From 332d7fb4e2d5f8b7ea61ef1ec53225d6105b9b2f Mon Sep 17 00:00:00 2001 From: Adam Miller Date: Wed, 8 Apr 2026 16:36:25 -0500 Subject: [PATCH] feat(providers): add OAuth2 credential lifecycle support via credential poll loop Add OAuth2 token exchange, caching, and refresh to the gateway proxy. The gateway server performs all OAuth2 operations (token exchange, refresh, rotation persistence) via a new TokenVendingService. The sandbox supervisor polls for fresh access tokens on a server-dictated interval, atomically updating the SecretResolver. Core design properties: - Real OAuth2 secrets (client_id, client_secret, refresh_token) never leave the gateway process - Short-lived access tokens follow the existing credential isolation path (placeholder in sandbox, resolved at proxy egress boundary) - Zero overhead for existing static-credential sandboxes (no poll loop spawned when refresh_after_secs=0) - Backward-compatible proto change (new field on existing response) Changes: - proto: add refresh_after_secs field to GetSandboxProviderEnvironmentResponse - openshell-server: new token_vending module with OAuth2 token exchange, per-provider caching with dedup, refresh token rotation persistence - openshell-server: extend resolve_provider_environment() to handle OAuth2 providers, filter internal credentials, compute refresh interval - openshell-server: add OAuth2 config validation at provider creation - openshell-sandbox: SecretResolver uses RwLock for atomic credential updates, add replace_secrets() method - openshell-sandbox: new run_credential_poll_loop() modeled on existing policy poll loop, with adaptive retry on failure - openshell-sandbox: grpc_client returns ProviderEnvironmentResult with refresh_after_secs - openshell-cli: provider get displays Auth method (Static vs OAuth2) - architecture: document OAuth2 lifecycle in sandbox-providers.md --- Cargo.lock | 15 + Cargo.toml | 3 + architecture/sandbox-providers.md | 148 +++ crates/openshell-cli/src/run.rs | 16 + crates/openshell-sandbox/src/grpc_client.rs | 44 +- crates/openshell-sandbox/src/lib.rs | 127 ++- crates/openshell-sandbox/src/secrets.rs | 216 ++++- crates/openshell-server/Cargo.toml | 1 + crates/openshell-server/src/grpc.rs | 185 +++- crates/openshell-server/src/lib.rs | 6 + crates/openshell-server/src/token_vending.rs | 939 +++++++++++++++++++ proto/openshell.proto | 5 + 12 files changed, 1630 insertions(+), 75 deletions(-) create mode 100644 crates/openshell-server/src/token_vending.rs diff --git a/Cargo.lock b/Cargo.lock index fd3b68d1e..c76faaa18 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1030,6 +1030,20 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.10.0" @@ -3003,6 +3017,7 @@ dependencies = [ "axum 0.8.8", "bytes", "clap", + "dashmap", "futures", "futures-util", "hex", diff --git a/Cargo.toml b/Cargo.toml index 3380e040b..6bf359b2f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -75,6 +75,9 @@ tokio-tungstenite = { version = "0.26", features = ["rustls-tls-native-roots"] } # Clipboard (OSC 52) base64 = "0.22" +# Concurrent data structures +dashmap = "6" + # Utilities futures = "0.3" bytes = "1" diff --git a/architecture/sandbox-providers.md b/architecture/sandbox-providers.md index fe5d48a97..fed97da27 100644 --- a/architecture/sandbox-providers.md +++ b/architecture/sandbox-providers.md @@ -374,6 +374,145 @@ The gateway enforces: Providers are stored with `object_type = "provider"` in the shared object store. +## OAuth2 Credential Lifecycle + +### Overview + +Providers can use OAuth2 for credential lifecycle management instead of static tokens. +The gateway server performs all OAuth2 operations (token exchange, refresh, rotation +persistence). The sandbox supervisor polls for fresh access tokens on a server-dictated +interval, atomically updating the `SecretResolver`. Sandboxes with only static credentials +incur zero overhead — no poll loop is spawned. + +The core invariant is preserved: real credentials (access tokens, refresh tokens, client +secrets) never enter the sandbox runtime. Child processes see only stable placeholder +strings. + +### Configuration + +OAuth2 is an auth method, not a provider type. Any provider type (`github`, `gitlab`, +`generic`, etc.) can use OAuth2 by setting config keys: + +| Config Key | Required | Example | Purpose | +|---|---|---|---| +| `auth_method` | Yes | `oauth2` | Discriminator (absence means static) | +| `oauth_token_endpoint` | Yes | `https://github.com/login/oauth/access_token` | Token exchange URL (HTTPS only) | +| `oauth_grant_type` | Yes | `refresh_token` or `client_credentials` | OAuth2 flow type | +| `oauth_scopes` | No | `api read_user` | Space-separated scopes | +| `oauth_access_token_env` | No | `MY_TOKEN` | Override output env var name | + +OAuth2 secret material is stored in `Provider.credentials`: + +| Credential Key | Required For | Purpose | +|---|---|---| +| `OAUTH_CLIENT_ID` | All OAuth2 | Client identifier | +| `OAUTH_CLIENT_SECRET` | All OAuth2 | Client secret | +| `OAUTH_REFRESH_TOKEN` | `refresh_token` grant | Refresh token (may be rotated) | + +### CLI Usage + +```bash +# Refresh token flow: +openshell provider create \ + --name github-oauth --type github \ + --credential OAUTH_CLIENT_ID=Iv1.abc123 \ + --credential OAUTH_CLIENT_SECRET=secret456 \ + --credential OAUTH_REFRESH_TOKEN=ghr_xyz789 \ + --config auth_method=oauth2 \ + --config oauth_grant_type=refresh_token \ + --config oauth_token_endpoint=https://github.com/login/oauth/access_token + +# Client credentials flow: +openshell provider create \ + --name service-account --type generic \ + --credential OAUTH_CLIENT_ID=client-id \ + --credential OAUTH_CLIENT_SECRET=client-secret \ + --config auth_method=oauth2 \ + --config oauth_grant_type=client_credentials \ + --config oauth_token_endpoint=https://auth.example.com/oauth2/token +``` + +### Gateway-Side Token Vending + +The `TokenVendingService` (`crates/openshell-server/src/token_vending.rs`) handles all +OAuth2 token exchange, caching, and refresh: + +- **Per-provider caching**: access tokens are cached in memory with their TTL. +- **Lazy refresh**: tokens are refreshed when a sandbox calls + `GetSandboxProviderEnvironment` and the cached token is within its safety margin + (`max(60s, ttl * 0.1)` before expiry). +- **Concurrent deduplication**: a per-provider `tokio::sync::Mutex` ensures only one + HTTP request to the IdP runs at a time; concurrent callers await the result. +- **Refresh token rotation**: if the IdP returns a new refresh token, the gateway + persists it to the store via `UpdateProvider`. + +The `resolve_provider_environment()` function detects OAuth2 providers via +`config["auth_method"] == "oauth2"`, calls the token vending service, and returns +the access token as a credential entry (e.g., `GITHUB_ACCESS_TOKEN`). OAuth2 internal +credentials (`OAUTH_CLIENT_ID`, `OAUTH_CLIENT_SECRET`, `OAUTH_REFRESH_TOKEN`) are +filtered out and never injected into the sandbox environment. + +### Response-Driven Polling + +`GetSandboxProviderEnvironmentResponse` includes a `refresh_after_secs` field: + +- **0**: all credentials are static — supervisor does not spawn a poll loop. +- **>0**: computed as `min(token_ttl / 2)` across all OAuth2 providers. The supervisor + spawns a background credential poll loop. + +### Supervisor Credential Poll Loop + +When `refresh_after_secs > 0`, the sandbox supervisor spawns +`run_credential_poll_loop()` (modeled on `run_policy_poll_loop()`): + +1. Sleeps for the server-dictated interval. +2. Calls `GetSandboxProviderEnvironment` via `CachedOpenShellClient`. +3. Atomically replaces all `SecretResolver` mappings via `replace_secrets()`. +4. On failure, tightens the retry interval to 30 seconds. +5. On recovery, restores the server-dictated interval. +6. If the server returns `refresh_after_secs == 0`, exits cleanly. + +The `SecretResolver` uses `std::sync::RwLock` to allow atomic value replacement +without blocking concurrent reads (credential resolution during request forwarding). + +### Credential Isolation + +| Secret | Gateway | Supervisor | Child Env | Proxy Wire | +|---|---|---|---|---| +| `OAUTH_CLIENT_ID` | ✅ Store + cache | ❌ | ❌ | ❌ | +| `OAUTH_CLIENT_SECRET` | ✅ Store + cache | ❌ | ❌ | ❌ | +| `OAUTH_REFRESH_TOKEN` | ✅ Store + cache | ❌ | ❌ | ❌ | +| Access token (ephemeral) | ✅ Cache | ✅ SecretResolver | ❌ Placeholder | ✅ Egress | + +### End-to-End Flow (OAuth2) + +``` +CLI: openshell provider create --type github --config auth_method=oauth2 ... + | + +-- Gateway validates OAuth2 config (HTTPS endpoint, required credentials) + +-- Persists Provider with credentials + config + | +CLI: openshell sandbox create --provider github-oauth -- claude + | + +-- Gateway: create_sandbox() validates provider exists + | + Sandbox supervisor: run_sandbox() + +-- GetSandboxProviderEnvironment + | +-- Gateway: resolve_provider_environment() + | | +-- Detects auth_method=oauth2 + | | +-- TokenVendingService::get_or_refresh() + | | | +-- POST to oauth_token_endpoint (lazy refresh) + | | | +-- Caches access_token with TTL + | | +-- Returns {GITHUB_ACCESS_TOKEN: "gho-abc123...", refresh_after_secs: 1800} + | | +-- Filters out OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OAUTH_REFRESH_TOKEN + +-- SecretResolver::from_provider_env() + | +-- child env: {GITHUB_ACCESS_TOKEN: "openshell:resolve:env:GITHUB_ACCESS_TOKEN"} + | +-- resolver: {"openshell:resolve:env:GITHUB_ACCESS_TOKEN": "gho-abc123..."} + +-- Spawns credential poll loop (refresh_after_secs=1800 → poll every 30min) + | +-- Every 30min: GetSandboxProviderEnvironment → replace_secrets() + +-- Proxy rewrites outbound headers with current access token +``` + ## Security Notes - Provider credentials are stored in `credentials` map and treated as sensitive. @@ -385,6 +524,11 @@ Providers are stored with `object_type = "provider"` in the shared object store. placeholders, and the supervisor resolves those placeholders during outbound proxying. - `OPENSHELL_SSH_HANDSHAKE_SECRET` is required by the supervisor/SSH server path but is explicitly kept out of spawned sandbox child-process environments. +- OAuth2 long-lived secrets (client ID, client secret, refresh token) never leave the + gateway process. Only short-lived access tokens are sent to the supervisor. +- OAuth2 token endpoints must use HTTPS (enforced at provider creation). +- Token endpoint responses are validated: access token values pass through + `validate_resolved_secret()` to reject header-injection characters. ## Test Strategy @@ -396,3 +540,7 @@ Providers are stored with `object_type = "provider"` in the shared object store. - sandbox unit tests validate placeholder generation and header rewriting. - E2E sandbox tests verify placeholders are visible in child env, outbound proxy traffic is rewritten with the real secret, and the SSH handshake secret is absent from exec env. +- OAuth2 token vending unit tests in `crates/openshell-server/src/token_vending.rs`: + mock HTTP server tests for token exchange, caching, rotation, and error handling. +- `SecretResolver` concurrency tests validate `replace_secrets()` under concurrent + read/write access. diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index 351a9346e..e55096908 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -3450,11 +3450,27 @@ pub async fn provider_get(server: &str, name: &str, tls: &TlsOptions) -> Result< let credential_keys = provider.credentials.keys().cloned().collect::>(); let config_keys = provider.config.keys().cloned().collect::>(); + // Derive auth method display. + let auth_display = if provider + .config + .get("auth_method") + .is_some_and(|v| v == "oauth2") + { + let grant = provider + .config + .get("oauth_grant_type") + .map_or("unknown", |v| v.as_str()); + format!("OAuth2 ({grant})") + } else { + "Static".to_string() + }; + println!("{}", "Provider:".cyan().bold()); println!(); println!(" {} {}", "Id:".dimmed(), provider.id); println!(" {} {}", "Name:".dimmed(), provider.name); println!(" {} {}", "Type:".dimmed(), provider.r#type); + println!(" {} {}", "Auth:".dimmed(), auth_display); println!( " {} {}", "Credential keys:".dimmed(), diff --git a/crates/openshell-sandbox/src/grpc_client.rs b/crates/openshell-sandbox/src/grpc_client.rs index 5503637ee..646c98034 100644 --- a/crates/openshell-sandbox/src/grpc_client.rs +++ b/crates/openshell-sandbox/src/grpc_client.rs @@ -181,15 +181,22 @@ pub async fn sync_policy(endpoint: &str, sandbox: &str, policy: &ProtoSandboxPol sync_policy_with_client(&mut client, sandbox, policy).await } +/// Result of fetching provider environment. +pub struct ProviderEnvironmentResult { + /// Credential environment variables (key → secret value). + pub environment: HashMap, + /// Seconds until next refresh. 0 = static credentials only. + pub refresh_after_secs: u32, +} + /// Fetch provider environment variables for a sandbox from OpenShell server via gRPC. /// -/// Returns a map of environment variable names to values derived from provider -/// credentials configured on the sandbox. Returns an empty map if the sandbox -/// has no providers or the call fails. +/// Returns credential env vars and a refresh interval. A `refresh_after_secs` +/// of 0 means all credentials are static and no polling is needed. pub async fn fetch_provider_environment( endpoint: &str, sandbox_id: &str, -) -> Result> { +) -> Result { debug!(endpoint = %endpoint, sandbox_id = %sandbox_id, "Fetching provider environment"); let mut client = connect(endpoint).await?; @@ -201,7 +208,11 @@ pub async fn fetch_provider_environment( .await .into_diagnostic()?; - Ok(response.into_inner().environment) + let inner = response.into_inner(); + Ok(ProviderEnvironmentResult { + environment: inner.environment, + refresh_after_secs: inner.refresh_after_secs, + }) } /// A reusable gRPC client for the OpenShell service. @@ -221,7 +232,7 @@ pub struct SettingsPollResult { pub config_revision: u64, pub policy_source: PolicySource, /// Effective settings keyed by name. - pub settings: std::collections::HashMap, + pub settings: HashMap, /// When `policy_source` is `Global`, the version of the global policy revision. pub global_policy_version: u32, } @@ -264,6 +275,27 @@ impl CachedOpenShellClient { }) } + /// Fetch provider environment for credential refresh polling. + pub async fn fetch_provider_environment( + &self, + sandbox_id: &str, + ) -> Result { + let response = self + .client + .clone() + .get_sandbox_provider_environment(GetSandboxProviderEnvironmentRequest { + sandbox_id: sandbox_id.to_string(), + }) + .await + .into_diagnostic()?; + + let inner = response.into_inner(); + Ok(ProviderEnvironmentResult { + environment: inner.environment, + refresh_after_secs: inner.refresh_after_secs, + }) + } + /// Submit denial summaries for policy analysis. pub async fn submit_policy_analysis( &self, diff --git a/crates/openshell-sandbox/src/lib.rs b/crates/openshell-sandbox/src/lib.rs index b160cdefc..c58c402b3 100644 --- a/crates/openshell-sandbox/src/lib.rs +++ b/crates/openshell-sandbox/src/lib.rs @@ -187,22 +187,27 @@ pub async fn run_sandbox( // Fetch provider environment variables from the server. // This is done after loading the policy so the sandbox can still start // even if provider env fetch fails (graceful degradation). - let provider_env = if let (Some(id), Some(endpoint)) = (&sandbox_id, &openshell_endpoint) { - match grpc_client::fetch_provider_environment(endpoint, id).await { - Ok(env) => { - info!(env_count = env.len(), "Fetched provider environment"); - env - } - Err(e) => { - warn!(error = %e, "Failed to fetch provider environment, continuing without"); - std::collections::HashMap::new() + let (provider_env_map, refresh_after_secs) = + if let (Some(id), Some(endpoint)) = (&sandbox_id, &openshell_endpoint) { + match grpc_client::fetch_provider_environment(endpoint, id).await { + Ok(result) => { + info!( + env_count = result.environment.len(), + refresh_after_secs = result.refresh_after_secs, + "Fetched provider environment" + ); + (result.environment, result.refresh_after_secs) + } + Err(e) => { + warn!(error = %e, "Failed to fetch provider environment, continuing without"); + (std::collections::HashMap::new(), 0) + } } - } - } else { - std::collections::HashMap::new() - }; + } else { + (std::collections::HashMap::new(), 0) + }; - let (provider_env, secret_resolver) = SecretResolver::from_provider_env(provider_env); + let (provider_env, secret_resolver) = SecretResolver::from_provider_env(provider_env_map); let secret_resolver = secret_resolver.map(Arc::new); // Create identity cache for SHA256 TOFU when OPA is active @@ -589,6 +594,30 @@ pub async fn run_sandbox( } }); + // Spawn credential poll loop (only if credentials are time-bounded). + if refresh_after_secs > 0 { + if let Some(ref resolver) = secret_resolver { + let cred_id = id.clone(); + let cred_endpoint = endpoint.clone(); + let cred_resolver = resolver.clone(); + + tokio::spawn(async move { + if let Err(e) = run_credential_poll_loop( + &cred_endpoint, + &cred_id, + &cred_resolver, + refresh_after_secs, + ) + .await + { + warn!(error = %e, "Credential poll loop exited with error"); + } + }); + } + } else { + debug!("All provider credentials are static, skipping credential poll loop"); + } + // Spawn denial aggregator (gRPC mode only, when proxy is active). if let Some(rx) = denial_rx { // SubmitPolicyAnalysis resolves by sandbox *name*, not UUID. @@ -1637,6 +1666,76 @@ async fn run_policy_poll_loop( } /// Log individual setting changes between two snapshots. +/// Minimum poll interval for credential refresh (floor). +const CREDENTIAL_MIN_POLL_INTERVAL_SECS: u64 = 30; +/// Retry interval when credential refresh fails. +const CREDENTIAL_FAILURE_RETRY_SECS: u64 = 30; + +/// Background loop that periodically re-fetches provider credentials from the +/// gateway. Modeled on [`run_policy_poll_loop`] but focused on credential +/// lifecycle. +/// +/// The loop exits cleanly when the gateway returns `refresh_after_secs == 0`, +/// indicating all credentials have become static (e.g. provider reconfigured). +async fn run_credential_poll_loop( + endpoint: &str, + sandbox_id: &str, + secret_resolver: &Arc, + initial_refresh_secs: u32, +) -> Result<()> { + use crate::grpc_client::CachedOpenShellClient; + + let client = CachedOpenShellClient::connect(endpoint).await?; + let mut current_interval = + Duration::from_secs((initial_refresh_secs as u64).max(CREDENTIAL_MIN_POLL_INTERVAL_SECS)); + let mut consecutive_failures: u32 = 0; + + loop { + tokio::time::sleep(current_interval).await; + + match client.fetch_provider_environment(sandbox_id).await { + Ok(result) => { + if consecutive_failures > 0 { + info!( + consecutive_failures, + "Credential refresh recovered after failures" + ); + } + consecutive_failures = 0; + + // Atomically replace the secret values. + secret_resolver.replace_secrets(result.environment); + + // Use server-dictated interval for next poll. + let server_interval = + u64::from(result.refresh_after_secs).max(CREDENTIAL_MIN_POLL_INTERVAL_SECS); + current_interval = Duration::from_secs(server_interval); + + debug!( + refresh_after_secs = result.refresh_after_secs, + "Credential refresh completed" + ); + + // If server says 0, credentials became static. Exit loop. + if result.refresh_after_secs == 0 { + info!("All credentials now static, stopping credential poll loop"); + return Ok(()); + } + } + Err(e) => { + consecutive_failures += 1; + current_interval = Duration::from_secs(CREDENTIAL_FAILURE_RETRY_SECS); + warn!( + error = %e, + consecutive_failures, + retry_in_secs = CREDENTIAL_FAILURE_RETRY_SECS, + "Credential refresh failed, retrying at reduced interval" + ); + } + } + } +} + fn log_setting_changes( old: &std::collections::HashMap, new: &std::collections::HashMap, diff --git a/crates/openshell-sandbox/src/secrets.rs b/crates/openshell-sandbox/src/secrets.rs index a27537c91..b87153a41 100644 --- a/crates/openshell-sandbox/src/secrets.rs +++ b/crates/openshell-sandbox/src/secrets.rs @@ -4,6 +4,7 @@ use base64::Engine as _; use std::collections::HashMap; use std::fmt; +use std::sync::RwLock; const PLACEHOLDER_PREFIX: &str = "openshell:resolve:env:"; @@ -61,9 +62,23 @@ pub(crate) struct RewriteTargetResult { // SecretResolver // --------------------------------------------------------------------------- -#[derive(Debug, Clone, Default)] +/// Resolves `openshell:resolve:env:*` placeholder strings to real secret values. +/// +/// This type intentionally does **not** implement `Clone` because the inner +/// `RwLock` requires atomic access for credential refresh. All call +/// sites wrap the resolver in `Arc` for shared ownership. +#[derive(Default)] pub struct SecretResolver { - by_placeholder: HashMap, + by_placeholder: RwLock>, +} + +impl fmt::Debug for SecretResolver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let count = self.by_placeholder.read().map(|m| m.len()).unwrap_or(0); + f.debug_struct("SecretResolver") + .field("placeholder_count", &count) + .finish() + } } impl SecretResolver { @@ -83,17 +98,41 @@ impl SecretResolver { by_placeholder.insert(placeholder, value); } - (child_env, Some(Self { by_placeholder })) + ( + child_env, + Some(Self { + by_placeholder: RwLock::new(by_placeholder), + }), + ) + } + + /// Atomically replace all secret mappings with new values. + /// + /// Used by the credential poll loop to swap in fresh OAuth2 access tokens. + /// Takes env-key-keyed input (e.g. `{"GITHUB_TOKEN": "new-val"}`) and + /// converts to placeholder-keyed internally. + pub(crate) fn replace_secrets(&self, new_provider_env: HashMap) { + let mut new_map = HashMap::with_capacity(new_provider_env.len()); + for (key, value) in new_provider_env { + let placeholder = placeholder_for_env_key(&key); + new_map.insert(placeholder, value); + } + let mut guard = self.by_placeholder.write().expect("secret resolver lock"); + *guard = new_map; } /// Resolve a placeholder string to the real secret value. /// /// Returns `None` if the placeholder is unknown or the resolved value /// contains prohibited control characters (CRLF, null byte). - pub(crate) fn resolve_placeholder(&self, value: &str) -> Option<&str> { - let secret = self.by_placeholder.get(value).map(String::as_str)?; + /// + /// Returns an owned `String` because the value is cloned under a read + /// lock — callers must not hold a borrow across the lock boundary. + pub(crate) fn resolve_placeholder(&self, value: &str) -> Option { + let guard = self.by_placeholder.read().expect("secret resolver lock"); + let secret = guard.get(value)?; match validate_resolved_secret(secret) { - Ok(s) => Some(s), + Ok(s) => Some(s.to_string()), Err(reason) => { tracing::warn!( location = "resolve_placeholder", @@ -108,7 +147,7 @@ impl SecretResolver { pub(crate) fn rewrite_header_value(&self, value: &str) -> Option { // Direct placeholder match: `x-api-key: openshell:resolve:env:KEY` if let Some(secret) = self.resolve_placeholder(value.trim()) { - return Some(secret.to_string()); + return Some(secret); } let trimmed = value.trim(); @@ -147,20 +186,29 @@ impl SecretResolver { return None; } - // Rewrite all placeholder occurrences in the decoded string + // Rewrite all placeholder occurrences in the decoded string. + // Collect replacements under a single read-lock acquisition to + // avoid holding the lock across the string rewriting loop. + let replacements: Vec<(String, String)> = { + let guard = self.by_placeholder.read().expect("secret resolver lock"); + guard + .iter() + .filter(|(ph, _)| decoded.contains(ph.as_str())) + .map(|(ph, sec)| (ph.clone(), sec.clone())) + .collect() + }; + let mut rewritten = decoded.to_string(); - for (placeholder, secret) in &self.by_placeholder { - if rewritten.contains(placeholder.as_str()) { - // Validate the resolved secret for control characters - if validate_resolved_secret(secret).is_err() { - tracing::warn!( - location = "basic_auth", - "credential resolution rejected: resolved value contains prohibited characters" - ); - return None; - } - rewritten = rewritten.replace(placeholder.as_str(), secret); + for (placeholder, secret) in &replacements { + // Validate the resolved secret for control characters + if validate_resolved_secret(secret).is_err() { + tracing::warn!( + location = "basic_auth", + "credential resolution rejected: resolved value contains prohibited characters" + ); + return None; } + rewritten = rewritten.replace(placeholder.as_str(), secret); } // Only return if we actually changed something @@ -479,7 +527,7 @@ fn rewrite_path_segment( let full_placeholder = &segment[abs_start..key_end]; if let Some(secret) = resolver.resolve_placeholder(full_placeholder) { - validate_credential_for_path(secret).map_err(|reason| { + validate_credential_for_path(&secret).map_err(|reason| { tracing::warn!( location = "path", %reason, @@ -487,7 +535,7 @@ fn rewrite_path_segment( ); UnresolvedPlaceholderError { location: "path" } })?; - resolved.push_str(secret); + resolved.push_str(&secret); redacted.push_str("[CREDENTIAL]"); } else { return Err(UnresolvedPlaceholderError { location: "path" }); @@ -523,7 +571,7 @@ fn rewrite_uri_query_params( if let Some((key, value)) = param.split_once('=') { let decoded_value = percent_decode(value); if let Some(secret) = resolver.resolve_placeholder(&decoded_value) { - resolved_params.push(format!("{key}={}", percent_encode_query(secret))); + resolved_params.push(format!("{key}={}", percent_encode_query(&secret))); redacted_params.push(format!("{key}=[CREDENTIAL]")); any_rewritten = true; } else if decoded_value.contains(PLACEHOLDER_PREFIX) { @@ -717,7 +765,7 @@ mod tests { .as_ref() .and_then(|resolver| resolver .resolve_placeholder("openshell:resolve:env:ANTHROPIC_API_KEY")), - Some("sk-test") + Some("sk-test".to_string()) ); } @@ -890,7 +938,7 @@ mod tests { let resolver = resolver.expect("resolver"); assert_eq!( resolver.resolve_placeholder("openshell:resolve:env:KEY"), - Some("sk-abc123_DEF.456~xyz") + Some("sk-abc123_DEF.456~xyz".to_string()) ); } @@ -1473,4 +1521,124 @@ mod tests { assert_eq!(result.resolved, "/bottok123/method?key=key456"); assert_eq!(result.redacted, "/bot[CREDENTIAL]/method?key=[CREDENTIAL]"); } + + // === RwLock and replace_secrets tests === + + #[test] + fn replace_secrets_updates_resolved_values() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_KEY".to_string(), "old-secret".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + + assert_eq!( + resolver.resolve_placeholder("openshell:resolve:env:API_KEY"), + Some("old-secret".to_string()) + ); + + // Replace with new value. + resolver.replace_secrets( + [("API_KEY".to_string(), "new-secret".to_string())] + .into_iter() + .collect(), + ); + + assert_eq!( + resolver.resolve_placeholder("openshell:resolve:env:API_KEY"), + Some("new-secret".to_string()) + ); + } + + #[test] + fn replace_secrets_removes_stale_keys() { + let (_, resolver) = SecretResolver::from_provider_env( + [ + ("KEY_A".to_string(), "val-a".to_string()), + ("KEY_B".to_string(), "val-b".to_string()), + ] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + + // Replace with only KEY_A — KEY_B should disappear. + resolver.replace_secrets( + [("KEY_A".to_string(), "val-a-new".to_string())] + .into_iter() + .collect(), + ); + + assert_eq!( + resolver.resolve_placeholder("openshell:resolve:env:KEY_A"), + Some("val-a-new".to_string()) + ); + assert_eq!( + resolver.resolve_placeholder("openshell:resolve:env:KEY_B"), + None + ); + } + + #[test] + fn replace_secrets_adds_new_keys() { + let (_, resolver) = SecretResolver::from_provider_env( + [("EXISTING".to_string(), "val".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + + resolver.replace_secrets( + [ + ("EXISTING".to_string(), "val".to_string()), + ("NEW_KEY".to_string(), "new-val".to_string()), + ] + .into_iter() + .collect(), + ); + + assert_eq!( + resolver.resolve_placeholder("openshell:resolve:env:NEW_KEY"), + Some("new-val".to_string()) + ); + } + + #[test] + fn concurrent_read_during_replace_does_not_panic() { + use std::sync::Arc; + + let (_, resolver) = SecretResolver::from_provider_env( + [("TOKEN".to_string(), "initial".to_string())] + .into_iter() + .collect(), + ); + let resolver = Arc::new(resolver.expect("resolver")); + + // Spawn readers and one writer concurrently. + let mut handles = Vec::new(); + for _ in 0..8 { + let r = resolver.clone(); + handles.push(std::thread::spawn(move || { + for _ in 0..100 { + let _ = r.resolve_placeholder("openshell:resolve:env:TOKEN"); + } + })); + } + + let w = resolver.clone(); + handles.push(std::thread::spawn(move || { + for i in 0..100 { + w.replace_secrets( + [("TOKEN".to_string(), format!("val-{i}"))] + .into_iter() + .collect(), + ); + } + })); + + for h in handles { + h.join().expect("thread should not panic"); + } + } } diff --git a/crates/openshell-server/Cargo.toml b/crates/openshell-server/Cargo.toml index 0308f30ff..831a08017 100644 --- a/crates/openshell-server/Cargo.toml +++ b/crates/openshell-server/Cargo.toml @@ -63,6 +63,7 @@ serde_json = { workspace = true } tokio-stream = { workspace = true } sqlx = { workspace = true } reqwest = { workspace = true } +dashmap = { workspace = true } kube = { workspace = true } kube-runtime = { workspace = true } k8s-openapi = { workspace = true } diff --git a/crates/openshell-server/src/grpc.rs b/crates/openshell-server/src/grpc.rs index d7ef4ccf5..b7b209068 100644 --- a/crates/openshell-server/src/grpc.rs +++ b/crates/openshell-server/src/grpc.rs @@ -929,18 +929,24 @@ impl OpenShell for OpenShellService { .spec .ok_or_else(|| Status::internal("sandbox has no spec"))?; - let environment = - resolve_provider_environment(self.state.store.as_ref(), &spec.providers).await?; + let resolved = resolve_provider_environment( + self.state.store.as_ref(), + &self.state.token_vending, + &spec.providers, + ) + .await?; info!( sandbox_id = %sandbox_id, provider_count = spec.providers.len(), - env_count = environment.len(), + env_count = resolved.environment.len(), + refresh_after_secs = resolved.refresh_after_secs, "GetSandboxProviderEnvironment request completed successfully" ); Ok(Response::new(GetSandboxProviderEnvironmentResponse { - environment, + environment: resolved.environment, + refresh_after_secs: resolved.refresh_after_secs, })) } @@ -3634,21 +3640,39 @@ fn build_remote_exec_command(req: &ExecSandboxRequest) -> Result Ok(result) } +/// Resolved provider environment with optional refresh interval. +#[derive(Debug)] +struct ResolvedProviderEnv { + environment: std::collections::HashMap, + /// Seconds until the supervisor should re-fetch. 0 = static only. + refresh_after_secs: u32, +} + /// Resolve provider credentials into environment variables. /// /// For each provider name in the list, fetches the provider from the store and /// collects credential key-value pairs. Returns a map of environment variables /// to inject into the sandbox. When duplicate keys appear across providers, the /// first provider's value wins. +/// +/// For OAuth2 providers, performs token exchange/refresh via the +/// `TokenVendingService` and returns the access token as a regular credential +/// entry. OAuth2 internal credentials (`OAUTH_CLIENT_ID`, `OAUTH_CLIENT_SECRET`, +/// `OAUTH_REFRESH_TOKEN`) are filtered and never injected into the sandbox. async fn resolve_provider_environment( store: &crate::persistence::Store, + token_vending: &crate::token_vending::TokenVendingService, provider_names: &[String], -) -> Result, Status> { +) -> Result { if provider_names.is_empty() { - return Ok(std::collections::HashMap::new()); + return Ok(ResolvedProviderEnv { + environment: std::collections::HashMap::new(), + refresh_after_secs: 0, + }); } let mut env = std::collections::HashMap::new(); + let mut min_refresh_secs: Option = None; for name in provider_names { let provider = store @@ -3657,7 +3681,60 @@ async fn resolve_provider_environment( .map_err(|e| Status::internal(format!("failed to fetch provider '{name}': {e}")))? .ok_or_else(|| Status::failed_precondition(format!("provider '{name}' not found")))?; + let is_oauth2 = crate::token_vending::is_oauth2_provider(&provider); + + // Handle OAuth2 token resolution. + if is_oauth2 { + match token_vending.get_or_refresh(&provider).await { + Ok(result) => { + let token_key = crate::token_vending::oauth_access_token_key(&provider); + env.entry(token_key).or_insert(result.access_token); + + // If the IdP returned a rotated refresh token, persist it. + // We construct a Provider with empty `id` and `config` + // because `update_provider_record` resolves by `name` (not + // id) and `merge_map` treats empty maps as no-ops, so only + // the credentials map is merged into the existing record. + if let Some(new_refresh) = result.new_refresh_token { + let mut updated_creds = std::collections::HashMap::new(); + updated_creds.insert("OAUTH_REFRESH_TOKEN".to_string(), new_refresh); + let updated = Provider { + id: String::new(), + name: provider.name.clone(), + r#type: provider.r#type.clone(), + credentials: updated_creds, + config: std::collections::HashMap::new(), + }; + if let Err(e) = update_provider_record(store, updated).await { + warn!( + provider_name = %name, + error = %e.message(), + "failed to persist rotated refresh token" + ); + } + } + + // Track shortest TTL for poll interval. Floor at 1 so + // that very short-lived tokens (expires_in <= 1) don't + // produce 0, which the supervisor interprets as "static". + let refresh = (result.expires_in_secs / 2).max(1); + min_refresh_secs = Some(min_refresh_secs.map_or(refresh, |m| m.min(refresh))); + } + Err(e) => { + return Err(Status::internal(format!( + "OAuth2 token refresh failed for provider '{name}': {e}" + ))); + } + } + } + + // Inject static credentials; skip OAuth2 internal keys only for + // OAuth2 providers so that a static provider using a credential key + // like "OAUTH_CLIENT_ID" is not accidentally suppressed. for (key, value) in &provider.credentials { + if is_oauth2 && crate::token_vending::is_oauth2_internal_credential(key) { + continue; + } if is_valid_env_key(key) { env.entry(key.clone()).or_insert_with(|| value.clone()); } else { @@ -3670,7 +3747,10 @@ async fn resolve_provider_environment( } } - Ok(env) + Ok(ResolvedProviderEnv { + environment: env, + refresh_after_secs: min_refresh_secs.unwrap_or(0), + }) } fn is_valid_env_key(key: &str) -> bool { @@ -4157,6 +4237,9 @@ async fn create_provider_record( // Validate field sizes before any I/O. validate_provider_fields(&provider)?; + // Validate OAuth2-specific configuration if present. + crate::token_vending::validate_oauth2_config(&provider)?; + let existing = store .get_message_by_name::(&provider.name) .await @@ -4270,6 +4353,10 @@ async fn update_provider_record( validate_provider_fields(&updated)?; + // Validate OAuth2-specific configuration on the merged result, so a + // user cannot add auth_method=oauth2 without the required credentials. + crate::token_vending::validate_oauth2_config(&updated)?; + store .put_message(&updated) .await @@ -4872,8 +4959,11 @@ mod tests { #[tokio::test] async fn resolve_provider_env_empty_list_returns_empty() { let store = Store::connect("sqlite::memory:").await.unwrap(); - let result = resolve_provider_environment(&store, &[]).await.unwrap(); - assert!(result.is_empty()); + let token_vending = crate::token_vending::TokenVendingService::new(); + let result = resolve_provider_environment(&store, &token_vending, &[]) + .await + .unwrap(); + assert!(result.environment.is_empty()); } #[tokio::test] @@ -4897,21 +4987,31 @@ mod tests { }; create_provider_record(&store, provider).await.unwrap(); - let result = resolve_provider_environment(&store, &["claude-local".to_string()]) - .await - .unwrap(); - assert_eq!(result.get("ANTHROPIC_API_KEY"), Some(&"sk-abc".to_string())); - assert_eq!(result.get("CLAUDE_API_KEY"), Some(&"sk-abc".to_string())); + let token_vending = crate::token_vending::TokenVendingService::new(); + let result = + resolve_provider_environment(&store, &token_vending, &["claude-local".to_string()]) + .await + .unwrap(); + assert_eq!( + result.environment.get("ANTHROPIC_API_KEY"), + Some(&"sk-abc".to_string()) + ); + assert_eq!( + result.environment.get("CLAUDE_API_KEY"), + Some(&"sk-abc".to_string()) + ); // Config values should NOT be injected. - assert!(!result.contains_key("endpoint")); + assert!(!result.environment.contains_key("endpoint")); } #[tokio::test] async fn resolve_provider_env_unknown_name_returns_error() { let store = Store::connect("sqlite::memory:").await.unwrap(); - let err = resolve_provider_environment(&store, &["nonexistent".to_string()]) - .await - .unwrap_err(); + let token_vending = crate::token_vending::TokenVendingService::new(); + let err = + resolve_provider_environment(&store, &token_vending, &["nonexistent".to_string()]) + .await + .unwrap_err(); assert_eq!(err.code(), Code::FailedPrecondition); assert!(err.message().contains("nonexistent")); } @@ -4934,12 +5034,17 @@ mod tests { }; create_provider_record(&store, provider).await.unwrap(); - let result = resolve_provider_environment(&store, &["test-provider".to_string()]) - .await - .unwrap(); - assert_eq!(result.get("VALID_KEY"), Some(&"value".to_string())); - assert!(!result.contains_key("nested.api_key")); - assert!(!result.contains_key("bad-key")); + let token_vending = crate::token_vending::TokenVendingService::new(); + let result = + resolve_provider_environment(&store, &token_vending, &["test-provider".to_string()]) + .await + .unwrap(); + assert_eq!( + result.environment.get("VALID_KEY"), + Some(&"value".to_string()) + ); + assert!(!result.environment.contains_key("nested.api_key")); + assert!(!result.environment.contains_key("bad-key")); } #[tokio::test] @@ -4975,14 +5080,22 @@ mod tests { .await .unwrap(); + let token_vending = crate::token_vending::TokenVendingService::new(); let result = resolve_provider_environment( &store, + &token_vending, &["claude-local".to_string(), "gitlab-local".to_string()], ) .await .unwrap(); - assert_eq!(result.get("ANTHROPIC_API_KEY"), Some(&"sk-abc".to_string())); - assert_eq!(result.get("GITLAB_TOKEN"), Some(&"glpat-xyz".to_string())); + assert_eq!( + result.environment.get("ANTHROPIC_API_KEY"), + Some(&"sk-abc".to_string()) + ); + assert_eq!( + result.environment.get("GITLAB_TOKEN"), + Some(&"glpat-xyz".to_string()) + ); } #[tokio::test] @@ -5018,13 +5131,18 @@ mod tests { .await .unwrap(); + let token_vending = crate::token_vending::TokenVendingService::new(); let result = resolve_provider_environment( &store, + &token_vending, &["provider-a".to_string(), "provider-b".to_string()], ) .await .unwrap(); - assert_eq!(result.get("SHARED_KEY"), Some(&"first-value".to_string())); + assert_eq!( + result.environment.get("SHARED_KEY"), + Some(&"first-value".to_string()) + ); } /// Simulates the handler flow: persist a sandbox with providers, then resolve @@ -5075,11 +5193,15 @@ mod tests { .unwrap() .unwrap(); let spec = loaded.spec.unwrap(); - let env = resolve_provider_environment(&store, &spec.providers) + let token_vending = crate::token_vending::TokenVendingService::new(); + let env = resolve_provider_environment(&store, &token_vending, &spec.providers) .await .unwrap(); - assert_eq!(env.get("ANTHROPIC_API_KEY"), Some(&"sk-test".to_string())); + assert_eq!( + env.environment.get("ANTHROPIC_API_KEY"), + Some(&"sk-test".to_string()) + ); } /// Handler flow returns empty map when sandbox has no providers. @@ -5106,11 +5228,12 @@ mod tests { .unwrap() .unwrap(); let spec = loaded.spec.unwrap(); - let env = resolve_provider_environment(&store, &spec.providers) + let token_vending = crate::token_vending::TokenVendingService::new(); + let env = resolve_provider_environment(&store, &token_vending, &spec.providers) .await .unwrap(); - assert!(env.is_empty()); + assert!(env.environment.is_empty()); } /// Handler returns not-found when sandbox doesn't exist. diff --git a/crates/openshell-server/src/lib.rs b/crates/openshell-server/src/lib.rs index e827b3628..5547eff4d 100644 --- a/crates/openshell-server/src/lib.rs +++ b/crates/openshell-server/src/lib.rs @@ -20,6 +20,7 @@ mod sandbox_index; mod sandbox_watch; mod ssh_tunnel; mod tls; +pub mod token_vending; pub mod tracing_bus; mod ws_tunnel; @@ -38,6 +39,7 @@ use sandbox::{SandboxClient, spawn_sandbox_watcher, spawn_store_reconciler}; use sandbox_index::SandboxIndex; use sandbox_watch::{SandboxWatchBus, spawn_kube_event_tailer}; pub use tls::TlsAcceptor; +use token_vending::TokenVendingService; use tracing_bus::TracingLogBus; /// Server state shared across handlers. @@ -72,6 +74,9 @@ pub struct ServerState { /// set/delete operation, including the precedence check on sandbox /// mutations that reads global state. pub settings_mutex: tokio::sync::Mutex<()>, + + /// Token vending service for `OAuth2` providers. + pub token_vending: TokenVendingService, } fn is_benign_tls_handshake_failure(error: &std::io::Error) -> bool { @@ -102,6 +107,7 @@ impl ServerState { ssh_connections_by_token: Mutex::new(HashMap::new()), ssh_connections_by_sandbox: Mutex::new(HashMap::new()), settings_mutex: tokio::sync::Mutex::new(()), + token_vending: TokenVendingService::new(), } } } diff --git a/crates/openshell-server/src/token_vending.rs b/crates/openshell-server/src/token_vending.rs new file mode 100644 index 000000000..235e949bf --- /dev/null +++ b/crates/openshell-server/src/token_vending.rs @@ -0,0 +1,939 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! `OAuth2` token vending service. +//! +//! Performs `OAuth2` token exchanges, caches access tokens per provider, and +//! deduplicates concurrent refresh requests. Used by +//! `resolve_provider_environment()` to transparently return current access +//! tokens for `OAuth2` providers. + +use std::collections::HashMap; +use std::sync::Arc; + +use openshell_core::proto::Provider; +use tokio::sync::Mutex; +use tracing::{debug, error, info}; + +// --------------------------------------------------------------------------- +// Error types +// --------------------------------------------------------------------------- + +/// Errors returned by the token vending service. +#[derive(Debug, thiserror::Error)] +pub enum TokenError { + #[error("OAuth2 token exchange failed: {0}")] + Exchange(String), + + #[error("OAuth2 provider misconfigured: {0}")] + Config(String), + + #[error("HTTP request to token endpoint failed: {0}")] + Http(String), + + #[error("token endpoint returned error: {status} — {body}")] + EndpointError { status: u16, body: String }, +} + +// --------------------------------------------------------------------------- +// Token result +// --------------------------------------------------------------------------- + +/// Result of a successful token acquisition. +#[derive(Debug, Clone)] +pub struct TokenResult { + /// The current access token value. + pub access_token: String, + /// Seconds until this token expires. + pub expires_in_secs: u32, + /// If the `IdP` returned a rotated refresh token, it is captured here + /// so the caller can persist it. + pub new_refresh_token: Option, +} + +// --------------------------------------------------------------------------- +// Cached token +// --------------------------------------------------------------------------- + +/// Cached access token with expiry metadata. +#[derive(Debug, Clone)] +struct CachedToken { + access_token: String, + /// Absolute expiry time (monotonic). + expires_at: tokio::time::Instant, + /// Original TTL in seconds as reported by the `IdP`. + ttl_secs: u32, +} + +impl CachedToken { + /// Returns `true` if the token is still usable (with safety margin). + fn is_valid(&self) -> bool { + let margin = std::cmp::max(60, u64::from(self.ttl_secs / 10)); + let margin_dur = std::time::Duration::from_secs(margin); + tokio::time::Instant::now() + margin_dur < self.expires_at + } +} + +// --------------------------------------------------------------------------- +// Per-provider state +// --------------------------------------------------------------------------- + +/// Per-provider token state, protected by a Tokio mutex for async refresh. +struct ProviderTokenState { + cached: Option, +} + +// --------------------------------------------------------------------------- +// Token vending service +// --------------------------------------------------------------------------- + +/// Token vending service. One per gateway process. +/// +/// Caches access tokens per provider and deduplicates concurrent refresh +/// requests via a per-provider `Mutex`. +pub struct TokenVendingService { + /// Per-provider token state, keyed by provider ID. + state: dashmap::DashMap>>, + /// Shared HTTP client for token endpoint requests. + http_client: reqwest::Client, +} + +impl std::fmt::Debug for TokenVendingService { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TokenVendingService") + .field("cached_providers", &self.state.len()) + .finish() + } +} + +impl TokenVendingService { + /// Create a new token vending service. + #[must_use] + pub fn new() -> Self { + let http_client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .build() + .expect("failed to build reqwest client"); + Self { + state: dashmap::DashMap::new(), + http_client, + } + } + + /// Get a current access token for an `OAuth2` provider. + /// + /// Returns the cached token if still valid (with safety margin). + /// Otherwise performs a token refresh, caches the result, and returns it. + /// Concurrent callers for the same provider are deduplicated — only one + /// HTTP request to the `IdP` is made; others await the result. + pub async fn get_or_refresh(&self, provider: &Provider) -> Result { + let provider_id = &provider.id; + + // Get or create per-provider state. + let state = self + .state + .entry(provider_id.clone()) + .or_insert_with(|| Arc::new(Mutex::new(ProviderTokenState { cached: None }))) + .clone(); + + // Acquire the per-provider mutex. This serializes concurrent refresh + // requests for the same provider — only the first caller performs the + // HTTP request; others await the mutex and get the cached result. + let mut guard = state.lock().await; + + // Check cache under the lock. + if let Some(ref cached) = guard.cached + && cached.is_valid() + { + debug!( + provider_id = %provider_id, + ttl_remaining_secs = cached.expires_at.duration_since(tokio::time::Instant::now()).as_secs(), + "Returning cached OAuth2 access token" + ); + return Ok(TokenResult { + access_token: cached.access_token.clone(), + expires_in_secs: cached.ttl_secs, + new_refresh_token: None, + }); + } + + // Cache miss or expired — perform token exchange. + let config = OAuth2Config::from_provider(provider)?; + let result = self.exchange_token(&config).await?; + + info!( + provider_id = %provider_id, + provider_name = %provider.name, + expires_in_secs = result.expires_in_secs, + refresh_token_rotated = result.new_refresh_token.is_some(), + "OAuth2 token exchange completed" + ); + + // Cache the new token. + guard.cached = Some(CachedToken { + access_token: result.access_token.clone(), + expires_at: tokio::time::Instant::now() + + std::time::Duration::from_secs(u64::from(result.expires_in_secs)), + ttl_secs: result.expires_in_secs, + }); + + Ok(result) + } + + /// Perform the actual HTTP token exchange. + async fn exchange_token(&self, config: &OAuth2Config) -> Result { + let mut params = HashMap::new(); + params.insert("client_id", config.client_id.as_str()); + params.insert("client_secret", config.client_secret.as_str()); + params.insert("grant_type", config.grant_type.as_str()); + + if let Some(ref refresh_token) = config.refresh_token { + params.insert("refresh_token", refresh_token.as_str()); + } + + if let Some(ref scopes) = config.scopes { + params.insert("scope", scopes.as_str()); + } + + debug!( + token_endpoint = %config.token_endpoint, + grant_type = %config.grant_type, + "Sending OAuth2 token exchange request" + ); + + // Explicitly request JSON. Some OAuth2 providers (notably GitHub) + // default to application/x-www-form-urlencoded responses without this. + let response = self + .http_client + .post(&config.token_endpoint) + .header("Accept", "application/json") + .form(¶ms) + .send() + .await + .map_err(|e| TokenError::Http(e.to_string()))?; + + let status = response.status().as_u16(); + let body = response + .text() + .await + .map_err(|e| TokenError::Http(format!("failed to read response body: {e}")))?; + + if status != 200 { + // Truncate body for logging safety. Find a char boundary at + // or before byte 256 to avoid panicking on multi-byte UTF-8. + let truncated = if body.len() > 256 { + let end = (0..=256) + .rev() + .find(|&i| body.is_char_boundary(i)) + .unwrap_or(0); + format!("{}...", &body[..end]) + } else { + body + }; + error!( + token_endpoint = %config.token_endpoint, + status, + response_body = %truncated, + "OAuth2 token endpoint returned error" + ); + return Err(TokenError::EndpointError { + status, + body: truncated, + }); + } + + let token_response: OAuth2TokenResponse = serde_json::from_str(&body) + .map_err(|e| TokenError::Exchange(format!("failed to parse token response: {e}")))?; + + let access_token = token_response.access_token.ok_or_else(|| { + TokenError::Exchange("token response missing 'access_token' field".into()) + })?; + + if access_token.is_empty() { + return Err(TokenError::Exchange( + "token response contains empty 'access_token'".into(), + )); + } + + // Fail-fast: reject access tokens containing header-injection + // characters (CR, LF, NUL). The downstream `SecretResolver` also + // validates during placeholder resolution, but catching it here + // provides a clearer error and avoids caching a poisoned token. + validate_token_value(&access_token)?; + + // Validate rotated refresh token if present. A poisoned refresh + // token would be persisted to the store and break future refreshes. + let new_refresh_token = match token_response.refresh_token { + Some(rt) => { + validate_token_value(&rt).map_err(|_| { + TokenError::Exchange( + "rotated refresh token contains prohibited control characters".into(), + ) + })?; + Some(rt) + } + None => None, + }; + + // Default to 3600s if the IdP doesn't specify expires_in. + let expires_in = token_response.expires_in.unwrap_or(3600); + + Ok(TokenResult { + access_token, + expires_in_secs: expires_in, + new_refresh_token, + }) + } +} + +impl Default for TokenVendingService { + fn default() -> Self { + Self::new() + } +} + +// --------------------------------------------------------------------------- +// OAuth2 config extraction +// --------------------------------------------------------------------------- + +/// Parsed `OAuth2` configuration from a Provider record. +struct OAuth2Config { + token_endpoint: String, + grant_type: String, + client_id: String, + client_secret: String, + refresh_token: Option, + scopes: Option, +} + +impl OAuth2Config { + fn from_provider(provider: &Provider) -> Result { + let config = &provider.config; + let credentials = &provider.credentials; + + let token_endpoint = config + .get("oauth_token_endpoint") + .filter(|v| !v.trim().is_empty()) + .ok_or_else(|| { + TokenError::Config("missing oauth_token_endpoint in provider config".into()) + })? + .clone(); + + let grant_type = config + .get("oauth_grant_type") + .filter(|v| !v.trim().is_empty()) + .ok_or_else(|| { + TokenError::Config("missing oauth_grant_type in provider config".into()) + })? + .clone(); + + let client_id = credentials + .get("OAUTH_CLIENT_ID") + .filter(|v| !v.trim().is_empty()) + .ok_or_else(|| { + TokenError::Config("missing OAUTH_CLIENT_ID in provider credentials".into()) + })? + .clone(); + + let client_secret = credentials + .get("OAUTH_CLIENT_SECRET") + .filter(|v| !v.trim().is_empty()) + .ok_or_else(|| { + TokenError::Config("missing OAUTH_CLIENT_SECRET in provider credentials".into()) + })? + .clone(); + + let refresh_token = credentials.get("OAUTH_REFRESH_TOKEN").cloned(); + let scopes = config.get("oauth_scopes").cloned(); + + Ok(Self { + token_endpoint, + grant_type, + client_id, + client_secret, + refresh_token, + scopes, + }) + } +} + +// --------------------------------------------------------------------------- +// OAuth2 token response (RFC 6749 §5.1) +// --------------------------------------------------------------------------- + +#[derive(Debug, serde::Deserialize)] +struct OAuth2TokenResponse { + access_token: Option, + #[allow(dead_code)] + token_type: Option, + expires_in: Option, + refresh_token: Option, + #[allow(dead_code)] + scope: Option, +} + +// --------------------------------------------------------------------------- +// Helper functions for provider classification +// --------------------------------------------------------------------------- + +/// Reject access tokens containing header-injection characters. +/// +/// Mirrors the `validate_resolved_secret()` check in the sandbox's +/// `SecretResolver`, but applied at the token vending layer for fail-fast +/// behavior. Prevents caching a poisoned token that would be rejected +/// downstream on every request. +fn validate_token_value(value: &str) -> Result<(), TokenError> { + if value + .bytes() + .any(|b| b == b'\r' || b == b'\n' || b == b'\0') + { + return Err(TokenError::Exchange( + "access token contains prohibited control characters (CR, LF, or NUL)".into(), + )); + } + Ok(()) +} + +/// Determine if a provider uses `OAuth2` based on its config map. +pub fn is_oauth2_provider(provider: &Provider) -> bool { + provider + .config + .get("auth_method") + .is_some_and(|v| v == "oauth2") +} + +/// `OAuth2` credential keys that must never be injected into sandbox env. +pub fn is_oauth2_internal_credential(key: &str) -> bool { + matches!( + key, + "OAUTH_CLIENT_ID" | "OAUTH_CLIENT_SECRET" | "OAUTH_REFRESH_TOKEN" + ) +} + +/// Derive the access token env var key for an `OAuth2` provider. +pub fn oauth_access_token_key(provider: &Provider) -> String { + provider + .config + .get("oauth_access_token_env") + .filter(|v| !v.trim().is_empty()) + .cloned() + .unwrap_or_else(|| format!("{}_ACCESS_TOKEN", provider.r#type.to_ascii_uppercase())) +} + +/// Validate OAuth2-specific configuration at provider creation time. +/// +/// Returns `Ok(())` for non-OAuth2 providers (no-op). +#[allow(clippy::result_large_err)] +pub fn validate_oauth2_config(provider: &Provider) -> Result<(), tonic::Status> { + let config = &provider.config; + + if !is_oauth2_provider(provider) { + return Ok(()); + } + + // Required config keys. + for key in &["oauth_token_endpoint", "oauth_grant_type"] { + if config.get(*key).is_none_or(|v| v.trim().is_empty()) { + return Err(tonic::Status::invalid_argument(format!( + "OAuth2 provider requires config key '{key}'" + ))); + } + } + + let token_endpoint = config.get("oauth_token_endpoint").unwrap(); + if !token_endpoint.starts_with("https://") { + return Err(tonic::Status::invalid_argument( + "oauth_token_endpoint must use HTTPS", + )); + } + + let grant_type = config.get("oauth_grant_type").unwrap(); + match grant_type.as_str() { + "refresh_token" => { + for key in &[ + "OAUTH_CLIENT_ID", + "OAUTH_CLIENT_SECRET", + "OAUTH_REFRESH_TOKEN", + ] { + if !provider.credentials.contains_key(*key) { + return Err(tonic::Status::invalid_argument(format!( + "OAuth2 refresh_token grant requires credential '{key}'" + ))); + } + } + } + "client_credentials" => { + for key in &["OAUTH_CLIENT_ID", "OAUTH_CLIENT_SECRET"] { + if !provider.credentials.contains_key(*key) { + return Err(tonic::Status::invalid_argument(format!( + "OAuth2 client_credentials grant requires credential '{key}'" + ))); + } + } + } + other => { + return Err(tonic::Status::invalid_argument(format!( + "unsupported oauth_grant_type: '{other}' (expected 'refresh_token' or 'client_credentials')" + ))); + } + } + + Ok(()) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_oauth2_provider(grant_type: &str, include_refresh: bool) -> Provider { + let mut credentials = HashMap::new(); + credentials.insert("OAUTH_CLIENT_ID".into(), "test-client-id".into()); + credentials.insert("OAUTH_CLIENT_SECRET".into(), "test-client-secret".into()); + if include_refresh { + credentials.insert("OAUTH_REFRESH_TOKEN".into(), "test-refresh-token".into()); + } + + let mut config = HashMap::new(); + config.insert("auth_method".into(), "oauth2".into()); + config.insert("oauth_grant_type".into(), grant_type.into()); + config.insert( + "oauth_token_endpoint".into(), + "https://auth.example.com/oauth/token".into(), + ); + + Provider { + id: "test-id".into(), + name: "test-oauth".into(), + r#type: "github".into(), + credentials, + config, + } + } + + fn make_static_provider() -> Provider { + let mut credentials = HashMap::new(); + credentials.insert("GITHUB_TOKEN".into(), "ghp-static-token".into()); + + Provider { + id: "static-id".into(), + name: "test-static".into(), + r#type: "github".into(), + credentials, + config: HashMap::new(), + } + } + + // --- is_oauth2_provider --- + + #[test] + fn detects_oauth2_provider() { + let provider = make_oauth2_provider("refresh_token", true); + assert!(is_oauth2_provider(&provider)); + } + + #[test] + fn static_provider_is_not_oauth2() { + let provider = make_static_provider(); + assert!(!is_oauth2_provider(&provider)); + } + + #[test] + fn empty_config_is_not_oauth2() { + let provider = Provider { + id: String::new(), + name: "empty".into(), + r#type: "generic".into(), + credentials: HashMap::new(), + config: HashMap::new(), + }; + assert!(!is_oauth2_provider(&provider)); + } + + // --- is_oauth2_internal_credential --- + + #[test] + fn identifies_internal_credentials() { + assert!(is_oauth2_internal_credential("OAUTH_CLIENT_ID")); + assert!(is_oauth2_internal_credential("OAUTH_CLIENT_SECRET")); + assert!(is_oauth2_internal_credential("OAUTH_REFRESH_TOKEN")); + } + + #[test] + fn non_oauth_keys_are_not_internal() { + assert!(!is_oauth2_internal_credential("GITHUB_TOKEN")); + assert!(!is_oauth2_internal_credential("ANTHROPIC_API_KEY")); + assert!(!is_oauth2_internal_credential("OAUTH_ACCESS_TOKEN")); + } + + // --- oauth_access_token_key --- + + #[test] + fn derives_access_token_key_from_type() { + let provider = make_oauth2_provider("refresh_token", true); + assert_eq!(oauth_access_token_key(&provider), "GITHUB_ACCESS_TOKEN"); + } + + #[test] + fn uses_custom_access_token_key_when_configured() { + let mut provider = make_oauth2_provider("refresh_token", true); + provider + .config + .insert("oauth_access_token_env".into(), "MY_CUSTOM_TOKEN".into()); + assert_eq!(oauth_access_token_key(&provider), "MY_CUSTOM_TOKEN"); + } + + #[test] + fn ignores_empty_custom_access_token_key() { + let mut provider = make_oauth2_provider("refresh_token", true); + provider + .config + .insert("oauth_access_token_env".into(), " ".into()); + assert_eq!(oauth_access_token_key(&provider), "GITHUB_ACCESS_TOKEN"); + } + + // --- validate_oauth2_config --- + + #[test] + fn validates_refresh_token_grant() { + let provider = make_oauth2_provider("refresh_token", true); + assert!(validate_oauth2_config(&provider).is_ok()); + } + + #[test] + fn validates_client_credentials_grant() { + let provider = make_oauth2_provider("client_credentials", false); + assert!(validate_oauth2_config(&provider).is_ok()); + } + + #[test] + fn skips_validation_for_static_providers() { + let provider = make_static_provider(); + assert!(validate_oauth2_config(&provider).is_ok()); + } + + #[test] + fn rejects_missing_token_endpoint() { + let mut provider = make_oauth2_provider("refresh_token", true); + provider.config.remove("oauth_token_endpoint"); + let err = validate_oauth2_config(&provider).unwrap_err(); + assert!(err.message().contains("oauth_token_endpoint")); + } + + #[test] + fn rejects_http_token_endpoint() { + let mut provider = make_oauth2_provider("refresh_token", true); + provider.config.insert( + "oauth_token_endpoint".into(), + "http://insecure.example.com/token".into(), + ); + let err = validate_oauth2_config(&provider).unwrap_err(); + assert!(err.message().contains("HTTPS")); + } + + #[test] + fn rejects_missing_grant_type() { + let mut provider = make_oauth2_provider("refresh_token", true); + provider.config.remove("oauth_grant_type"); + let err = validate_oauth2_config(&provider).unwrap_err(); + assert!(err.message().contains("oauth_grant_type")); + } + + #[test] + fn rejects_unsupported_grant_type() { + let provider = make_oauth2_provider("implicit", false); + let err = validate_oauth2_config(&provider).unwrap_err(); + assert!(err.message().contains("implicit")); + } + + #[test] + fn rejects_refresh_token_grant_without_refresh_token() { + let provider = make_oauth2_provider("refresh_token", false); + let err = validate_oauth2_config(&provider).unwrap_err(); + assert!(err.message().contains("OAUTH_REFRESH_TOKEN")); + } + + #[test] + fn rejects_missing_client_id() { + let mut provider = make_oauth2_provider("client_credentials", false); + provider.credentials.remove("OAUTH_CLIENT_ID"); + let err = validate_oauth2_config(&provider).unwrap_err(); + assert!(err.message().contains("OAUTH_CLIENT_ID")); + } + + #[test] + fn rejects_missing_client_secret() { + let mut provider = make_oauth2_provider("client_credentials", false); + provider.credentials.remove("OAUTH_CLIENT_SECRET"); + let err = validate_oauth2_config(&provider).unwrap_err(); + assert!(err.message().contains("OAUTH_CLIENT_SECRET")); + } + + // --- OAuth2Config::from_provider --- + + #[test] + fn extracts_config_from_provider() { + let provider = make_oauth2_provider("refresh_token", true); + let config = OAuth2Config::from_provider(&provider).unwrap(); + assert_eq!( + config.token_endpoint, + "https://auth.example.com/oauth/token" + ); + assert_eq!(config.grant_type, "refresh_token"); + assert_eq!(config.client_id, "test-client-id"); + assert_eq!(config.client_secret, "test-client-secret"); + assert_eq!(config.refresh_token, Some("test-refresh-token".into())); + assert!(config.scopes.is_none()); + } + + #[test] + fn extracts_scopes_when_present() { + let mut provider = make_oauth2_provider("client_credentials", false); + provider + .config + .insert("oauth_scopes".into(), "api read_user".into()); + let config = OAuth2Config::from_provider(&provider).unwrap(); + assert_eq!(config.scopes, Some("api read_user".into())); + } + + // --- TokenVendingService (unit-level) --- + + #[tokio::test] + async fn token_vending_service_constructs() { + let _service = TokenVendingService::new(); + } + + // --- validate_token_value --- + + #[test] + fn validates_clean_token() { + assert!(validate_token_value("gho_abc123XYZ").is_ok()); + } + + #[test] + fn rejects_token_with_newline() { + let err = validate_token_value("token\ninjection").unwrap_err(); + assert!(matches!(err, TokenError::Exchange(_))); + } + + #[test] + fn rejects_token_with_carriage_return() { + let err = validate_token_value("token\rinjection").unwrap_err(); + assert!(matches!(err, TokenError::Exchange(_))); + } + + #[test] + fn rejects_token_with_null_byte() { + let err = validate_token_value("token\0injection").unwrap_err(); + assert!(matches!(err, TokenError::Exchange(_))); + } + + #[test] + fn cached_token_validity() { + let valid = CachedToken { + access_token: "tok".into(), + expires_at: tokio::time::Instant::now() + std::time::Duration::from_secs(3600), + ttl_secs: 3600, + }; + assert!(valid.is_valid()); + + let expired = CachedToken { + access_token: "tok".into(), + expires_at: tokio::time::Instant::now(), + ttl_secs: 3600, + }; + assert!(!expired.is_valid()); + } + + // --- Integration tests (mock HTTP server) --- + + #[tokio::test] + async fn exchange_token_with_mock_server() { + use tokio::net::TcpListener; + + // Start a minimal mock HTTPS-less HTTP server for testing. + // In production, token endpoints are HTTPS. For unit tests we + // bypass the HTTPS requirement by constructing OAuth2Config directly. + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + // Spawn mock server that returns a valid token response. + tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut buf = vec![0u8; 4096]; + let n = tokio::io::AsyncReadExt::read(&mut stream, &mut buf) + .await + .unwrap(); + let request = String::from_utf8_lossy(&buf[..n]); + assert!(request.contains("grant_type=client_credentials")); + assert!(request.contains("client_id=test-id")); + assert!(request.contains("client_secret=test-secret")); + + let body = + r#"{"access_token":"mock-access-token","token_type":"Bearer","expires_in":3600}"#; + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + body.len(), + body + ); + tokio::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()) + .await + .unwrap(); + }); + + let service = TokenVendingService::new(); + let config = OAuth2Config { + token_endpoint: format!("http://127.0.0.1:{}", addr.port()), + grant_type: "client_credentials".into(), + client_id: "test-id".into(), + client_secret: "test-secret".into(), + refresh_token: None, + scopes: None, + }; + + let result = service.exchange_token(&config).await.unwrap(); + assert_eq!(result.access_token, "mock-access-token"); + assert_eq!(result.expires_in_secs, 3600); + assert!(result.new_refresh_token.is_none()); + } + + #[tokio::test] + async fn exchange_token_with_refresh_rotation() { + use tokio::net::TcpListener; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut buf = vec![0u8; 4096]; + let _ = tokio::io::AsyncReadExt::read(&mut stream, &mut buf) + .await + .unwrap(); + + let body = r#"{"access_token":"new-access","token_type":"Bearer","expires_in":7200,"refresh_token":"new-refresh"}"#; + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + body.len(), + body + ); + tokio::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()) + .await + .unwrap(); + }); + + let service = TokenVendingService::new(); + let config = OAuth2Config { + token_endpoint: format!("http://127.0.0.1:{}", addr.port()), + grant_type: "refresh_token".into(), + client_id: "test-id".into(), + client_secret: "test-secret".into(), + refresh_token: Some("old-refresh".into()), + scopes: None, + }; + + let result = service.exchange_token(&config).await.unwrap(); + assert_eq!(result.access_token, "new-access"); + assert_eq!(result.expires_in_secs, 7200); + assert_eq!(result.new_refresh_token, Some("new-refresh".into())); + } + + #[tokio::test] + async fn exchange_token_error_response() { + use tokio::net::TcpListener; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut buf = vec![0u8; 4096]; + let _ = tokio::io::AsyncReadExt::read(&mut stream, &mut buf) + .await + .unwrap(); + + let body = r#"{"error":"invalid_grant","error_description":"refresh token revoked"}"#; + let response = format!( + "HTTP/1.1 400 Bad Request\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + body.len(), + body + ); + tokio::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()) + .await + .unwrap(); + }); + + let service = TokenVendingService::new(); + let config = OAuth2Config { + token_endpoint: format!("http://127.0.0.1:{}", addr.port()), + grant_type: "refresh_token".into(), + client_id: "test-id".into(), + client_secret: "test-secret".into(), + refresh_token: Some("revoked-token".into()), + scopes: None, + }; + + let err = service.exchange_token(&config).await.unwrap_err(); + assert!(matches!(err, TokenError::EndpointError { status: 400, .. })); + } + + #[tokio::test] + async fn caching_returns_same_token() { + use tokio::net::TcpListener; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + // Server only handles one connection — second call must hit cache. + tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut buf = vec![0u8; 4096]; + let _ = tokio::io::AsyncReadExt::read(&mut stream, &mut buf) + .await + .unwrap(); + + let body = r#"{"access_token":"cached-token","token_type":"Bearer","expires_in":3600}"#; + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + body.len(), + body + ); + tokio::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()) + .await + .unwrap(); + }); + + let service = TokenVendingService::new(); + let provider = Provider { + id: "cache-test".into(), + name: "cache-test".into(), + r#type: "github".into(), + credentials: [ + ("OAUTH_CLIENT_ID".into(), "id".into()), + ("OAUTH_CLIENT_SECRET".into(), "secret".into()), + ] + .into_iter() + .collect(), + config: [ + ("auth_method".into(), "oauth2".into()), + ("oauth_grant_type".into(), "client_credentials".into()), + ( + "oauth_token_endpoint".into(), + format!("http://127.0.0.1:{}", addr.port()), + ), + ] + .into_iter() + .collect(), + }; + + let result1 = service.get_or_refresh(&provider).await.unwrap(); + assert_eq!(result1.access_token, "cached-token"); + + // Second call should return cached token (server has closed). + let result2 = service.get_or_refresh(&provider).await.unwrap(); + assert_eq!(result2.access_token, "cached-token"); + } +} diff --git a/proto/openshell.proto b/proto/openshell.proto index 04f705020..af86efda5 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -433,6 +433,11 @@ message GetSandboxProviderEnvironmentRequest { message GetSandboxProviderEnvironmentResponse { // Provider credential environment variables. map environment = 1; + // Seconds until the supervisor should re-fetch credentials. + // 0 means all credentials are static — do not poll. + // The server computes this as min(ttl * 0.5) across all + // time-bounded credentials attached to the sandbox. + uint32 refresh_after_secs = 2; } // ---------------------------------------------------------------------------