diff --git a/Cargo.lock b/Cargo.lock index fd3b68d1..c76faaa1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1030,6 +1030,20 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.10.0" @@ -3003,6 +3017,7 @@ dependencies = [ "axum 0.8.8", "bytes", "clap", + "dashmap", "futures", "futures-util", "hex", diff --git a/Cargo.toml b/Cargo.toml index 3380e040..6bf359b2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -75,6 +75,9 @@ tokio-tungstenite = { version = "0.26", features = ["rustls-tls-native-roots"] } # Clipboard (OSC 52) base64 = "0.22" +# Concurrent data structures +dashmap = "6" + # Utilities futures = "0.3" bytes = "1" diff --git a/architecture/sandbox-providers.md b/architecture/sandbox-providers.md index fe5d48a9..fed97da2 100644 --- a/architecture/sandbox-providers.md +++ b/architecture/sandbox-providers.md @@ -374,6 +374,145 @@ The gateway enforces: Providers are stored with `object_type = "provider"` in the shared object store. +## OAuth2 Credential Lifecycle + +### Overview + +Providers can use OAuth2 for credential lifecycle management instead of static tokens. +The gateway server performs all OAuth2 operations (token exchange, refresh, rotation +persistence). The sandbox supervisor polls for fresh access tokens on a server-dictated +interval, atomically updating the `SecretResolver`. Sandboxes with only static credentials +incur zero overhead — no poll loop is spawned. + +The core invariant is preserved: real credentials (access tokens, refresh tokens, client +secrets) never enter the sandbox runtime. Child processes see only stable placeholder +strings. + +### Configuration + +OAuth2 is an auth method, not a provider type. Any provider type (`github`, `gitlab`, +`generic`, etc.) can use OAuth2 by setting config keys: + +| Config Key | Required | Example | Purpose | +|---|---|---|---| +| `auth_method` | Yes | `oauth2` | Discriminator (absence means static) | +| `oauth_token_endpoint` | Yes | `https://github.com/login/oauth/access_token` | Token exchange URL (HTTPS only) | +| `oauth_grant_type` | Yes | `refresh_token` or `client_credentials` | OAuth2 flow type | +| `oauth_scopes` | No | `api read_user` | Space-separated scopes | +| `oauth_access_token_env` | No | `MY_TOKEN` | Override output env var name | + +OAuth2 secret material is stored in `Provider.credentials`: + +| Credential Key | Required For | Purpose | +|---|---|---| +| `OAUTH_CLIENT_ID` | All OAuth2 | Client identifier | +| `OAUTH_CLIENT_SECRET` | All OAuth2 | Client secret | +| `OAUTH_REFRESH_TOKEN` | `refresh_token` grant | Refresh token (may be rotated) | + +### CLI Usage + +```bash +# Refresh token flow: +openshell provider create \ + --name github-oauth --type github \ + --credential OAUTH_CLIENT_ID=Iv1.abc123 \ + --credential OAUTH_CLIENT_SECRET=secret456 \ + --credential OAUTH_REFRESH_TOKEN=ghr_xyz789 \ + --config auth_method=oauth2 \ + --config oauth_grant_type=refresh_token \ + --config oauth_token_endpoint=https://github.com/login/oauth/access_token + +# Client credentials flow: +openshell provider create \ + --name service-account --type generic \ + --credential OAUTH_CLIENT_ID=client-id \ + --credential OAUTH_CLIENT_SECRET=client-secret \ + --config auth_method=oauth2 \ + --config oauth_grant_type=client_credentials \ + --config oauth_token_endpoint=https://auth.example.com/oauth2/token +``` + +### Gateway-Side Token Vending + +The `TokenVendingService` (`crates/openshell-server/src/token_vending.rs`) handles all +OAuth2 token exchange, caching, and refresh: + +- **Per-provider caching**: access tokens are cached in memory with their TTL. +- **Lazy refresh**: tokens are refreshed when a sandbox calls + `GetSandboxProviderEnvironment` and the cached token is within its safety margin + (`max(60s, ttl * 0.1)` before expiry). +- **Concurrent deduplication**: a per-provider `tokio::sync::Mutex` ensures only one + HTTP request to the IdP runs at a time; concurrent callers await the result. +- **Refresh token rotation**: if the IdP returns a new refresh token, the gateway + persists it to the store via `UpdateProvider`. + +The `resolve_provider_environment()` function detects OAuth2 providers via +`config["auth_method"] == "oauth2"`, calls the token vending service, and returns +the access token as a credential entry (e.g., `GITHUB_ACCESS_TOKEN`). OAuth2 internal +credentials (`OAUTH_CLIENT_ID`, `OAUTH_CLIENT_SECRET`, `OAUTH_REFRESH_TOKEN`) are +filtered out and never injected into the sandbox environment. + +### Response-Driven Polling + +`GetSandboxProviderEnvironmentResponse` includes a `refresh_after_secs` field: + +- **0**: all credentials are static — supervisor does not spawn a poll loop. +- **>0**: computed as `min(token_ttl / 2)` across all OAuth2 providers. The supervisor + spawns a background credential poll loop. + +### Supervisor Credential Poll Loop + +When `refresh_after_secs > 0`, the sandbox supervisor spawns +`run_credential_poll_loop()` (modeled on `run_policy_poll_loop()`): + +1. Sleeps for the server-dictated interval. +2. Calls `GetSandboxProviderEnvironment` via `CachedOpenShellClient`. +3. Atomically replaces all `SecretResolver` mappings via `replace_secrets()`. +4. On failure, tightens the retry interval to 30 seconds. +5. On recovery, restores the server-dictated interval. +6. If the server returns `refresh_after_secs == 0`, exits cleanly. + +The `SecretResolver` uses `std::sync::RwLock` to allow atomic value replacement +without blocking concurrent reads (credential resolution during request forwarding). + +### Credential Isolation + +| Secret | Gateway | Supervisor | Child Env | Proxy Wire | +|---|---|---|---|---| +| `OAUTH_CLIENT_ID` | ✅ Store + cache | ❌ | ❌ | ❌ | +| `OAUTH_CLIENT_SECRET` | ✅ Store + cache | ❌ | ❌ | ❌ | +| `OAUTH_REFRESH_TOKEN` | ✅ Store + cache | ❌ | ❌ | ❌ | +| Access token (ephemeral) | ✅ Cache | ✅ SecretResolver | ❌ Placeholder | ✅ Egress | + +### End-to-End Flow (OAuth2) + +``` +CLI: openshell provider create --type github --config auth_method=oauth2 ... + | + +-- Gateway validates OAuth2 config (HTTPS endpoint, required credentials) + +-- Persists Provider with credentials + config + | +CLI: openshell sandbox create --provider github-oauth -- claude + | + +-- Gateway: create_sandbox() validates provider exists + | + Sandbox supervisor: run_sandbox() + +-- GetSandboxProviderEnvironment + | +-- Gateway: resolve_provider_environment() + | | +-- Detects auth_method=oauth2 + | | +-- TokenVendingService::get_or_refresh() + | | | +-- POST to oauth_token_endpoint (lazy refresh) + | | | +-- Caches access_token with TTL + | | +-- Returns {GITHUB_ACCESS_TOKEN: "gho-abc123...", refresh_after_secs: 1800} + | | +-- Filters out OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OAUTH_REFRESH_TOKEN + +-- SecretResolver::from_provider_env() + | +-- child env: {GITHUB_ACCESS_TOKEN: "openshell:resolve:env:GITHUB_ACCESS_TOKEN"} + | +-- resolver: {"openshell:resolve:env:GITHUB_ACCESS_TOKEN": "gho-abc123..."} + +-- Spawns credential poll loop (refresh_after_secs=1800 → poll every 30min) + | +-- Every 30min: GetSandboxProviderEnvironment → replace_secrets() + +-- Proxy rewrites outbound headers with current access token +``` + ## Security Notes - Provider credentials are stored in `credentials` map and treated as sensitive. @@ -385,6 +524,11 @@ Providers are stored with `object_type = "provider"` in the shared object store. placeholders, and the supervisor resolves those placeholders during outbound proxying. - `OPENSHELL_SSH_HANDSHAKE_SECRET` is required by the supervisor/SSH server path but is explicitly kept out of spawned sandbox child-process environments. +- OAuth2 long-lived secrets (client ID, client secret, refresh token) never leave the + gateway process. Only short-lived access tokens are sent to the supervisor. +- OAuth2 token endpoints must use HTTPS (enforced at provider creation). +- Token endpoint responses are validated: access token values pass through + `validate_resolved_secret()` to reject header-injection characters. ## Test Strategy @@ -396,3 +540,7 @@ Providers are stored with `object_type = "provider"` in the shared object store. - sandbox unit tests validate placeholder generation and header rewriting. - E2E sandbox tests verify placeholders are visible in child env, outbound proxy traffic is rewritten with the real secret, and the SSH handshake secret is absent from exec env. +- OAuth2 token vending unit tests in `crates/openshell-server/src/token_vending.rs`: + mock HTTP server tests for token exchange, caching, rotation, and error handling. +- `SecretResolver` concurrency tests validate `replace_secrets()` under concurrent + read/write access. diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index 351a9346..e5509690 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -3450,11 +3450,27 @@ pub async fn provider_get(server: &str, name: &str, tls: &TlsOptions) -> Result< let credential_keys = provider.credentials.keys().cloned().collect::>(); let config_keys = provider.config.keys().cloned().collect::>(); + // Derive auth method display. + let auth_display = if provider + .config + .get("auth_method") + .is_some_and(|v| v == "oauth2") + { + let grant = provider + .config + .get("oauth_grant_type") + .map_or("unknown", |v| v.as_str()); + format!("OAuth2 ({grant})") + } else { + "Static".to_string() + }; + println!("{}", "Provider:".cyan().bold()); println!(); println!(" {} {}", "Id:".dimmed(), provider.id); println!(" {} {}", "Name:".dimmed(), provider.name); println!(" {} {}", "Type:".dimmed(), provider.r#type); + println!(" {} {}", "Auth:".dimmed(), auth_display); println!( " {} {}", "Credential keys:".dimmed(), diff --git a/crates/openshell-sandbox/src/grpc_client.rs b/crates/openshell-sandbox/src/grpc_client.rs index 5503637e..646c9803 100644 --- a/crates/openshell-sandbox/src/grpc_client.rs +++ b/crates/openshell-sandbox/src/grpc_client.rs @@ -181,15 +181,22 @@ pub async fn sync_policy(endpoint: &str, sandbox: &str, policy: &ProtoSandboxPol sync_policy_with_client(&mut client, sandbox, policy).await } +/// Result of fetching provider environment. +pub struct ProviderEnvironmentResult { + /// Credential environment variables (key → secret value). + pub environment: HashMap, + /// Seconds until next refresh. 0 = static credentials only. + pub refresh_after_secs: u32, +} + /// Fetch provider environment variables for a sandbox from OpenShell server via gRPC. /// -/// Returns a map of environment variable names to values derived from provider -/// credentials configured on the sandbox. Returns an empty map if the sandbox -/// has no providers or the call fails. +/// Returns credential env vars and a refresh interval. A `refresh_after_secs` +/// of 0 means all credentials are static and no polling is needed. pub async fn fetch_provider_environment( endpoint: &str, sandbox_id: &str, -) -> Result> { +) -> Result { debug!(endpoint = %endpoint, sandbox_id = %sandbox_id, "Fetching provider environment"); let mut client = connect(endpoint).await?; @@ -201,7 +208,11 @@ pub async fn fetch_provider_environment( .await .into_diagnostic()?; - Ok(response.into_inner().environment) + let inner = response.into_inner(); + Ok(ProviderEnvironmentResult { + environment: inner.environment, + refresh_after_secs: inner.refresh_after_secs, + }) } /// A reusable gRPC client for the OpenShell service. @@ -221,7 +232,7 @@ pub struct SettingsPollResult { pub config_revision: u64, pub policy_source: PolicySource, /// Effective settings keyed by name. - pub settings: std::collections::HashMap, + pub settings: HashMap, /// When `policy_source` is `Global`, the version of the global policy revision. pub global_policy_version: u32, } @@ -264,6 +275,27 @@ impl CachedOpenShellClient { }) } + /// Fetch provider environment for credential refresh polling. + pub async fn fetch_provider_environment( + &self, + sandbox_id: &str, + ) -> Result { + let response = self + .client + .clone() + .get_sandbox_provider_environment(GetSandboxProviderEnvironmentRequest { + sandbox_id: sandbox_id.to_string(), + }) + .await + .into_diagnostic()?; + + let inner = response.into_inner(); + Ok(ProviderEnvironmentResult { + environment: inner.environment, + refresh_after_secs: inner.refresh_after_secs, + }) + } + /// Submit denial summaries for policy analysis. pub async fn submit_policy_analysis( &self, diff --git a/crates/openshell-sandbox/src/lib.rs b/crates/openshell-sandbox/src/lib.rs index b160cdef..c58c402b 100644 --- a/crates/openshell-sandbox/src/lib.rs +++ b/crates/openshell-sandbox/src/lib.rs @@ -187,22 +187,27 @@ pub async fn run_sandbox( // Fetch provider environment variables from the server. // This is done after loading the policy so the sandbox can still start // even if provider env fetch fails (graceful degradation). - let provider_env = if let (Some(id), Some(endpoint)) = (&sandbox_id, &openshell_endpoint) { - match grpc_client::fetch_provider_environment(endpoint, id).await { - Ok(env) => { - info!(env_count = env.len(), "Fetched provider environment"); - env - } - Err(e) => { - warn!(error = %e, "Failed to fetch provider environment, continuing without"); - std::collections::HashMap::new() + let (provider_env_map, refresh_after_secs) = + if let (Some(id), Some(endpoint)) = (&sandbox_id, &openshell_endpoint) { + match grpc_client::fetch_provider_environment(endpoint, id).await { + Ok(result) => { + info!( + env_count = result.environment.len(), + refresh_after_secs = result.refresh_after_secs, + "Fetched provider environment" + ); + (result.environment, result.refresh_after_secs) + } + Err(e) => { + warn!(error = %e, "Failed to fetch provider environment, continuing without"); + (std::collections::HashMap::new(), 0) + } } - } - } else { - std::collections::HashMap::new() - }; + } else { + (std::collections::HashMap::new(), 0) + }; - let (provider_env, secret_resolver) = SecretResolver::from_provider_env(provider_env); + let (provider_env, secret_resolver) = SecretResolver::from_provider_env(provider_env_map); let secret_resolver = secret_resolver.map(Arc::new); // Create identity cache for SHA256 TOFU when OPA is active @@ -589,6 +594,30 @@ pub async fn run_sandbox( } }); + // Spawn credential poll loop (only if credentials are time-bounded). + if refresh_after_secs > 0 { + if let Some(ref resolver) = secret_resolver { + let cred_id = id.clone(); + let cred_endpoint = endpoint.clone(); + let cred_resolver = resolver.clone(); + + tokio::spawn(async move { + if let Err(e) = run_credential_poll_loop( + &cred_endpoint, + &cred_id, + &cred_resolver, + refresh_after_secs, + ) + .await + { + warn!(error = %e, "Credential poll loop exited with error"); + } + }); + } + } else { + debug!("All provider credentials are static, skipping credential poll loop"); + } + // Spawn denial aggregator (gRPC mode only, when proxy is active). if let Some(rx) = denial_rx { // SubmitPolicyAnalysis resolves by sandbox *name*, not UUID. @@ -1637,6 +1666,76 @@ async fn run_policy_poll_loop( } /// Log individual setting changes between two snapshots. +/// Minimum poll interval for credential refresh (floor). +const CREDENTIAL_MIN_POLL_INTERVAL_SECS: u64 = 30; +/// Retry interval when credential refresh fails. +const CREDENTIAL_FAILURE_RETRY_SECS: u64 = 30; + +/// Background loop that periodically re-fetches provider credentials from the +/// gateway. Modeled on [`run_policy_poll_loop`] but focused on credential +/// lifecycle. +/// +/// The loop exits cleanly when the gateway returns `refresh_after_secs == 0`, +/// indicating all credentials have become static (e.g. provider reconfigured). +async fn run_credential_poll_loop( + endpoint: &str, + sandbox_id: &str, + secret_resolver: &Arc, + initial_refresh_secs: u32, +) -> Result<()> { + use crate::grpc_client::CachedOpenShellClient; + + let client = CachedOpenShellClient::connect(endpoint).await?; + let mut current_interval = + Duration::from_secs((initial_refresh_secs as u64).max(CREDENTIAL_MIN_POLL_INTERVAL_SECS)); + let mut consecutive_failures: u32 = 0; + + loop { + tokio::time::sleep(current_interval).await; + + match client.fetch_provider_environment(sandbox_id).await { + Ok(result) => { + if consecutive_failures > 0 { + info!( + consecutive_failures, + "Credential refresh recovered after failures" + ); + } + consecutive_failures = 0; + + // Atomically replace the secret values. + secret_resolver.replace_secrets(result.environment); + + // Use server-dictated interval for next poll. + let server_interval = + u64::from(result.refresh_after_secs).max(CREDENTIAL_MIN_POLL_INTERVAL_SECS); + current_interval = Duration::from_secs(server_interval); + + debug!( + refresh_after_secs = result.refresh_after_secs, + "Credential refresh completed" + ); + + // If server says 0, credentials became static. Exit loop. + if result.refresh_after_secs == 0 { + info!("All credentials now static, stopping credential poll loop"); + return Ok(()); + } + } + Err(e) => { + consecutive_failures += 1; + current_interval = Duration::from_secs(CREDENTIAL_FAILURE_RETRY_SECS); + warn!( + error = %e, + consecutive_failures, + retry_in_secs = CREDENTIAL_FAILURE_RETRY_SECS, + "Credential refresh failed, retrying at reduced interval" + ); + } + } + } +} + fn log_setting_changes( old: &std::collections::HashMap, new: &std::collections::HashMap, diff --git a/crates/openshell-sandbox/src/secrets.rs b/crates/openshell-sandbox/src/secrets.rs index a27537c9..b87153a4 100644 --- a/crates/openshell-sandbox/src/secrets.rs +++ b/crates/openshell-sandbox/src/secrets.rs @@ -4,6 +4,7 @@ use base64::Engine as _; use std::collections::HashMap; use std::fmt; +use std::sync::RwLock; const PLACEHOLDER_PREFIX: &str = "openshell:resolve:env:"; @@ -61,9 +62,23 @@ pub(crate) struct RewriteTargetResult { // SecretResolver // --------------------------------------------------------------------------- -#[derive(Debug, Clone, Default)] +/// Resolves `openshell:resolve:env:*` placeholder strings to real secret values. +/// +/// This type intentionally does **not** implement `Clone` because the inner +/// `RwLock` requires atomic access for credential refresh. All call +/// sites wrap the resolver in `Arc` for shared ownership. +#[derive(Default)] pub struct SecretResolver { - by_placeholder: HashMap, + by_placeholder: RwLock>, +} + +impl fmt::Debug for SecretResolver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let count = self.by_placeholder.read().map(|m| m.len()).unwrap_or(0); + f.debug_struct("SecretResolver") + .field("placeholder_count", &count) + .finish() + } } impl SecretResolver { @@ -83,17 +98,41 @@ impl SecretResolver { by_placeholder.insert(placeholder, value); } - (child_env, Some(Self { by_placeholder })) + ( + child_env, + Some(Self { + by_placeholder: RwLock::new(by_placeholder), + }), + ) + } + + /// Atomically replace all secret mappings with new values. + /// + /// Used by the credential poll loop to swap in fresh OAuth2 access tokens. + /// Takes env-key-keyed input (e.g. `{"GITHUB_TOKEN": "new-val"}`) and + /// converts to placeholder-keyed internally. + pub(crate) fn replace_secrets(&self, new_provider_env: HashMap) { + let mut new_map = HashMap::with_capacity(new_provider_env.len()); + for (key, value) in new_provider_env { + let placeholder = placeholder_for_env_key(&key); + new_map.insert(placeholder, value); + } + let mut guard = self.by_placeholder.write().expect("secret resolver lock"); + *guard = new_map; } /// Resolve a placeholder string to the real secret value. /// /// Returns `None` if the placeholder is unknown or the resolved value /// contains prohibited control characters (CRLF, null byte). - pub(crate) fn resolve_placeholder(&self, value: &str) -> Option<&str> { - let secret = self.by_placeholder.get(value).map(String::as_str)?; + /// + /// Returns an owned `String` because the value is cloned under a read + /// lock — callers must not hold a borrow across the lock boundary. + pub(crate) fn resolve_placeholder(&self, value: &str) -> Option { + let guard = self.by_placeholder.read().expect("secret resolver lock"); + let secret = guard.get(value)?; match validate_resolved_secret(secret) { - Ok(s) => Some(s), + Ok(s) => Some(s.to_string()), Err(reason) => { tracing::warn!( location = "resolve_placeholder", @@ -108,7 +147,7 @@ impl SecretResolver { pub(crate) fn rewrite_header_value(&self, value: &str) -> Option { // Direct placeholder match: `x-api-key: openshell:resolve:env:KEY` if let Some(secret) = self.resolve_placeholder(value.trim()) { - return Some(secret.to_string()); + return Some(secret); } let trimmed = value.trim(); @@ -147,20 +186,29 @@ impl SecretResolver { return None; } - // Rewrite all placeholder occurrences in the decoded string + // Rewrite all placeholder occurrences in the decoded string. + // Collect replacements under a single read-lock acquisition to + // avoid holding the lock across the string rewriting loop. + let replacements: Vec<(String, String)> = { + let guard = self.by_placeholder.read().expect("secret resolver lock"); + guard + .iter() + .filter(|(ph, _)| decoded.contains(ph.as_str())) + .map(|(ph, sec)| (ph.clone(), sec.clone())) + .collect() + }; + let mut rewritten = decoded.to_string(); - for (placeholder, secret) in &self.by_placeholder { - if rewritten.contains(placeholder.as_str()) { - // Validate the resolved secret for control characters - if validate_resolved_secret(secret).is_err() { - tracing::warn!( - location = "basic_auth", - "credential resolution rejected: resolved value contains prohibited characters" - ); - return None; - } - rewritten = rewritten.replace(placeholder.as_str(), secret); + for (placeholder, secret) in &replacements { + // Validate the resolved secret for control characters + if validate_resolved_secret(secret).is_err() { + tracing::warn!( + location = "basic_auth", + "credential resolution rejected: resolved value contains prohibited characters" + ); + return None; } + rewritten = rewritten.replace(placeholder.as_str(), secret); } // Only return if we actually changed something @@ -479,7 +527,7 @@ fn rewrite_path_segment( let full_placeholder = &segment[abs_start..key_end]; if let Some(secret) = resolver.resolve_placeholder(full_placeholder) { - validate_credential_for_path(secret).map_err(|reason| { + validate_credential_for_path(&secret).map_err(|reason| { tracing::warn!( location = "path", %reason, @@ -487,7 +535,7 @@ fn rewrite_path_segment( ); UnresolvedPlaceholderError { location: "path" } })?; - resolved.push_str(secret); + resolved.push_str(&secret); redacted.push_str("[CREDENTIAL]"); } else { return Err(UnresolvedPlaceholderError { location: "path" }); @@ -523,7 +571,7 @@ fn rewrite_uri_query_params( if let Some((key, value)) = param.split_once('=') { let decoded_value = percent_decode(value); if let Some(secret) = resolver.resolve_placeholder(&decoded_value) { - resolved_params.push(format!("{key}={}", percent_encode_query(secret))); + resolved_params.push(format!("{key}={}", percent_encode_query(&secret))); redacted_params.push(format!("{key}=[CREDENTIAL]")); any_rewritten = true; } else if decoded_value.contains(PLACEHOLDER_PREFIX) { @@ -717,7 +765,7 @@ mod tests { .as_ref() .and_then(|resolver| resolver .resolve_placeholder("openshell:resolve:env:ANTHROPIC_API_KEY")), - Some("sk-test") + Some("sk-test".to_string()) ); } @@ -890,7 +938,7 @@ mod tests { let resolver = resolver.expect("resolver"); assert_eq!( resolver.resolve_placeholder("openshell:resolve:env:KEY"), - Some("sk-abc123_DEF.456~xyz") + Some("sk-abc123_DEF.456~xyz".to_string()) ); } @@ -1473,4 +1521,124 @@ mod tests { assert_eq!(result.resolved, "/bottok123/method?key=key456"); assert_eq!(result.redacted, "/bot[CREDENTIAL]/method?key=[CREDENTIAL]"); } + + // === RwLock and replace_secrets tests === + + #[test] + fn replace_secrets_updates_resolved_values() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_KEY".to_string(), "old-secret".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + + assert_eq!( + resolver.resolve_placeholder("openshell:resolve:env:API_KEY"), + Some("old-secret".to_string()) + ); + + // Replace with new value. + resolver.replace_secrets( + [("API_KEY".to_string(), "new-secret".to_string())] + .into_iter() + .collect(), + ); + + assert_eq!( + resolver.resolve_placeholder("openshell:resolve:env:API_KEY"), + Some("new-secret".to_string()) + ); + } + + #[test] + fn replace_secrets_removes_stale_keys() { + let (_, resolver) = SecretResolver::from_provider_env( + [ + ("KEY_A".to_string(), "val-a".to_string()), + ("KEY_B".to_string(), "val-b".to_string()), + ] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + + // Replace with only KEY_A — KEY_B should disappear. + resolver.replace_secrets( + [("KEY_A".to_string(), "val-a-new".to_string())] + .into_iter() + .collect(), + ); + + assert_eq!( + resolver.resolve_placeholder("openshell:resolve:env:KEY_A"), + Some("val-a-new".to_string()) + ); + assert_eq!( + resolver.resolve_placeholder("openshell:resolve:env:KEY_B"), + None + ); + } + + #[test] + fn replace_secrets_adds_new_keys() { + let (_, resolver) = SecretResolver::from_provider_env( + [("EXISTING".to_string(), "val".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + + resolver.replace_secrets( + [ + ("EXISTING".to_string(), "val".to_string()), + ("NEW_KEY".to_string(), "new-val".to_string()), + ] + .into_iter() + .collect(), + ); + + assert_eq!( + resolver.resolve_placeholder("openshell:resolve:env:NEW_KEY"), + Some("new-val".to_string()) + ); + } + + #[test] + fn concurrent_read_during_replace_does_not_panic() { + use std::sync::Arc; + + let (_, resolver) = SecretResolver::from_provider_env( + [("TOKEN".to_string(), "initial".to_string())] + .into_iter() + .collect(), + ); + let resolver = Arc::new(resolver.expect("resolver")); + + // Spawn readers and one writer concurrently. + let mut handles = Vec::new(); + for _ in 0..8 { + let r = resolver.clone(); + handles.push(std::thread::spawn(move || { + for _ in 0..100 { + let _ = r.resolve_placeholder("openshell:resolve:env:TOKEN"); + } + })); + } + + let w = resolver.clone(); + handles.push(std::thread::spawn(move || { + for i in 0..100 { + w.replace_secrets( + [("TOKEN".to_string(), format!("val-{i}"))] + .into_iter() + .collect(), + ); + } + })); + + for h in handles { + h.join().expect("thread should not panic"); + } + } } diff --git a/crates/openshell-server/Cargo.toml b/crates/openshell-server/Cargo.toml index 0308f30f..831a0801 100644 --- a/crates/openshell-server/Cargo.toml +++ b/crates/openshell-server/Cargo.toml @@ -63,6 +63,7 @@ serde_json = { workspace = true } tokio-stream = { workspace = true } sqlx = { workspace = true } reqwest = { workspace = true } +dashmap = { workspace = true } kube = { workspace = true } kube-runtime = { workspace = true } k8s-openapi = { workspace = true } diff --git a/crates/openshell-server/src/grpc.rs b/crates/openshell-server/src/grpc.rs index d7ef4ccf..b7b20906 100644 --- a/crates/openshell-server/src/grpc.rs +++ b/crates/openshell-server/src/grpc.rs @@ -929,18 +929,24 @@ impl OpenShell for OpenShellService { .spec .ok_or_else(|| Status::internal("sandbox has no spec"))?; - let environment = - resolve_provider_environment(self.state.store.as_ref(), &spec.providers).await?; + let resolved = resolve_provider_environment( + self.state.store.as_ref(), + &self.state.token_vending, + &spec.providers, + ) + .await?; info!( sandbox_id = %sandbox_id, provider_count = spec.providers.len(), - env_count = environment.len(), + env_count = resolved.environment.len(), + refresh_after_secs = resolved.refresh_after_secs, "GetSandboxProviderEnvironment request completed successfully" ); Ok(Response::new(GetSandboxProviderEnvironmentResponse { - environment, + environment: resolved.environment, + refresh_after_secs: resolved.refresh_after_secs, })) } @@ -3634,21 +3640,39 @@ fn build_remote_exec_command(req: &ExecSandboxRequest) -> Result Ok(result) } +/// Resolved provider environment with optional refresh interval. +#[derive(Debug)] +struct ResolvedProviderEnv { + environment: std::collections::HashMap, + /// Seconds until the supervisor should re-fetch. 0 = static only. + refresh_after_secs: u32, +} + /// Resolve provider credentials into environment variables. /// /// For each provider name in the list, fetches the provider from the store and /// collects credential key-value pairs. Returns a map of environment variables /// to inject into the sandbox. When duplicate keys appear across providers, the /// first provider's value wins. +/// +/// For OAuth2 providers, performs token exchange/refresh via the +/// `TokenVendingService` and returns the access token as a regular credential +/// entry. OAuth2 internal credentials (`OAUTH_CLIENT_ID`, `OAUTH_CLIENT_SECRET`, +/// `OAUTH_REFRESH_TOKEN`) are filtered and never injected into the sandbox. async fn resolve_provider_environment( store: &crate::persistence::Store, + token_vending: &crate::token_vending::TokenVendingService, provider_names: &[String], -) -> Result, Status> { +) -> Result { if provider_names.is_empty() { - return Ok(std::collections::HashMap::new()); + return Ok(ResolvedProviderEnv { + environment: std::collections::HashMap::new(), + refresh_after_secs: 0, + }); } let mut env = std::collections::HashMap::new(); + let mut min_refresh_secs: Option = None; for name in provider_names { let provider = store @@ -3657,7 +3681,60 @@ async fn resolve_provider_environment( .map_err(|e| Status::internal(format!("failed to fetch provider '{name}': {e}")))? .ok_or_else(|| Status::failed_precondition(format!("provider '{name}' not found")))?; + let is_oauth2 = crate::token_vending::is_oauth2_provider(&provider); + + // Handle OAuth2 token resolution. + if is_oauth2 { + match token_vending.get_or_refresh(&provider).await { + Ok(result) => { + let token_key = crate::token_vending::oauth_access_token_key(&provider); + env.entry(token_key).or_insert(result.access_token); + + // If the IdP returned a rotated refresh token, persist it. + // We construct a Provider with empty `id` and `config` + // because `update_provider_record` resolves by `name` (not + // id) and `merge_map` treats empty maps as no-ops, so only + // the credentials map is merged into the existing record. + if let Some(new_refresh) = result.new_refresh_token { + let mut updated_creds = std::collections::HashMap::new(); + updated_creds.insert("OAUTH_REFRESH_TOKEN".to_string(), new_refresh); + let updated = Provider { + id: String::new(), + name: provider.name.clone(), + r#type: provider.r#type.clone(), + credentials: updated_creds, + config: std::collections::HashMap::new(), + }; + if let Err(e) = update_provider_record(store, updated).await { + warn!( + provider_name = %name, + error = %e.message(), + "failed to persist rotated refresh token" + ); + } + } + + // Track shortest TTL for poll interval. Floor at 1 so + // that very short-lived tokens (expires_in <= 1) don't + // produce 0, which the supervisor interprets as "static". + let refresh = (result.expires_in_secs / 2).max(1); + min_refresh_secs = Some(min_refresh_secs.map_or(refresh, |m| m.min(refresh))); + } + Err(e) => { + return Err(Status::internal(format!( + "OAuth2 token refresh failed for provider '{name}': {e}" + ))); + } + } + } + + // Inject static credentials; skip OAuth2 internal keys only for + // OAuth2 providers so that a static provider using a credential key + // like "OAUTH_CLIENT_ID" is not accidentally suppressed. for (key, value) in &provider.credentials { + if is_oauth2 && crate::token_vending::is_oauth2_internal_credential(key) { + continue; + } if is_valid_env_key(key) { env.entry(key.clone()).or_insert_with(|| value.clone()); } else { @@ -3670,7 +3747,10 @@ async fn resolve_provider_environment( } } - Ok(env) + Ok(ResolvedProviderEnv { + environment: env, + refresh_after_secs: min_refresh_secs.unwrap_or(0), + }) } fn is_valid_env_key(key: &str) -> bool { @@ -4157,6 +4237,9 @@ async fn create_provider_record( // Validate field sizes before any I/O. validate_provider_fields(&provider)?; + // Validate OAuth2-specific configuration if present. + crate::token_vending::validate_oauth2_config(&provider)?; + let existing = store .get_message_by_name::(&provider.name) .await @@ -4270,6 +4353,10 @@ async fn update_provider_record( validate_provider_fields(&updated)?; + // Validate OAuth2-specific configuration on the merged result, so a + // user cannot add auth_method=oauth2 without the required credentials. + crate::token_vending::validate_oauth2_config(&updated)?; + store .put_message(&updated) .await @@ -4872,8 +4959,11 @@ mod tests { #[tokio::test] async fn resolve_provider_env_empty_list_returns_empty() { let store = Store::connect("sqlite::memory:").await.unwrap(); - let result = resolve_provider_environment(&store, &[]).await.unwrap(); - assert!(result.is_empty()); + let token_vending = crate::token_vending::TokenVendingService::new(); + let result = resolve_provider_environment(&store, &token_vending, &[]) + .await + .unwrap(); + assert!(result.environment.is_empty()); } #[tokio::test] @@ -4897,21 +4987,31 @@ mod tests { }; create_provider_record(&store, provider).await.unwrap(); - let result = resolve_provider_environment(&store, &["claude-local".to_string()]) - .await - .unwrap(); - assert_eq!(result.get("ANTHROPIC_API_KEY"), Some(&"sk-abc".to_string())); - assert_eq!(result.get("CLAUDE_API_KEY"), Some(&"sk-abc".to_string())); + let token_vending = crate::token_vending::TokenVendingService::new(); + let result = + resolve_provider_environment(&store, &token_vending, &["claude-local".to_string()]) + .await + .unwrap(); + assert_eq!( + result.environment.get("ANTHROPIC_API_KEY"), + Some(&"sk-abc".to_string()) + ); + assert_eq!( + result.environment.get("CLAUDE_API_KEY"), + Some(&"sk-abc".to_string()) + ); // Config values should NOT be injected. - assert!(!result.contains_key("endpoint")); + assert!(!result.environment.contains_key("endpoint")); } #[tokio::test] async fn resolve_provider_env_unknown_name_returns_error() { let store = Store::connect("sqlite::memory:").await.unwrap(); - let err = resolve_provider_environment(&store, &["nonexistent".to_string()]) - .await - .unwrap_err(); + let token_vending = crate::token_vending::TokenVendingService::new(); + let err = + resolve_provider_environment(&store, &token_vending, &["nonexistent".to_string()]) + .await + .unwrap_err(); assert_eq!(err.code(), Code::FailedPrecondition); assert!(err.message().contains("nonexistent")); } @@ -4934,12 +5034,17 @@ mod tests { }; create_provider_record(&store, provider).await.unwrap(); - let result = resolve_provider_environment(&store, &["test-provider".to_string()]) - .await - .unwrap(); - assert_eq!(result.get("VALID_KEY"), Some(&"value".to_string())); - assert!(!result.contains_key("nested.api_key")); - assert!(!result.contains_key("bad-key")); + let token_vending = crate::token_vending::TokenVendingService::new(); + let result = + resolve_provider_environment(&store, &token_vending, &["test-provider".to_string()]) + .await + .unwrap(); + assert_eq!( + result.environment.get("VALID_KEY"), + Some(&"value".to_string()) + ); + assert!(!result.environment.contains_key("nested.api_key")); + assert!(!result.environment.contains_key("bad-key")); } #[tokio::test] @@ -4975,14 +5080,22 @@ mod tests { .await .unwrap(); + let token_vending = crate::token_vending::TokenVendingService::new(); let result = resolve_provider_environment( &store, + &token_vending, &["claude-local".to_string(), "gitlab-local".to_string()], ) .await .unwrap(); - assert_eq!(result.get("ANTHROPIC_API_KEY"), Some(&"sk-abc".to_string())); - assert_eq!(result.get("GITLAB_TOKEN"), Some(&"glpat-xyz".to_string())); + assert_eq!( + result.environment.get("ANTHROPIC_API_KEY"), + Some(&"sk-abc".to_string()) + ); + assert_eq!( + result.environment.get("GITLAB_TOKEN"), + Some(&"glpat-xyz".to_string()) + ); } #[tokio::test] @@ -5018,13 +5131,18 @@ mod tests { .await .unwrap(); + let token_vending = crate::token_vending::TokenVendingService::new(); let result = resolve_provider_environment( &store, + &token_vending, &["provider-a".to_string(), "provider-b".to_string()], ) .await .unwrap(); - assert_eq!(result.get("SHARED_KEY"), Some(&"first-value".to_string())); + assert_eq!( + result.environment.get("SHARED_KEY"), + Some(&"first-value".to_string()) + ); } /// Simulates the handler flow: persist a sandbox with providers, then resolve @@ -5075,11 +5193,15 @@ mod tests { .unwrap() .unwrap(); let spec = loaded.spec.unwrap(); - let env = resolve_provider_environment(&store, &spec.providers) + let token_vending = crate::token_vending::TokenVendingService::new(); + let env = resolve_provider_environment(&store, &token_vending, &spec.providers) .await .unwrap(); - assert_eq!(env.get("ANTHROPIC_API_KEY"), Some(&"sk-test".to_string())); + assert_eq!( + env.environment.get("ANTHROPIC_API_KEY"), + Some(&"sk-test".to_string()) + ); } /// Handler flow returns empty map when sandbox has no providers. @@ -5106,11 +5228,12 @@ mod tests { .unwrap() .unwrap(); let spec = loaded.spec.unwrap(); - let env = resolve_provider_environment(&store, &spec.providers) + let token_vending = crate::token_vending::TokenVendingService::new(); + let env = resolve_provider_environment(&store, &token_vending, &spec.providers) .await .unwrap(); - assert!(env.is_empty()); + assert!(env.environment.is_empty()); } /// Handler returns not-found when sandbox doesn't exist. diff --git a/crates/openshell-server/src/lib.rs b/crates/openshell-server/src/lib.rs index e827b362..5547eff4 100644 --- a/crates/openshell-server/src/lib.rs +++ b/crates/openshell-server/src/lib.rs @@ -20,6 +20,7 @@ mod sandbox_index; mod sandbox_watch; mod ssh_tunnel; mod tls; +pub mod token_vending; pub mod tracing_bus; mod ws_tunnel; @@ -38,6 +39,7 @@ use sandbox::{SandboxClient, spawn_sandbox_watcher, spawn_store_reconciler}; use sandbox_index::SandboxIndex; use sandbox_watch::{SandboxWatchBus, spawn_kube_event_tailer}; pub use tls::TlsAcceptor; +use token_vending::TokenVendingService; use tracing_bus::TracingLogBus; /// Server state shared across handlers. @@ -72,6 +74,9 @@ pub struct ServerState { /// set/delete operation, including the precedence check on sandbox /// mutations that reads global state. pub settings_mutex: tokio::sync::Mutex<()>, + + /// Token vending service for `OAuth2` providers. + pub token_vending: TokenVendingService, } fn is_benign_tls_handshake_failure(error: &std::io::Error) -> bool { @@ -102,6 +107,7 @@ impl ServerState { ssh_connections_by_token: Mutex::new(HashMap::new()), ssh_connections_by_sandbox: Mutex::new(HashMap::new()), settings_mutex: tokio::sync::Mutex::new(()), + token_vending: TokenVendingService::new(), } } } diff --git a/crates/openshell-server/src/token_vending.rs b/crates/openshell-server/src/token_vending.rs new file mode 100644 index 00000000..235e949b --- /dev/null +++ b/crates/openshell-server/src/token_vending.rs @@ -0,0 +1,939 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! `OAuth2` token vending service. +//! +//! Performs `OAuth2` token exchanges, caches access tokens per provider, and +//! deduplicates concurrent refresh requests. Used by +//! `resolve_provider_environment()` to transparently return current access +//! tokens for `OAuth2` providers. + +use std::collections::HashMap; +use std::sync::Arc; + +use openshell_core::proto::Provider; +use tokio::sync::Mutex; +use tracing::{debug, error, info}; + +// --------------------------------------------------------------------------- +// Error types +// --------------------------------------------------------------------------- + +/// Errors returned by the token vending service. +#[derive(Debug, thiserror::Error)] +pub enum TokenError { + #[error("OAuth2 token exchange failed: {0}")] + Exchange(String), + + #[error("OAuth2 provider misconfigured: {0}")] + Config(String), + + #[error("HTTP request to token endpoint failed: {0}")] + Http(String), + + #[error("token endpoint returned error: {status} — {body}")] + EndpointError { status: u16, body: String }, +} + +// --------------------------------------------------------------------------- +// Token result +// --------------------------------------------------------------------------- + +/// Result of a successful token acquisition. +#[derive(Debug, Clone)] +pub struct TokenResult { + /// The current access token value. + pub access_token: String, + /// Seconds until this token expires. + pub expires_in_secs: u32, + /// If the `IdP` returned a rotated refresh token, it is captured here + /// so the caller can persist it. + pub new_refresh_token: Option, +} + +// --------------------------------------------------------------------------- +// Cached token +// --------------------------------------------------------------------------- + +/// Cached access token with expiry metadata. +#[derive(Debug, Clone)] +struct CachedToken { + access_token: String, + /// Absolute expiry time (monotonic). + expires_at: tokio::time::Instant, + /// Original TTL in seconds as reported by the `IdP`. + ttl_secs: u32, +} + +impl CachedToken { + /// Returns `true` if the token is still usable (with safety margin). + fn is_valid(&self) -> bool { + let margin = std::cmp::max(60, u64::from(self.ttl_secs / 10)); + let margin_dur = std::time::Duration::from_secs(margin); + tokio::time::Instant::now() + margin_dur < self.expires_at + } +} + +// --------------------------------------------------------------------------- +// Per-provider state +// --------------------------------------------------------------------------- + +/// Per-provider token state, protected by a Tokio mutex for async refresh. +struct ProviderTokenState { + cached: Option, +} + +// --------------------------------------------------------------------------- +// Token vending service +// --------------------------------------------------------------------------- + +/// Token vending service. One per gateway process. +/// +/// Caches access tokens per provider and deduplicates concurrent refresh +/// requests via a per-provider `Mutex`. +pub struct TokenVendingService { + /// Per-provider token state, keyed by provider ID. + state: dashmap::DashMap>>, + /// Shared HTTP client for token endpoint requests. + http_client: reqwest::Client, +} + +impl std::fmt::Debug for TokenVendingService { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TokenVendingService") + .field("cached_providers", &self.state.len()) + .finish() + } +} + +impl TokenVendingService { + /// Create a new token vending service. + #[must_use] + pub fn new() -> Self { + let http_client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .build() + .expect("failed to build reqwest client"); + Self { + state: dashmap::DashMap::new(), + http_client, + } + } + + /// Get a current access token for an `OAuth2` provider. + /// + /// Returns the cached token if still valid (with safety margin). + /// Otherwise performs a token refresh, caches the result, and returns it. + /// Concurrent callers for the same provider are deduplicated — only one + /// HTTP request to the `IdP` is made; others await the result. + pub async fn get_or_refresh(&self, provider: &Provider) -> Result { + let provider_id = &provider.id; + + // Get or create per-provider state. + let state = self + .state + .entry(provider_id.clone()) + .or_insert_with(|| Arc::new(Mutex::new(ProviderTokenState { cached: None }))) + .clone(); + + // Acquire the per-provider mutex. This serializes concurrent refresh + // requests for the same provider — only the first caller performs the + // HTTP request; others await the mutex and get the cached result. + let mut guard = state.lock().await; + + // Check cache under the lock. + if let Some(ref cached) = guard.cached + && cached.is_valid() + { + debug!( + provider_id = %provider_id, + ttl_remaining_secs = cached.expires_at.duration_since(tokio::time::Instant::now()).as_secs(), + "Returning cached OAuth2 access token" + ); + return Ok(TokenResult { + access_token: cached.access_token.clone(), + expires_in_secs: cached.ttl_secs, + new_refresh_token: None, + }); + } + + // Cache miss or expired — perform token exchange. + let config = OAuth2Config::from_provider(provider)?; + let result = self.exchange_token(&config).await?; + + info!( + provider_id = %provider_id, + provider_name = %provider.name, + expires_in_secs = result.expires_in_secs, + refresh_token_rotated = result.new_refresh_token.is_some(), + "OAuth2 token exchange completed" + ); + + // Cache the new token. + guard.cached = Some(CachedToken { + access_token: result.access_token.clone(), + expires_at: tokio::time::Instant::now() + + std::time::Duration::from_secs(u64::from(result.expires_in_secs)), + ttl_secs: result.expires_in_secs, + }); + + Ok(result) + } + + /// Perform the actual HTTP token exchange. + async fn exchange_token(&self, config: &OAuth2Config) -> Result { + let mut params = HashMap::new(); + params.insert("client_id", config.client_id.as_str()); + params.insert("client_secret", config.client_secret.as_str()); + params.insert("grant_type", config.grant_type.as_str()); + + if let Some(ref refresh_token) = config.refresh_token { + params.insert("refresh_token", refresh_token.as_str()); + } + + if let Some(ref scopes) = config.scopes { + params.insert("scope", scopes.as_str()); + } + + debug!( + token_endpoint = %config.token_endpoint, + grant_type = %config.grant_type, + "Sending OAuth2 token exchange request" + ); + + // Explicitly request JSON. Some OAuth2 providers (notably GitHub) + // default to application/x-www-form-urlencoded responses without this. + let response = self + .http_client + .post(&config.token_endpoint) + .header("Accept", "application/json") + .form(¶ms) + .send() + .await + .map_err(|e| TokenError::Http(e.to_string()))?; + + let status = response.status().as_u16(); + let body = response + .text() + .await + .map_err(|e| TokenError::Http(format!("failed to read response body: {e}")))?; + + if status != 200 { + // Truncate body for logging safety. Find a char boundary at + // or before byte 256 to avoid panicking on multi-byte UTF-8. + let truncated = if body.len() > 256 { + let end = (0..=256) + .rev() + .find(|&i| body.is_char_boundary(i)) + .unwrap_or(0); + format!("{}...", &body[..end]) + } else { + body + }; + error!( + token_endpoint = %config.token_endpoint, + status, + response_body = %truncated, + "OAuth2 token endpoint returned error" + ); + return Err(TokenError::EndpointError { + status, + body: truncated, + }); + } + + let token_response: OAuth2TokenResponse = serde_json::from_str(&body) + .map_err(|e| TokenError::Exchange(format!("failed to parse token response: {e}")))?; + + let access_token = token_response.access_token.ok_or_else(|| { + TokenError::Exchange("token response missing 'access_token' field".into()) + })?; + + if access_token.is_empty() { + return Err(TokenError::Exchange( + "token response contains empty 'access_token'".into(), + )); + } + + // Fail-fast: reject access tokens containing header-injection + // characters (CR, LF, NUL). The downstream `SecretResolver` also + // validates during placeholder resolution, but catching it here + // provides a clearer error and avoids caching a poisoned token. + validate_token_value(&access_token)?; + + // Validate rotated refresh token if present. A poisoned refresh + // token would be persisted to the store and break future refreshes. + let new_refresh_token = match token_response.refresh_token { + Some(rt) => { + validate_token_value(&rt).map_err(|_| { + TokenError::Exchange( + "rotated refresh token contains prohibited control characters".into(), + ) + })?; + Some(rt) + } + None => None, + }; + + // Default to 3600s if the IdP doesn't specify expires_in. + let expires_in = token_response.expires_in.unwrap_or(3600); + + Ok(TokenResult { + access_token, + expires_in_secs: expires_in, + new_refresh_token, + }) + } +} + +impl Default for TokenVendingService { + fn default() -> Self { + Self::new() + } +} + +// --------------------------------------------------------------------------- +// OAuth2 config extraction +// --------------------------------------------------------------------------- + +/// Parsed `OAuth2` configuration from a Provider record. +struct OAuth2Config { + token_endpoint: String, + grant_type: String, + client_id: String, + client_secret: String, + refresh_token: Option, + scopes: Option, +} + +impl OAuth2Config { + fn from_provider(provider: &Provider) -> Result { + let config = &provider.config; + let credentials = &provider.credentials; + + let token_endpoint = config + .get("oauth_token_endpoint") + .filter(|v| !v.trim().is_empty()) + .ok_or_else(|| { + TokenError::Config("missing oauth_token_endpoint in provider config".into()) + })? + .clone(); + + let grant_type = config + .get("oauth_grant_type") + .filter(|v| !v.trim().is_empty()) + .ok_or_else(|| { + TokenError::Config("missing oauth_grant_type in provider config".into()) + })? + .clone(); + + let client_id = credentials + .get("OAUTH_CLIENT_ID") + .filter(|v| !v.trim().is_empty()) + .ok_or_else(|| { + TokenError::Config("missing OAUTH_CLIENT_ID in provider credentials".into()) + })? + .clone(); + + let client_secret = credentials + .get("OAUTH_CLIENT_SECRET") + .filter(|v| !v.trim().is_empty()) + .ok_or_else(|| { + TokenError::Config("missing OAUTH_CLIENT_SECRET in provider credentials".into()) + })? + .clone(); + + let refresh_token = credentials.get("OAUTH_REFRESH_TOKEN").cloned(); + let scopes = config.get("oauth_scopes").cloned(); + + Ok(Self { + token_endpoint, + grant_type, + client_id, + client_secret, + refresh_token, + scopes, + }) + } +} + +// --------------------------------------------------------------------------- +// OAuth2 token response (RFC 6749 §5.1) +// --------------------------------------------------------------------------- + +#[derive(Debug, serde::Deserialize)] +struct OAuth2TokenResponse { + access_token: Option, + #[allow(dead_code)] + token_type: Option, + expires_in: Option, + refresh_token: Option, + #[allow(dead_code)] + scope: Option, +} + +// --------------------------------------------------------------------------- +// Helper functions for provider classification +// --------------------------------------------------------------------------- + +/// Reject access tokens containing header-injection characters. +/// +/// Mirrors the `validate_resolved_secret()` check in the sandbox's +/// `SecretResolver`, but applied at the token vending layer for fail-fast +/// behavior. Prevents caching a poisoned token that would be rejected +/// downstream on every request. +fn validate_token_value(value: &str) -> Result<(), TokenError> { + if value + .bytes() + .any(|b| b == b'\r' || b == b'\n' || b == b'\0') + { + return Err(TokenError::Exchange( + "access token contains prohibited control characters (CR, LF, or NUL)".into(), + )); + } + Ok(()) +} + +/// Determine if a provider uses `OAuth2` based on its config map. +pub fn is_oauth2_provider(provider: &Provider) -> bool { + provider + .config + .get("auth_method") + .is_some_and(|v| v == "oauth2") +} + +/// `OAuth2` credential keys that must never be injected into sandbox env. +pub fn is_oauth2_internal_credential(key: &str) -> bool { + matches!( + key, + "OAUTH_CLIENT_ID" | "OAUTH_CLIENT_SECRET" | "OAUTH_REFRESH_TOKEN" + ) +} + +/// Derive the access token env var key for an `OAuth2` provider. +pub fn oauth_access_token_key(provider: &Provider) -> String { + provider + .config + .get("oauth_access_token_env") + .filter(|v| !v.trim().is_empty()) + .cloned() + .unwrap_or_else(|| format!("{}_ACCESS_TOKEN", provider.r#type.to_ascii_uppercase())) +} + +/// Validate OAuth2-specific configuration at provider creation time. +/// +/// Returns `Ok(())` for non-OAuth2 providers (no-op). +#[allow(clippy::result_large_err)] +pub fn validate_oauth2_config(provider: &Provider) -> Result<(), tonic::Status> { + let config = &provider.config; + + if !is_oauth2_provider(provider) { + return Ok(()); + } + + // Required config keys. + for key in &["oauth_token_endpoint", "oauth_grant_type"] { + if config.get(*key).is_none_or(|v| v.trim().is_empty()) { + return Err(tonic::Status::invalid_argument(format!( + "OAuth2 provider requires config key '{key}'" + ))); + } + } + + let token_endpoint = config.get("oauth_token_endpoint").unwrap(); + if !token_endpoint.starts_with("https://") { + return Err(tonic::Status::invalid_argument( + "oauth_token_endpoint must use HTTPS", + )); + } + + let grant_type = config.get("oauth_grant_type").unwrap(); + match grant_type.as_str() { + "refresh_token" => { + for key in &[ + "OAUTH_CLIENT_ID", + "OAUTH_CLIENT_SECRET", + "OAUTH_REFRESH_TOKEN", + ] { + if !provider.credentials.contains_key(*key) { + return Err(tonic::Status::invalid_argument(format!( + "OAuth2 refresh_token grant requires credential '{key}'" + ))); + } + } + } + "client_credentials" => { + for key in &["OAUTH_CLIENT_ID", "OAUTH_CLIENT_SECRET"] { + if !provider.credentials.contains_key(*key) { + return Err(tonic::Status::invalid_argument(format!( + "OAuth2 client_credentials grant requires credential '{key}'" + ))); + } + } + } + other => { + return Err(tonic::Status::invalid_argument(format!( + "unsupported oauth_grant_type: '{other}' (expected 'refresh_token' or 'client_credentials')" + ))); + } + } + + Ok(()) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_oauth2_provider(grant_type: &str, include_refresh: bool) -> Provider { + let mut credentials = HashMap::new(); + credentials.insert("OAUTH_CLIENT_ID".into(), "test-client-id".into()); + credentials.insert("OAUTH_CLIENT_SECRET".into(), "test-client-secret".into()); + if include_refresh { + credentials.insert("OAUTH_REFRESH_TOKEN".into(), "test-refresh-token".into()); + } + + let mut config = HashMap::new(); + config.insert("auth_method".into(), "oauth2".into()); + config.insert("oauth_grant_type".into(), grant_type.into()); + config.insert( + "oauth_token_endpoint".into(), + "https://auth.example.com/oauth/token".into(), + ); + + Provider { + id: "test-id".into(), + name: "test-oauth".into(), + r#type: "github".into(), + credentials, + config, + } + } + + fn make_static_provider() -> Provider { + let mut credentials = HashMap::new(); + credentials.insert("GITHUB_TOKEN".into(), "ghp-static-token".into()); + + Provider { + id: "static-id".into(), + name: "test-static".into(), + r#type: "github".into(), + credentials, + config: HashMap::new(), + } + } + + // --- is_oauth2_provider --- + + #[test] + fn detects_oauth2_provider() { + let provider = make_oauth2_provider("refresh_token", true); + assert!(is_oauth2_provider(&provider)); + } + + #[test] + fn static_provider_is_not_oauth2() { + let provider = make_static_provider(); + assert!(!is_oauth2_provider(&provider)); + } + + #[test] + fn empty_config_is_not_oauth2() { + let provider = Provider { + id: String::new(), + name: "empty".into(), + r#type: "generic".into(), + credentials: HashMap::new(), + config: HashMap::new(), + }; + assert!(!is_oauth2_provider(&provider)); + } + + // --- is_oauth2_internal_credential --- + + #[test] + fn identifies_internal_credentials() { + assert!(is_oauth2_internal_credential("OAUTH_CLIENT_ID")); + assert!(is_oauth2_internal_credential("OAUTH_CLIENT_SECRET")); + assert!(is_oauth2_internal_credential("OAUTH_REFRESH_TOKEN")); + } + + #[test] + fn non_oauth_keys_are_not_internal() { + assert!(!is_oauth2_internal_credential("GITHUB_TOKEN")); + assert!(!is_oauth2_internal_credential("ANTHROPIC_API_KEY")); + assert!(!is_oauth2_internal_credential("OAUTH_ACCESS_TOKEN")); + } + + // --- oauth_access_token_key --- + + #[test] + fn derives_access_token_key_from_type() { + let provider = make_oauth2_provider("refresh_token", true); + assert_eq!(oauth_access_token_key(&provider), "GITHUB_ACCESS_TOKEN"); + } + + #[test] + fn uses_custom_access_token_key_when_configured() { + let mut provider = make_oauth2_provider("refresh_token", true); + provider + .config + .insert("oauth_access_token_env".into(), "MY_CUSTOM_TOKEN".into()); + assert_eq!(oauth_access_token_key(&provider), "MY_CUSTOM_TOKEN"); + } + + #[test] + fn ignores_empty_custom_access_token_key() { + let mut provider = make_oauth2_provider("refresh_token", true); + provider + .config + .insert("oauth_access_token_env".into(), " ".into()); + assert_eq!(oauth_access_token_key(&provider), "GITHUB_ACCESS_TOKEN"); + } + + // --- validate_oauth2_config --- + + #[test] + fn validates_refresh_token_grant() { + let provider = make_oauth2_provider("refresh_token", true); + assert!(validate_oauth2_config(&provider).is_ok()); + } + + #[test] + fn validates_client_credentials_grant() { + let provider = make_oauth2_provider("client_credentials", false); + assert!(validate_oauth2_config(&provider).is_ok()); + } + + #[test] + fn skips_validation_for_static_providers() { + let provider = make_static_provider(); + assert!(validate_oauth2_config(&provider).is_ok()); + } + + #[test] + fn rejects_missing_token_endpoint() { + let mut provider = make_oauth2_provider("refresh_token", true); + provider.config.remove("oauth_token_endpoint"); + let err = validate_oauth2_config(&provider).unwrap_err(); + assert!(err.message().contains("oauth_token_endpoint")); + } + + #[test] + fn rejects_http_token_endpoint() { + let mut provider = make_oauth2_provider("refresh_token", true); + provider.config.insert( + "oauth_token_endpoint".into(), + "http://insecure.example.com/token".into(), + ); + let err = validate_oauth2_config(&provider).unwrap_err(); + assert!(err.message().contains("HTTPS")); + } + + #[test] + fn rejects_missing_grant_type() { + let mut provider = make_oauth2_provider("refresh_token", true); + provider.config.remove("oauth_grant_type"); + let err = validate_oauth2_config(&provider).unwrap_err(); + assert!(err.message().contains("oauth_grant_type")); + } + + #[test] + fn rejects_unsupported_grant_type() { + let provider = make_oauth2_provider("implicit", false); + let err = validate_oauth2_config(&provider).unwrap_err(); + assert!(err.message().contains("implicit")); + } + + #[test] + fn rejects_refresh_token_grant_without_refresh_token() { + let provider = make_oauth2_provider("refresh_token", false); + let err = validate_oauth2_config(&provider).unwrap_err(); + assert!(err.message().contains("OAUTH_REFRESH_TOKEN")); + } + + #[test] + fn rejects_missing_client_id() { + let mut provider = make_oauth2_provider("client_credentials", false); + provider.credentials.remove("OAUTH_CLIENT_ID"); + let err = validate_oauth2_config(&provider).unwrap_err(); + assert!(err.message().contains("OAUTH_CLIENT_ID")); + } + + #[test] + fn rejects_missing_client_secret() { + let mut provider = make_oauth2_provider("client_credentials", false); + provider.credentials.remove("OAUTH_CLIENT_SECRET"); + let err = validate_oauth2_config(&provider).unwrap_err(); + assert!(err.message().contains("OAUTH_CLIENT_SECRET")); + } + + // --- OAuth2Config::from_provider --- + + #[test] + fn extracts_config_from_provider() { + let provider = make_oauth2_provider("refresh_token", true); + let config = OAuth2Config::from_provider(&provider).unwrap(); + assert_eq!( + config.token_endpoint, + "https://auth.example.com/oauth/token" + ); + assert_eq!(config.grant_type, "refresh_token"); + assert_eq!(config.client_id, "test-client-id"); + assert_eq!(config.client_secret, "test-client-secret"); + assert_eq!(config.refresh_token, Some("test-refresh-token".into())); + assert!(config.scopes.is_none()); + } + + #[test] + fn extracts_scopes_when_present() { + let mut provider = make_oauth2_provider("client_credentials", false); + provider + .config + .insert("oauth_scopes".into(), "api read_user".into()); + let config = OAuth2Config::from_provider(&provider).unwrap(); + assert_eq!(config.scopes, Some("api read_user".into())); + } + + // --- TokenVendingService (unit-level) --- + + #[tokio::test] + async fn token_vending_service_constructs() { + let _service = TokenVendingService::new(); + } + + // --- validate_token_value --- + + #[test] + fn validates_clean_token() { + assert!(validate_token_value("gho_abc123XYZ").is_ok()); + } + + #[test] + fn rejects_token_with_newline() { + let err = validate_token_value("token\ninjection").unwrap_err(); + assert!(matches!(err, TokenError::Exchange(_))); + } + + #[test] + fn rejects_token_with_carriage_return() { + let err = validate_token_value("token\rinjection").unwrap_err(); + assert!(matches!(err, TokenError::Exchange(_))); + } + + #[test] + fn rejects_token_with_null_byte() { + let err = validate_token_value("token\0injection").unwrap_err(); + assert!(matches!(err, TokenError::Exchange(_))); + } + + #[test] + fn cached_token_validity() { + let valid = CachedToken { + access_token: "tok".into(), + expires_at: tokio::time::Instant::now() + std::time::Duration::from_secs(3600), + ttl_secs: 3600, + }; + assert!(valid.is_valid()); + + let expired = CachedToken { + access_token: "tok".into(), + expires_at: tokio::time::Instant::now(), + ttl_secs: 3600, + }; + assert!(!expired.is_valid()); + } + + // --- Integration tests (mock HTTP server) --- + + #[tokio::test] + async fn exchange_token_with_mock_server() { + use tokio::net::TcpListener; + + // Start a minimal mock HTTPS-less HTTP server for testing. + // In production, token endpoints are HTTPS. For unit tests we + // bypass the HTTPS requirement by constructing OAuth2Config directly. + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + // Spawn mock server that returns a valid token response. + tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut buf = vec![0u8; 4096]; + let n = tokio::io::AsyncReadExt::read(&mut stream, &mut buf) + .await + .unwrap(); + let request = String::from_utf8_lossy(&buf[..n]); + assert!(request.contains("grant_type=client_credentials")); + assert!(request.contains("client_id=test-id")); + assert!(request.contains("client_secret=test-secret")); + + let body = + r#"{"access_token":"mock-access-token","token_type":"Bearer","expires_in":3600}"#; + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + body.len(), + body + ); + tokio::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()) + .await + .unwrap(); + }); + + let service = TokenVendingService::new(); + let config = OAuth2Config { + token_endpoint: format!("http://127.0.0.1:{}", addr.port()), + grant_type: "client_credentials".into(), + client_id: "test-id".into(), + client_secret: "test-secret".into(), + refresh_token: None, + scopes: None, + }; + + let result = service.exchange_token(&config).await.unwrap(); + assert_eq!(result.access_token, "mock-access-token"); + assert_eq!(result.expires_in_secs, 3600); + assert!(result.new_refresh_token.is_none()); + } + + #[tokio::test] + async fn exchange_token_with_refresh_rotation() { + use tokio::net::TcpListener; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut buf = vec![0u8; 4096]; + let _ = tokio::io::AsyncReadExt::read(&mut stream, &mut buf) + .await + .unwrap(); + + let body = r#"{"access_token":"new-access","token_type":"Bearer","expires_in":7200,"refresh_token":"new-refresh"}"#; + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + body.len(), + body + ); + tokio::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()) + .await + .unwrap(); + }); + + let service = TokenVendingService::new(); + let config = OAuth2Config { + token_endpoint: format!("http://127.0.0.1:{}", addr.port()), + grant_type: "refresh_token".into(), + client_id: "test-id".into(), + client_secret: "test-secret".into(), + refresh_token: Some("old-refresh".into()), + scopes: None, + }; + + let result = service.exchange_token(&config).await.unwrap(); + assert_eq!(result.access_token, "new-access"); + assert_eq!(result.expires_in_secs, 7200); + assert_eq!(result.new_refresh_token, Some("new-refresh".into())); + } + + #[tokio::test] + async fn exchange_token_error_response() { + use tokio::net::TcpListener; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut buf = vec![0u8; 4096]; + let _ = tokio::io::AsyncReadExt::read(&mut stream, &mut buf) + .await + .unwrap(); + + let body = r#"{"error":"invalid_grant","error_description":"refresh token revoked"}"#; + let response = format!( + "HTTP/1.1 400 Bad Request\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + body.len(), + body + ); + tokio::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()) + .await + .unwrap(); + }); + + let service = TokenVendingService::new(); + let config = OAuth2Config { + token_endpoint: format!("http://127.0.0.1:{}", addr.port()), + grant_type: "refresh_token".into(), + client_id: "test-id".into(), + client_secret: "test-secret".into(), + refresh_token: Some("revoked-token".into()), + scopes: None, + }; + + let err = service.exchange_token(&config).await.unwrap_err(); + assert!(matches!(err, TokenError::EndpointError { status: 400, .. })); + } + + #[tokio::test] + async fn caching_returns_same_token() { + use tokio::net::TcpListener; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + // Server only handles one connection — second call must hit cache. + tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut buf = vec![0u8; 4096]; + let _ = tokio::io::AsyncReadExt::read(&mut stream, &mut buf) + .await + .unwrap(); + + let body = r#"{"access_token":"cached-token","token_type":"Bearer","expires_in":3600}"#; + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + body.len(), + body + ); + tokio::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()) + .await + .unwrap(); + }); + + let service = TokenVendingService::new(); + let provider = Provider { + id: "cache-test".into(), + name: "cache-test".into(), + r#type: "github".into(), + credentials: [ + ("OAUTH_CLIENT_ID".into(), "id".into()), + ("OAUTH_CLIENT_SECRET".into(), "secret".into()), + ] + .into_iter() + .collect(), + config: [ + ("auth_method".into(), "oauth2".into()), + ("oauth_grant_type".into(), "client_credentials".into()), + ( + "oauth_token_endpoint".into(), + format!("http://127.0.0.1:{}", addr.port()), + ), + ] + .into_iter() + .collect(), + }; + + let result1 = service.get_or_refresh(&provider).await.unwrap(); + assert_eq!(result1.access_token, "cached-token"); + + // Second call should return cached token (server has closed). + let result2 = service.get_or_refresh(&provider).await.unwrap(); + assert_eq!(result2.access_token, "cached-token"); + } +} diff --git a/proto/openshell.proto b/proto/openshell.proto index 04f70502..af86efda 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -433,6 +433,11 @@ message GetSandboxProviderEnvironmentRequest { message GetSandboxProviderEnvironmentResponse { // Provider credential environment variables. map environment = 1; + // Seconds until the supervisor should re-fetch credentials. + // 0 means all credentials are static — do not poll. + // The server computes this as min(ttl * 0.5) across all + // time-bounded credentials attached to the sandbox. + uint32 refresh_after_secs = 2; } // ---------------------------------------------------------------------------