diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 000000000..4dfdc42a8 --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,8 @@ +{ + "permissions": { + "allow": [ + "Bash(cargo build:*)", + "Bash(cargo test:*)" + ] + } +} diff --git a/AGENTS.md b/AGENTS.md index ee1fd03cc..51d588895 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -100,6 +100,66 @@ These pipelines connect skills into end-to-end workflows. Individual skill files - Bollard (the Rust Docker client library) connects to Podman via its Docker-compatible API — no separate Podman client is needed. - When referencing host gateway aliases, use both `host.docker.internal` and `host.containers.internal` for cross-runtime compatibility. +### Debugging with Podman + +When using Podman (especially on macOS where Podman runs in a VM), debugging requires accessing the Podman machine: + +**Accessing the Podman VM:** +```bash +podman machine ssh +``` + +**Common debugging commands:** +```bash +# Check cluster logs via kubectl (inside podman machine or via ssh) +podman machine ssh -- "podman exec openshell-cluster-openshell kubectl logs -n openshell " + +# Check running containers +podman machine ssh -- "podman ps -a" + +# Check images and timestamps +podman machine ssh -- "podman images" + +# Verify binary in cluster +podman machine ssh -- "podman exec openshell-cluster-openshell ls -lh /opt/openshell/bin/openshell-sandbox" + +# Check for specific strings in binary +podman machine ssh -- "podman exec openshell-cluster-openshell strings /opt/openshell/bin/openshell-sandbox | grep " + +# Get sandbox pod logs +podman machine ssh -- "podman exec openshell-cluster-openshell kubectl logs -n openshell --container agent --tail 100" +``` + +**Important: Cross-compilation requirement** + +Running `cargo build --release` on macOS produces a macOS binary, not a Linux binary. The cluster runs Linux containers, so using a macOS binary causes "exec format error". + +- ✅ **Correct:** Use `mise run cluster:build:full` which handles cross-compilation +- ❌ **Incorrect:** `cargo build --release` then manually copying the binary + +**Fast iteration workflow:** + +After modifying Rust code in `crates/openshell-sandbox/`: + +```bash +# Force clean rebuild to avoid cargo cache issues +cargo clean -p openshell-sandbox + +# Full cluster rebuild (handles cross-compilation) +mise run cluster:build:full + +# Recreate sandbox to pick up new binary +openshell sandbox delete +openshell sandbox create --name --provider --policy -- bash +``` + +**Common issues:** + +- **"exec format error"**: Binary is for wrong architecture (macOS vs Linux) +- **Binary not updating**: Cargo is using cached artifacts - run `cargo clean -p openshell-sandbox` +- **Empty logs**: `RUST_LOG` environment variable not set in sandbox agent - logs are disabled by default +- **Changes not reflected**: Sandbox was created before cluster rebuild - always recreate sandboxes after deploying new binaries + ## Cluster Infrastructure Changes - If you change cluster bootstrap infrastructure (e.g., `openshell-bootstrap` crate, `deploy/docker/Dockerfile.images`, `cluster-entrypoint.sh`, `cluster-healthcheck.sh`, deploy logic in `openshell-cli`), update the `debug-openshell-cluster` skill in `.agents/skills/debug-openshell-cluster/SKILL.md` to reflect those changes. diff --git a/Cargo.lock b/Cargo.lock index fd3b68d1e..88fc45820 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2929,8 +2929,15 @@ dependencies = [ name = "openshell-providers" version = "0.0.0" dependencies = [ + "async-trait", + "chrono", "openshell-core", + "reqwest", + "serde", + "serde_json", "thiserror 2.0.18", + "tokio", + "tracing", ] [[package]] @@ -3019,6 +3026,7 @@ dependencies = [ "miette", "openshell-core", "openshell-policy", + "openshell-providers", "openshell-router", "petname", "pin-project-lite", diff --git a/Cargo.toml b/Cargo.toml index 3380e040b..90fec8c83 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -94,6 +94,12 @@ k8s-openapi = { version = "0.21.1", features = ["v1_26"] } # IDs uuid = { version = "1.10", features = ["v4"] } +# Time/Date +chrono = "0.4" + +# Async +async-trait = "0.1" + [workspace.lints.rust] unsafe_code = "warn" rust_2018_idioms = { level = "warn", priority = -1 } diff --git a/cleanup-openshell-podman-macos.sh b/cleanup-openshell-podman-macos.sh index 43efd8dd5..d6b80a411 100755 --- a/cleanup-openshell-podman-macos.sh +++ b/cleanup-openshell-podman-macos.sh @@ -11,19 +11,43 @@ set -e echo "=== OpenShell Podman Cleanup Script ===" echo "" +# Delete all sandboxes first (before destroying gateway) +echo "Deleting all sandboxes..." +if command -v openshell &>/dev/null; then + # Get list of sandboxes and delete each one + openshell sandbox list --no-header 2>/dev/null | awk '{print $1}' | while read -r sandbox; do + if [ -n "$sandbox" ]; then + echo " Deleting sandbox: $sandbox" + openshell sandbox delete "$sandbox" 2>/dev/null || true + fi + done +fi + # Destroy OpenShell gateway (if it exists) echo "Destroying OpenShell gateway..." if command -v openshell &>/dev/null; then openshell gateway destroy --name openshell 2>/dev/null || true fi -# Stop and remove any running OpenShell containers -echo "Stopping OpenShell containers..." -podman ps -a | grep openshell | awk '{print $1}' | xargs -r podman rm -f || true +# Stop and remove cluster container +echo "Stopping cluster container..." +podman stop openshell-cluster-openshell 2>/dev/null || true +podman rm openshell-cluster-openshell 2>/dev/null || true + +# Stop and remove local registry container +echo "Stopping local registry..." +podman stop openshell-local-registry 2>/dev/null || true +podman rm openshell-local-registry 2>/dev/null || true + +# Stop and remove any other OpenShell containers +echo "Cleaning up remaining OpenShell containers..." +podman ps -a | grep openshell | awk '{print $1}' | xargs -r podman rm -f 2>/dev/null || true # Remove OpenShell images echo "Removing OpenShell images..." -podman images | grep -E "openshell|cluster" | awk '{print $3}' | xargs -r podman rmi -f || true +podman rmi localhost/openshell/cluster:dev 2>/dev/null || true +podman rmi localhost/openshell/gateway:dev 2>/dev/null || true +podman images | grep -E "openshell|127.0.0.1:5000/openshell" | awk '{print $3}' | xargs -r podman rmi -f 2>/dev/null || true # Remove CLI binary echo "Removing CLI binary..." @@ -41,8 +65,11 @@ rm -rf ~/.openshell echo "Removing build artifacts..." SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" cd "$SCRIPT_DIR" -rm -rf target/ -rm -rf deploy/docker/.build/ +if command -v cargo &>/dev/null; then + echo " Running cargo clean..." + cargo clean 2>/dev/null || true +fi +rm -rf deploy/docker/.build/ 2>/dev/null || true # Clean Podman cache echo "Cleaning Podman build cache..." @@ -51,6 +78,13 @@ podman system prune -af --volumes echo "" echo "=== Cleanup Complete ===" echo "" +echo "OpenShell containers, images, and configuration have been removed." +echo "" +echo "To reinstall OpenShell:" +echo " 1. source scripts/podman.env" +echo " 2. mise run cluster:build:full" +echo " 3. cargo install --path crates/openshell-cli --root ~/.local" +echo "" echo "To completely remove the OpenShell Podman machine:" echo " podman machine stop openshell" echo " podman machine rm openshell" diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index d1cd7fd69..9ecc9b000 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -615,6 +615,7 @@ enum CliProviderType { Gitlab, Github, Outlook, + Vertex, } #[derive(Clone, Debug, ValueEnum)] @@ -646,6 +647,7 @@ impl CliProviderType { Self::Gitlab => "gitlab", Self::Github => "github", Self::Outlook => "outlook", + Self::Vertex => "vertex", } } } diff --git a/crates/openshell-core/src/inference.rs b/crates/openshell-core/src/inference.rs index a06c427f8..0973f25db 100644 --- a/crates/openshell-core/src/inference.rs +++ b/crates/openshell-core/src/inference.rs @@ -86,6 +86,19 @@ static NVIDIA_PROFILE: InferenceProviderProfile = InferenceProviderProfile { default_headers: &[], }; +static VERTEX_PROFILE: InferenceProviderProfile = InferenceProviderProfile { + provider_type: "vertex", + // Base URL template - actual URL constructed at request time with project/region/model + default_base_url: "https://us-central1-aiplatform.googleapis.com/v1", + protocols: ANTHROPIC_PROTOCOLS, + // Look for OAuth token first, fallback to project ID (for manual config) + credential_key_names: &["VERTEX_OAUTH_TOKEN", "ANTHROPIC_VERTEX_PROJECT_ID"], + base_url_config_keys: &["VERTEX_BASE_URL", "ANTHROPIC_VERTEX_REGION"], + // Vertex uses OAuth Bearer tokens, not x-api-key + auth: AuthHeader::Bearer, + default_headers: &[("anthropic-version", "vertex-2023-10-16")], +}; + /// Look up the inference provider profile for a given provider type. /// /// Returns `None` for provider types that don't support inference routing @@ -95,6 +108,7 @@ pub fn profile_for(provider_type: &str) -> Option<&'static InferenceProviderProf "openai" => Some(&OPENAI_PROFILE), "anthropic" => Some(&ANTHROPIC_PROFILE), "nvidia" => Some(&NVIDIA_PROFILE), + "vertex" => Some(&VERTEX_PROFILE), _ => None, } } @@ -176,6 +190,7 @@ mod tests { assert!(profile_for("openai").is_some()); assert!(profile_for("anthropic").is_some()); assert!(profile_for("nvidia").is_some()); + assert!(profile_for("vertex").is_some()); assert!(profile_for("OpenAI").is_some()); // case insensitive } diff --git a/crates/openshell-policy/src/lib.rs b/crates/openshell-policy/src/lib.rs index 9cf543bdf..b82c00a5d 100644 --- a/crates/openshell-policy/src/lib.rs +++ b/crates/openshell-policy/src/lib.rs @@ -16,7 +16,8 @@ use std::path::Path; use miette::{IntoDiagnostic, Result, WrapErr}; use openshell_core::proto::{ FilesystemPolicy, L7Allow, L7QueryMatcher, L7Rule, LandlockPolicy, NetworkBinary, - NetworkEndpoint, NetworkPolicyRule, ProcessPolicy, SandboxPolicy, + NetworkEndpoint, NetworkPolicyRule, OAuthCredentialsPolicy, OAuthInjectionConfig, + ProcessPolicy, SandboxPolicy, }; use serde::{Deserialize, Serialize}; @@ -36,6 +37,8 @@ struct PolicyFile { process: Option, #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] network_policies: BTreeMap, + #[serde(default, skip_serializing_if = "Option::is_none")] + oauth_credentials: Option, } #[derive(Debug, Serialize, Deserialize)] @@ -65,6 +68,25 @@ struct ProcessDef { run_as_group: String, } +#[derive(Debug, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +struct OAuthCredentialsDef { + #[serde(default)] + auto_refresh: bool, + #[serde(default, skip_serializing_if = "is_zero_i64")] + refresh_margin_seconds: i64, + #[serde(default, skip_serializing_if = "is_zero_i64")] + max_lifetime_seconds: i64, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +struct OAuthInjectionConfigDef { + token_env_var: String, + #[serde(default, skip_serializing_if = "String::is_empty")] + header_format: String, +} + #[derive(Debug, Serialize, Deserialize)] #[serde(deny_unknown_fields)] struct NetworkPolicyRuleDef { @@ -100,12 +122,18 @@ struct NetworkEndpointDef { rules: Vec, #[serde(default, skip_serializing_if = "Vec::is_empty")] allowed_ips: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + oauth: Option, } fn is_zero(v: &u16) -> bool { *v == 0 } +fn is_zero_i64(v: &i64) -> bool { + *v == 0 +} + #[derive(Debug, Serialize, Deserialize)] #[serde(deny_unknown_fields)] struct L7RuleDef { @@ -185,6 +213,14 @@ fn to_proto(raw: PolicyFile) -> SandboxPolicy { tls: e.tls, enforcement: e.enforcement, access: e.access, + oauth: e.oauth.map(|oauth| OAuthInjectionConfig { + token_env_var: oauth.token_env_var, + header_format: if oauth.header_format.is_empty() { + "Bearer {token}".to_string() + } else { + oauth.header_format + }, + }), rules: e .rules .into_iter() @@ -245,6 +281,11 @@ fn to_proto(raw: PolicyFile) -> SandboxPolicy { run_as_group: p.run_as_group, }), network_policies, + oauth_credentials: raw.oauth_credentials.map(|oauth| OAuthCredentialsPolicy { + auto_refresh: oauth.auto_refresh, + refresh_margin_seconds: oauth.refresh_margin_seconds, + max_lifetime_seconds: oauth.max_lifetime_seconds, + }), } } @@ -330,6 +371,10 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { }) .collect(), allowed_ips: e.allowed_ips.clone(), + oauth: e.oauth.as_ref().map(|oauth| OAuthInjectionConfigDef { + token_env_var: oauth.token_env_var.clone(), + header_format: oauth.header_format.clone(), + }), } }) .collect(), @@ -346,12 +391,28 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { }) .collect(); + let oauth_credentials = policy.oauth_credentials.as_ref().and_then(|oauth| { + if !oauth.auto_refresh + && oauth.refresh_margin_seconds == 0 + && oauth.max_lifetime_seconds == 0 + { + None + } else { + Some(OAuthCredentialsDef { + auto_refresh: oauth.auto_refresh, + refresh_margin_seconds: oauth.refresh_margin_seconds, + max_lifetime_seconds: oauth.max_lifetime_seconds, + }) + } + }); + PolicyFile { version: policy.version, filesystem_policy, landlock, process, network_policies, + oauth_credentials, } } @@ -448,6 +509,7 @@ pub fn restrictive_default_policy() -> SandboxPolicy { run_as_group: "sandbox".into(), }), network_policies: HashMap::new(), + oauth_credentials: None, } } @@ -1009,6 +1071,7 @@ network_policies: filesystem: None, landlock: None, network_policies: HashMap::new(), + oauth_credentials: None, }; assert!(validate_sandbox_policy(&policy).is_ok()); } @@ -1227,3 +1290,100 @@ network_policies: ); } } + +#[test] +fn parse_oauth_injection_config() { + let yaml = r#" +version: 1 +network_policies: + test: + name: test-oauth + endpoints: + - host: api.example.com + port: 443 + protocol: rest + oauth: + token_env_var: TEST_ACCESS_TOKEN + header_format: "Bearer {token}" + binaries: + - path: /usr/bin/curl +"#; + let policy = parse_sandbox_policy(yaml).expect("parse failed"); + let rule = policy + .network_policies + .get("test") + .expect("policy not found"); + let endpoint = rule.endpoints.first().expect("no endpoints"); + let oauth = endpoint.oauth.as_ref().expect("no oauth config"); + + assert_eq!(oauth.token_env_var, "TEST_ACCESS_TOKEN"); + assert_eq!(oauth.header_format, "Bearer {token}"); +} + +#[test] +fn round_trip_oauth_injection_config() { + let yaml = r#" +version: 1 +network_policies: + test: + name: test-oauth + endpoints: + - host: api.example.com + port: 443 + protocol: rest + oauth: + token_env_var: MY_TOKEN + header_format: "Bearer {token}" + binaries: + - path: /usr/bin/curl +"#; + let proto1 = parse_sandbox_policy(yaml).expect("parse failed"); + let yaml_out = serialize_sandbox_policy(&proto1).expect("serialize failed"); + let proto2 = parse_sandbox_policy(&yaml_out).expect("re-parse failed"); + + let oauth1 = proto1.network_policies["test"].endpoints[0] + .oauth + .as_ref() + .unwrap(); + let oauth2 = proto2.network_policies["test"].endpoints[0] + .oauth + .as_ref() + .unwrap(); + + assert_eq!(oauth1.token_env_var, oauth2.token_env_var); + assert_eq!(oauth1.header_format, oauth2.header_format); +} + +#[test] +fn parse_vertex_example_policy() { + let yaml = std::fs::read_to_string("../../examples/vertex-ai/sandbox-policy.yaml") + .expect("failed to read example policy"); + let policy = parse_sandbox_policy(&yaml).expect("parse failed"); + + let rule = policy + .network_policies + .get("google_vertex") + .expect("google_vertex policy not found"); + assert!(!rule.endpoints.is_empty(), "should have endpoints"); + + // Check that aiplatform.googleapis.com endpoints have OAuth config + let vertex_endpoints: Vec<_> = rule + .endpoints + .iter() + .filter(|e| e.host.contains("aiplatform.googleapis.com")) + .collect(); + + assert!( + !vertex_endpoints.is_empty(), + "should have aiplatform endpoints" + ); + + for endpoint in vertex_endpoints { + let oauth = endpoint + .oauth + .as_ref() + .expect("aiplatform endpoint should have OAuth config"); + assert_eq!(oauth.token_env_var, "VERTEX_ACCESS_TOKEN"); + assert_eq!(oauth.header_format, "Bearer {token}"); + } +} diff --git a/crates/openshell-providers/Cargo.toml b/crates/openshell-providers/Cargo.toml index 41f9ed6c0..b2c4c9d07 100644 --- a/crates/openshell-providers/Cargo.toml +++ b/crates/openshell-providers/Cargo.toml @@ -14,5 +14,14 @@ repository.workspace = true openshell-core = { path = "../openshell-core" } thiserror = { workspace = true } +# Runtime token exchange dependencies +tokio = { workspace = true } +async-trait = { workspace = true } +chrono = { workspace = true } +reqwest = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +tracing = { workspace = true } + [lints] workspace = true diff --git a/crates/openshell-providers/src/lib.rs b/crates/openshell-providers/src/lib.rs index e2bcc0c09..b90dc64f0 100644 --- a/crates/openshell-providers/src/lib.rs +++ b/crates/openshell-providers/src/lib.rs @@ -9,6 +9,18 @@ mod providers; #[cfg(test)] mod test_helpers; +// Runtime credential system +pub mod runtime; +pub mod secret_store; +pub mod stores; +pub mod token_cache; + +// Re-export specific providers for direct use +pub mod vertex { + pub use crate::providers::vertex::*; +} + +use async_trait::async_trait; use std::collections::HashMap; use std::path::Path; @@ -17,6 +29,12 @@ pub use openshell_core::proto::Provider; pub use context::{DiscoveryContext, RealDiscoveryContext}; pub use discovery::discover_with_spec; +// Re-export runtime types +pub use runtime::{RuntimeError, RuntimeResult, TokenResponse}; +pub use secret_store::{SecretError, SecretResult, SecretStore}; +pub use stores::DatabaseStore; +pub use token_cache::TokenCache; + #[derive(Debug, thiserror::Error)] pub enum ProviderError { #[error("unsupported provider type: {0}")] @@ -42,6 +60,7 @@ pub struct ProviderDiscoverySpec { pub credential_env_vars: &'static [&'static str], } +#[async_trait] pub trait ProviderPlugin: Send + Sync { /// Canonical provider id (for example: "claude", "gitlab"). fn id(&self) -> &'static str; @@ -64,6 +83,22 @@ pub trait ProviderPlugin: Send + Sync { fn apply_to_sandbox(&self, _provider: &Provider) -> Result<(), ProviderError> { Ok(()) } + + /// Get a runtime token by fetching and interpreting secrets from storage. + /// + /// This is called during sandbox execution to exchange stored credentials + /// for access tokens. The provider knows how to interpret its credential format: + /// - Vertex: fetches VERTEX_ADC from store, exchanges for OAuth token + /// - Anthropic: fetches API key from store, returns it directly + /// - OpenAI: fetches API key from store, returns it directly + /// + /// Default implementation returns NotConfigured error - providers that need + /// runtime token exchange must implement this. + async fn get_runtime_token(&self, _store: &dyn SecretStore) -> RuntimeResult { + Err(RuntimeError::NotConfigured( + "This provider does not support runtime token exchange".to_string(), + )) + } } #[derive(Default)] @@ -86,6 +121,7 @@ impl ProviderRegistry { registry.register(providers::gitlab::GitlabProvider); registry.register(providers::github::GithubProvider); registry.register(providers::outlook::OutlookProvider); + registry.register(providers::vertex::VertexProvider::new()); registry } @@ -138,6 +174,7 @@ pub fn normalize_provider_type(input: &str) -> Option<&'static str> { "gitlab" | "glab" => Some("gitlab"), "github" | "gh" => Some("github"), "outlook" => Some("outlook"), + "vertex" => Some("vertex"), _ => None, } } diff --git a/crates/openshell-providers/src/providers/mod.rs b/crates/openshell-providers/src/providers/mod.rs index 6fe395135..19f9c54a5 100644 --- a/crates/openshell-providers/src/providers/mod.rs +++ b/crates/openshell-providers/src/providers/mod.rs @@ -12,3 +12,4 @@ pub mod nvidia; pub mod openai; pub mod opencode; pub mod outlook; +pub mod vertex; diff --git a/crates/openshell-providers/src/providers/vertex.rs b/crates/openshell-providers/src/providers/vertex.rs new file mode 100644 index 000000000..82a565e40 --- /dev/null +++ b/crates/openshell-providers/src/providers/vertex.rs @@ -0,0 +1,269 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + DiscoveredProvider, ProviderDiscoverySpec, ProviderError, ProviderPlugin, RealDiscoveryContext, + RuntimeError, RuntimeResult, SecretStore, TokenResponse, discover_with_spec, +}; +use async_trait::async_trait; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::path::PathBuf; + +pub struct VertexProvider { + client: Client, +} + +impl VertexProvider { + #[must_use] + pub fn new() -> Self { + Self { + client: Client::new(), + } + } + + /// Get the standard ADC file path + fn get_standard_adc_path() -> Option { + let home = std::env::var("HOME").ok()?; + Some(PathBuf::from(home).join(".config/gcloud/application_default_credentials.json")) + } + + /// Try to read ADC from standard gcloud location + fn read_adc_from_standard_path() -> Option { + let path = Self::get_standard_adc_path()?; + std::fs::read_to_string(path).ok() + } + + /// Validate ADC credentials by testing token exchange + /// This is synchronous and blocks during provider creation + fn validate_adc_sync(adc_json: &str) -> Result<(), ProviderError> { + // Parse ADC JSON + let adc: AdcCredentials = serde_json::from_str(adc_json).map_err(|e| { + ProviderError::UnsupportedProvider(format!( + "Invalid ADC format: {}. Expected Google Application Default Credentials JSON from 'gcloud auth application-default login'", + e + )) + })?; + + // Test token exchange - use current runtime if available, otherwise create one + let result = if let Ok(handle) = tokio::runtime::Handle::try_current() { + // Already in a runtime - use block_in_place to avoid nested runtime error + tokio::task::block_in_place(|| handle.block_on(Self::validate_adc_async(adc))) + } else { + // Not in a runtime - create one + let runtime = tokio::runtime::Runtime::new().map_err(|e| { + ProviderError::UnsupportedProvider(format!( + "Failed to create runtime for validation: {}", + e + )) + })?; + runtime.block_on(Self::validate_adc_async(adc)) + }; + + result + } + + /// Async helper for ADC validation + async fn validate_adc_async(adc: AdcCredentials) -> Result<(), ProviderError> { + let client = Client::new(); + let params = [ + ("client_id", adc.client_id.as_str()), + ("client_secret", adc.client_secret.as_str()), + ("refresh_token", adc.refresh_token.as_str()), + ("grant_type", "refresh_token"), + ]; + + let response = client + .post("https://oauth2.googleapis.com/token") + .form(¶ms) + .send() + .await + .map_err(|e| { + ProviderError::UnsupportedProvider(format!( + "Failed to connect to Google OAuth: {}. Check your internet connection.", + e + )) + })?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + return Err(ProviderError::UnsupportedProvider(format!( + "ADC credentials rejected by Google OAuth (status {}): {}. Your credentials may be expired or invalid. Run: gcloud auth application-default login", + status, body + ))); + } + + // Successfully exchanged for token + tracing::info!("✅ Verified Vertex ADC credentials with Google OAuth"); + Ok(()) + } + + /// Exchange ADC credentials for OAuth access token + async fn exchange_adc_for_token(&self, adc: AdcCredentials) -> RuntimeResult { + let params = [ + ("client_id", adc.client_id.as_str()), + ("client_secret", adc.client_secret.as_str()), + ("refresh_token", adc.refresh_token.as_str()), + ("grant_type", "refresh_token"), + ]; + + let response = self + .client + .post("https://oauth2.googleapis.com/token") + .form(¶ms) + .send() + .await?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + return Err(RuntimeError::AuthFailed(format!( + "OAuth token request failed with status {}: {}", + status, body + ))); + } + + let token_response: GoogleTokenResponse = response.json().await?; + + Ok(TokenResponse { + access_token: token_response.access_token.trim().to_string(), + token_type: token_response.token_type, + expires_in: token_response.expires_in, + metadata: HashMap::new(), + }) + } +} + +impl Default for VertexProvider { + fn default() -> Self { + Self::new() + } +} + +pub const SPEC: ProviderDiscoverySpec = ProviderDiscoverySpec { + id: "vertex", + credential_env_vars: &["ANTHROPIC_VERTEX_PROJECT_ID"], +}; + +// Additional config keys for Vertex AI +const VERTEX_CONFIG_KEYS: &[&str] = &["ANTHROPIC_VERTEX_REGION"]; + +/// ADC (Application Default Credentials) format from gcloud +#[derive(Debug, Clone, Serialize, Deserialize)] +struct AdcCredentials { + client_id: String, + client_secret: String, + refresh_token: String, + #[serde(rename = "type")] + cred_type: String, +} + +/// Google OAuth token response +#[derive(Debug, Deserialize)] +struct GoogleTokenResponse { + access_token: String, + token_type: String, + expires_in: u64, +} + +#[async_trait] +impl ProviderPlugin for VertexProvider { + fn id(&self) -> &'static str { + SPEC.id + } + + fn discover_existing(&self) -> Result, ProviderError> { + let mut discovered = discover_with_spec(&SPEC, &RealDiscoveryContext)?; + + // Add region config if present + if let Some(ref mut provider) = discovered { + for &key in VERTEX_CONFIG_KEYS { + if let Ok(value) = std::env::var(key) { + provider.config.insert(key.to_string(), value); + } + } + + // Set CLAUDE_CODE_USE_VERTEX=1 to enable Vertex AI in claude CLI + // Must be in credentials (not config) to be injected into sandbox environment + provider + .credentials + .insert("CLAUDE_CODE_USE_VERTEX".to_string(), "1".to_string()); + + // Try to discover ADC credentials + // Priority: + // 1. VERTEX_ADC environment variable (explicit override) + // 2. Standard gcloud ADC path: ~/.config/gcloud/application_default_credentials.json + let adc_result = if let Ok(adc) = std::env::var("VERTEX_ADC") { + tracing::debug!("discovered VERTEX_ADC from environment variable"); + Some(adc) + } else if let Some(adc) = Self::read_adc_from_standard_path() { + tracing::debug!("discovered ADC from standard gcloud path"); + Some(adc) + } else { + None + }; + + match adc_result { + Some(adc_json) => { + // Validate ADC by testing token exchange with Google OAuth + Self::validate_adc_sync(&adc_json)?; + + provider + .credentials + .insert("VERTEX_ADC".to_string(), adc_json); + tracing::info!("✅ Validated and stored Vertex ADC credentials"); + } + None => { + return Err(ProviderError::UnsupportedProvider( + "Vertex ADC credentials not found. Run one of:\n \ + 1. gcloud auth application-default login (creates ~/.config/gcloud/application_default_credentials.json)\n \ + 2. export VERTEX_ADC=\"$(cat /path/to/adc.json)\"\n \ + 3. openshell provider create --name vertex --type vertex --credential VERTEX_ADC=\"$(cat /path/to/adc.json)\"".to_string() + )); + } + } + } + + Ok(discovered) + } + + fn credential_env_vars(&self) -> &'static [&'static str] { + SPEC.credential_env_vars + } + + async fn get_runtime_token(&self, store: &dyn SecretStore) -> RuntimeResult { + tracing::debug!("fetching runtime token for vertex provider"); + + // Get ADC from secret store + let adc_json = store.get("VERTEX_ADC").await?; + + // Parse ADC and exchange for OAuth token + let adc: AdcCredentials = serde_json::from_str(&adc_json) + .map_err(|e| RuntimeError::InvalidResponse(format!("Invalid ADC format: {}", e)))?; + + tracing::info!("exchanging ADC for OAuth token"); + self.exchange_adc_for_token(adc).await + } +} + +#[cfg(test)] +mod tests { + use super::SPEC; + use crate::discover_with_spec; + use crate::test_helpers::MockDiscoveryContext; + + #[test] + fn discovers_vertex_env_credentials() { + let ctx = + MockDiscoveryContext::new().with_env("ANTHROPIC_VERTEX_PROJECT_ID", "my-gcp-project"); + let discovered = discover_with_spec(&SPEC, &ctx) + .expect("discovery") + .expect("provider"); + assert_eq!( + discovered.credentials.get("ANTHROPIC_VERTEX_PROJECT_ID"), + Some(&"my-gcp-project".to_string()) + ); + } +} diff --git a/crates/openshell-providers/src/runtime.rs b/crates/openshell-providers/src/runtime.rs new file mode 100644 index 000000000..9f5a9f6be --- /dev/null +++ b/crates/openshell-providers/src/runtime.rs @@ -0,0 +1,62 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Runtime credential operations for providers. +//! +//! This module defines the runtime phase where providers fetch and exchange +//! credentials for access tokens during sandbox execution. + +use std::collections::HashMap; + +/// Standard response format for runtime token operations +#[derive(Debug, Clone)] +pub struct TokenResponse { + /// The actual token/secret value + pub access_token: String, + + /// Token type (e.g., "Bearer") + pub token_type: String, + + /// Seconds until expiration (from now) + pub expires_in: u64, + + /// Provider-specific metadata (e.g., project_id, region) + pub metadata: HashMap, +} + +/// Result type for runtime operations +pub type RuntimeResult = Result; + +/// Errors that can occur during runtime credential operations +#[derive(Debug, thiserror::Error)] +pub enum RuntimeError { + #[error("provider not configured: {0}")] + NotConfigured(String), + + #[error("network error: {0}")] + Network(String), + + #[error("authentication failed: {0}")] + AuthFailed(String), + + #[error("token expired")] + Expired, + + #[error("invalid response: {0}")] + InvalidResponse(String), + + #[error("secret store error: {0}")] + SecretStore(#[from] crate::secret_store::SecretError), +} + +impl From for RuntimeError { + fn from(e: reqwest::Error) -> Self { + RuntimeError::Network(e.to_string()) + } +} + +impl From for RuntimeError { + fn from(e: serde_json::Error) -> Self { + RuntimeError::InvalidResponse(e.to_string()) + } +} diff --git a/crates/openshell-providers/src/secret_store.rs b/crates/openshell-providers/src/secret_store.rs new file mode 100644 index 000000000..36dd98b03 --- /dev/null +++ b/crates/openshell-providers/src/secret_store.rs @@ -0,0 +1,56 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Generic secret storage interface. +//! +//! This module defines the storage layer for secrets/credentials. +//! Storage implementations are completely generic - they don't know about +//! provider-specific credential formats (ADC, API keys, etc.). +//! +//! The provider plugins (VertexProvider, AnthropicProvider, etc.) know how +//! to interpret the secrets retrieved from storage. + +use async_trait::async_trait; + +/// Result type for secret store operations +pub type SecretResult = Result; + +/// Errors that can occur during secret storage operations +#[derive(Debug, thiserror::Error)] +pub enum SecretError { + #[error("secret not found: {0}")] + NotFound(String), + + #[error("storage unavailable: {0}")] + Unavailable(String), + + #[error("access denied: {0}")] + AccessDenied(String), + + #[error("invalid format: {0}")] + InvalidFormat(String), + + #[error("network error: {0}")] + Network(String), +} + +/// Generic secret storage interface +/// +/// Implementations store and retrieve raw secret strings without interpreting them. +/// The provider plugins are responsible for interpreting the secret format. +#[async_trait] +pub trait SecretStore: Send + Sync { + /// Retrieve a secret by key + /// + /// Returns the raw secret string without interpretation. + async fn get(&self, key: &str) -> SecretResult; + + /// Check if the storage backend is available + /// + /// This should be a lightweight check (e.g., can we connect to the storage service?) + /// without actually retrieving secrets. + async fn health_check(&self) -> SecretResult<()>; + + /// Get a human-readable name for this storage backend + fn name(&self) -> &'static str; +} diff --git a/crates/openshell-providers/src/stores/database.rs b/crates/openshell-providers/src/stores/database.rs new file mode 100644 index 000000000..008d660e1 --- /dev/null +++ b/crates/openshell-providers/src/stores/database.rs @@ -0,0 +1,82 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Gateway database secret store. +//! +//! Fetches credentials from the provider credentials HashMap stored in the gateway database. +//! This is the primary secret storage mechanism for OpenShell. +//! +//! The gateway stores Provider records with credentials in `Provider.credentials` HashMap. +//! This store provides a clean abstraction over that storage. + +use crate::secret_store::{SecretError, SecretResult, SecretStore}; +use async_trait::async_trait; +use std::collections::HashMap; + +/// Gateway database secret store +/// +/// Wraps a provider's credentials HashMap from the database. +/// This is a simple in-memory wrapper - the actual persistence is handled +/// by the gateway's database layer. +pub struct DatabaseStore { + credentials: HashMap, +} + +impl DatabaseStore { + /// Create a new database store from provider credentials + #[must_use] + pub fn new(credentials: HashMap) -> Self { + Self { credentials } + } +} + +#[async_trait] +impl SecretStore for DatabaseStore { + async fn get(&self, key: &str) -> SecretResult { + tracing::debug!(key = key, "fetching secret from database store"); + + self.credentials.get(key).cloned().ok_or_else(|| { + SecretError::NotFound(format!("Credential '{}' not found in provider", key)) + }) + } + + async fn health_check(&self) -> SecretResult<()> { + // Database store is always available (in-memory) + Ok(()) + } + + fn name(&self) -> &'static str { + "database" + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_database_store_get() { + let mut creds = HashMap::new(); + creds.insert("VERTEX_ADC".to_string(), "mock-adc-json".to_string()); + + let store = DatabaseStore::new(creds); + + let result = store.get("VERTEX_ADC").await.unwrap(); + assert_eq!(result, "mock-adc-json"); + } + + #[tokio::test] + async fn test_database_store_not_found() { + let store = DatabaseStore::new(HashMap::new()); + + let result = store.get("NONEXISTENT").await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_database_store_health_check() { + let store = DatabaseStore::new(HashMap::new()); + let result = store.health_check().await; + assert!(result.is_ok()); + } +} diff --git a/crates/openshell-providers/src/stores/mod.rs b/crates/openshell-providers/src/stores/mod.rs new file mode 100644 index 000000000..d959c2c93 --- /dev/null +++ b/crates/openshell-providers/src/stores/mod.rs @@ -0,0 +1,8 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Secret store implementations. + +pub mod database; + +pub use database::DatabaseStore; diff --git a/crates/openshell-providers/src/token_cache.rs b/crates/openshell-providers/src/token_cache.rs new file mode 100644 index 000000000..a7a9016c1 --- /dev/null +++ b/crates/openshell-providers/src/token_cache.rs @@ -0,0 +1,318 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Token cache with automatic background refresh. +//! +//! This module provides a caching layer on top of provider plugins and secret stores that: +//! - Caches tokens to avoid repeated fetches +//! - Automatically refreshes tokens before they expire +//! - Runs a background task to proactively refresh tokens + +use crate::ProviderPlugin; +use crate::runtime::RuntimeResult; +use crate::secret_store::SecretStore; +use chrono::{DateTime, Duration, Utc}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; + +/// Token cache entry with expiry tracking +#[derive(Debug, Clone)] +struct CachedToken { + access_token: String, + #[allow(dead_code)] + token_type: String, + expires_at: DateTime, + refresh_margin: Duration, +} + +impl CachedToken { + /// Check if token is still valid + fn is_valid(&self) -> bool { + Utc::now() < self.expires_at + } + + /// Check if token should be refreshed (within margin of expiry) + fn should_refresh(&self) -> bool { + Utc::now() + self.refresh_margin > self.expires_at + } +} + +/// Token cache with automatic background refresh +/// +/// This cache wraps a provider plugin and secret store: +/// 1. Caches tokens to avoid repeated network calls +/// 2. Returns cached token if still valid +/// 3. Fetches fresh token if cache miss or expired +/// 4. Runs background task to refresh tokens before expiry +pub struct TokenCache { + /// Provider plugin that knows how to interpret credentials + provider: Arc, + + /// Secret store that provides raw credentials + store: Arc, + + /// Cached tokens by provider name + tokens: Arc>>, + + /// Background refresh task handle + refresh_task: Option>, + + /// How many seconds before expiry to refresh + refresh_margin_seconds: i64, +} + +impl std::fmt::Debug for TokenCache { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TokenCache") + .field("provider_id", &self.provider.id()) + .field("store_name", &self.store.name()) + .field("refresh_margin_seconds", &self.refresh_margin_seconds) + .field("has_background_task", &self.refresh_task.is_some()) + .finish() + } +} + +impl TokenCache { + /// Create a new token cache + /// + /// # Arguments + /// * `provider` - The provider plugin to interpret credentials + /// * `store` - The secret store to fetch credentials from + /// * `refresh_margin_seconds` - Refresh tokens this many seconds before expiry (default: 300 = 5 min) + /// * `auto_refresh` - Enable background auto-refresh task (default: true) + pub fn new( + provider: Arc, + store: Arc, + refresh_margin_seconds: i64, + auto_refresh: bool, + ) -> Self { + let tokens = Arc::new(RwLock::new(HashMap::new())); + + // Conditionally start background refresh task based on auto_refresh flag + let refresh_task = if auto_refresh { + let tokens = tokens.clone(); + let provider = provider.clone(); + let store = store.clone(); + let margin = refresh_margin_seconds; + + Some(tokio::spawn(async move { + Self::auto_refresh_loop(tokens, provider, store, margin).await; + })) + } else { + tracing::info!("Auto-refresh disabled for token cache"); + None + }; + + Self { + provider, + store, + tokens, + refresh_task, + refresh_margin_seconds, + } + } + + /// Get a token for the specified provider + /// + /// Returns cached token if valid, otherwise fetches fresh token. + pub async fn get_token(&self, provider_name: &str) -> RuntimeResult { + let (token, _) = self.get_token_with_expiry(provider_name).await?; + Ok(token) + } + + /// Get a token with its expiry time. + /// + /// Returns (token, expires_in_seconds) where expires_in_seconds is the + /// remaining time until token expiration. + pub async fn get_token_with_expiry(&self, provider_name: &str) -> RuntimeResult<(String, u64)> { + // Check cache first + { + let tokens = self.tokens.read().await; + if let Some(cached) = tokens.get(provider_name) { + if cached.is_valid() { + let expires_in = (cached.expires_at - Utc::now()).num_seconds().max(0) as u64; + tracing::debug!( + provider = provider_name, + expires_at = %cached.expires_at, + expires_in = expires_in, + "returning cached token" + ); + return Ok((cached.access_token.clone(), expires_in)); + } + } + } + + // Cache miss or expired - fetch fresh token + tracing::info!(provider = provider_name, "fetching fresh token"); + let token = self.refresh_token(provider_name).await?; + + // Get the expiry time we just cached + let expires_in = { + let tokens = self.tokens.read().await; + if let Some(cached) = tokens.get(provider_name) { + (cached.expires_at - Utc::now()).num_seconds().max(0) as u64 + } else { + // Fallback - shouldn't happen since we just cached it + 3600 + } + }; + + Ok((token, expires_in)) + } + + /// Force refresh a token (bypasses cache) + async fn refresh_token(&self, provider_name: &str) -> RuntimeResult { + let response = self.provider.get_runtime_token(self.store.as_ref()).await?; + + let expires_at = Utc::now() + Duration::seconds(response.expires_in as i64); + let cached = CachedToken { + access_token: response.access_token.clone(), + token_type: response.token_type, + expires_at, + refresh_margin: Duration::seconds(self.refresh_margin_seconds), + }; + + tracing::info!( + provider = provider_name, + expires_at = %cached.expires_at, + "cached fresh token" + ); + + self.tokens + .write() + .await + .insert(provider_name.to_string(), cached); + + Ok(response.access_token) + } + + /// Background task that proactively refreshes tokens before expiry + async fn auto_refresh_loop( + tokens: Arc>>, + provider: Arc, + store: Arc, + margin_seconds: i64, + ) { + // For 60-minute tokens with 5-minute margin, we want to check every 55 minutes + // This minimizes wake-ups while ensuring we catch the refresh window + let check_interval_seconds = 3600 - margin_seconds; // Default: 3600 - 300 = 3300 (55 min) + + loop { + tokio::time::sleep(tokio::time::Duration::from_secs( + check_interval_seconds as u64, + )) + .await; + + // Find tokens that need refresh + let to_refresh: Vec = { + let tokens = tokens.read().await; + tokens + .iter() + .filter(|(_, token)| token.should_refresh()) + .map(|(name, _)| name.clone()) + .collect() + }; + + // Refresh each token + for provider_name in to_refresh { + tracing::info!(provider = provider_name, "background refresh triggered"); + + match provider.get_runtime_token(store.as_ref()).await { + Ok(response) => { + let expires_at = Utc::now() + Duration::seconds(response.expires_in as i64); + let cached = CachedToken { + access_token: response.access_token, + token_type: response.token_type, + expires_at, + refresh_margin: Duration::seconds(margin_seconds), + }; + + tokens.write().await.insert(provider_name.clone(), cached); + + tracing::info!( + provider = provider_name, + expires_at = %expires_at, + "background refresh succeeded" + ); + } + Err(e) => { + tracing::error!( + provider = provider_name, + error = %e, + "background refresh failed" + ); + } + } + } + } + } +} + +impl Drop for TokenCache { + fn drop(&mut self) { + if let Some(task) = self.refresh_task.take() { + task.abort(); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::runtime::RuntimeResult; + use crate::{DatabaseStore, ProviderPlugin, SecretStore, TokenResponse}; + use async_trait::async_trait; + + struct MockProvider; + + #[async_trait] + impl ProviderPlugin for MockProvider { + fn id(&self) -> &'static str { + "mock" + } + + fn discover_existing( + &self, + ) -> Result, crate::ProviderError> { + Ok(None) + } + + async fn get_runtime_token( + &self, + _store: &dyn SecretStore, + ) -> RuntimeResult { + Ok(TokenResponse { + access_token: "mock-token".to_string(), + token_type: "Bearer".to_string(), + expires_in: 3600, + metadata: HashMap::new(), + }) + } + } + + #[tokio::test] + async fn test_cache_miss_fetches_token() { + let provider = Arc::new(MockProvider); + let store = Arc::new(DatabaseStore::new(HashMap::new())); + let cache = TokenCache::new(provider, store, 300, true); + + let token = cache.get_token("mock").await.unwrap(); + assert_eq!(token, "mock-token"); + } + + #[tokio::test] + async fn test_cache_hit_avoids_fetch() { + let provider = Arc::new(MockProvider); + let store = Arc::new(DatabaseStore::new(HashMap::new())); + let cache = TokenCache::new(provider, store, 300, true); + + // First call - cache miss + let token1 = cache.get_token("mock").await.unwrap(); + + // Second call - cache hit + let token2 = cache.get_token("mock").await.unwrap(); + + assert_eq!(token1, token2); + } +} diff --git a/crates/openshell-router/src/backend.rs b/crates/openshell-router/src/backend.rs index d1d7092c0..9b5d1a000 100644 --- a/crates/openshell-router/src/backend.rs +++ b/crates/openshell-router/src/backend.rs @@ -95,7 +95,7 @@ async fn send_backend_request( headers: Vec<(String, String)>, body: bytes::Bytes, ) -> Result { - let url = build_backend_url(&route.endpoint, path); + let url = build_backend_url(&route.endpoint, path, &route.model); let reqwest_method: reqwest::Method = method .parse() @@ -137,13 +137,24 @@ async fn send_backend_request( // Set the "model" field in the JSON body to the route's configured model so the // backend receives the correct model ID regardless of what the client sent. + // + // Exception: Vertex AI's :streamRawPredict endpoint expects the model in the URL + // path (already handled in build_backend_url), not in the request body. + let is_vertex_ai = route.endpoint.contains("aiplatform.googleapis.com"); + let body = match serde_json::from_slice::(&body) { Ok(mut json) => { if let Some(obj) = json.as_object_mut() { - obj.insert( - "model".to_string(), - serde_json::Value::String(route.model.clone()), - ); + if is_vertex_ai { + // Remove model field for Vertex AI (it's in the URL path) + obj.remove("model"); + } else { + // Insert/override model field for standard backends + obj.insert( + "model".to_string(), + serde_json::Value::String(route.model.clone()), + ); + } } bytes::Bytes::from(serde_json::to_vec(&json).unwrap_or_else(|_| body.to_vec())) } @@ -241,7 +252,7 @@ pub async fn verify_backend_endpoint( if mock::is_mock_route(route) { return Ok(ValidatedEndpoint { - url: build_backend_url(&route.endpoint, probe.path), + url: build_backend_url(&route.endpoint, probe.path, &route.model), protocol: probe.protocol.to_string(), }); } @@ -306,7 +317,7 @@ async fn try_validation_request( details, }, })?; - let url = build_backend_url(&route.endpoint, path); + let url = build_backend_url(&route.endpoint, path, &route.model); if response.status().is_success() { return Ok(ValidatedEndpoint { @@ -418,8 +429,23 @@ pub async fn proxy_to_backend_streaming( }) } -fn build_backend_url(endpoint: &str, path: &str) -> String { +fn build_backend_url(endpoint: &str, path: &str, model: &str) -> String { let base = endpoint.trim_end_matches('/'); + + // Special handling for Vertex AI + if base.contains("aiplatform.googleapis.com") && path.starts_with("/v1/messages") { + // Vertex AI uses a different path structure: + // https://{region}-aiplatform.googleapis.com/v1/projects/{project}/locations/{region}/publishers/anthropic/models/{model}:streamRawPredict + // The base already has everything up to /models, so we append /{model}:streamRawPredict + let model_suffix = if model.is_empty() { + String::new() + } else { + format!("/{}", model) + }; + return format!("{}{}:streamRawPredict", base, model_suffix); + } + + // Deduplicate /v1 prefix for standard endpoints if base.ends_with("/v1") && (path == "/v1" || path.starts_with("/v1/")) { return format!("{base}{}", &path[3..]); } @@ -438,7 +464,7 @@ mod tests { #[test] fn build_backend_url_dedupes_v1_prefix() { assert_eq!( - build_backend_url("https://api.openai.com/v1", "/v1/chat/completions"), + build_backend_url("https://api.openai.com/v1", "/v1/chat/completions", "gpt-4"), "https://api.openai.com/v1/chat/completions" ); } @@ -446,15 +472,27 @@ mod tests { #[test] fn build_backend_url_preserves_non_versioned_base() { assert_eq!( - build_backend_url("https://api.anthropic.com", "/v1/messages"), + build_backend_url("https://api.anthropic.com", "/v1/messages", "claude-3"), "https://api.anthropic.com/v1/messages" ); } + #[test] + fn build_backend_url_handles_vertex_ai() { + assert_eq!( + build_backend_url( + "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/anthropic/models", + "/v1/messages", + "claude-3-5-sonnet-20241022" + ), + "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/anthropic/models/claude-3-5-sonnet-20241022:streamRawPredict" + ); + } + #[test] fn build_backend_url_handles_exact_v1_path() { assert_eq!( - build_backend_url("https://api.openai.com/v1", "/v1"), + build_backend_url("https://api.openai.com/v1", "/v1", "gpt-4"), "https://api.openai.com/v1" ); } diff --git a/crates/openshell-sandbox/src/grpc_client.rs b/crates/openshell-sandbox/src/grpc_client.rs index 5503637ee..21f57bc9e 100644 --- a/crates/openshell-sandbox/src/grpc_client.rs +++ b/crates/openshell-sandbox/src/grpc_client.rs @@ -221,7 +221,7 @@ pub struct SettingsPollResult { pub config_revision: u64, pub policy_source: PolicySource, /// Effective settings keyed by name. - pub settings: std::collections::HashMap, + pub settings: HashMap, /// When `policy_source` is `Global`, the version of the global policy revision. pub global_policy_version: u32, } diff --git a/crates/openshell-sandbox/src/l7/mod.rs b/crates/openshell-sandbox/src/l7/mod.rs index 880b6fd9e..1811c6795 100644 --- a/crates/openshell-sandbox/src/l7/mod.rs +++ b/crates/openshell-sandbox/src/l7/mod.rs @@ -53,12 +53,23 @@ pub enum EnforcementMode { Enforce, } +/// OAuth header injection configuration +#[derive(Debug, Clone)] +pub struct OAuthConfig { + /// Environment variable name containing the OAuth token + pub token_env_var: String, + /// Header value format template (use {token} as placeholder) + /// Default: "Bearer {token}" + pub header_format: String, +} + /// L7 configuration for an endpoint, extracted from policy data. #[derive(Debug, Clone)] pub struct L7EndpointConfig { pub protocol: L7Protocol, pub tls: TlsMode, pub enforcement: EnforcementMode, + pub oauth: Option, } /// Result of an L7 policy decision for a single request. @@ -112,10 +123,26 @@ pub fn parse_l7_config(val: ®orus::Value) -> Option { _ => EnforcementMode::Audit, }; + // Parse OAuth configuration if present + let oauth = match get_object_field(val, "oauth") { + Some(oauth_val) => { + let token_env_var = get_object_str(oauth_val, "token_env_var")?; + let header_format = get_object_str(oauth_val, "header_format") + .unwrap_or_else(|| "Bearer {token}".to_string()); + + Some(OAuthConfig { + token_env_var, + header_format, + }) + } + None => None, + }; + Some(L7EndpointConfig { protocol, tls, enforcement, + oauth, }) } @@ -132,6 +159,14 @@ pub fn parse_tls_mode(val: ®orus::Value) -> TlsMode { } /// Extract a string value from a regorus object. +fn get_object_field<'a>(val: &'a regorus::Value, key: &str) -> Option<&'a regorus::Value> { + let key_val = regorus::Value::String(key.into()); + match val { + regorus::Value::Object(map) => map.get(&key_val), + _ => None, + } +} + fn get_object_str(val: ®orus::Value, key: &str) -> Option { let key_val = regorus::Value::String(key.into()); match val { diff --git a/crates/openshell-sandbox/src/l7/relay.rs b/crates/openshell-sandbox/src/l7/relay.rs index b2fb34b61..4ad170b6b 100644 --- a/crates/openshell-sandbox/src/l7/relay.rs +++ b/crates/openshell-sandbox/src/l7/relay.rs @@ -31,6 +31,8 @@ pub struct L7EvalContext { pub cmdline_paths: Vec, /// Supervisor-only placeholder resolver for outbound headers. pub(crate) secret_resolver: Option>, + /// OAuth header injection configuration from endpoint policy. + pub(crate) oauth_config: Option, } /// Run protocol-aware L7 inspection on a tunnel. @@ -215,6 +217,7 @@ where client, upstream, ctx.secret_resolver.as_deref(), + ctx.oauth_config.as_ref(), ) .await?; match outcome { @@ -388,9 +391,14 @@ where // Forward request with credential rewriting and relay the response. // relay_http_request_with_resolver handles both directions: it sends // the request upstream and reads the response back to the client. - let outcome = - crate::l7::rest::relay_http_request_with_resolver(&req, client, upstream, resolver) - .await?; + let outcome = crate::l7::rest::relay_http_request_with_resolver( + &req, + client, + upstream, + resolver, + ctx.oauth_config.as_ref(), + ) + .await?; match outcome { RelayOutcome::Reusable => {} // continue loop diff --git a/crates/openshell-sandbox/src/l7/rest.rs b/crates/openshell-sandbox/src/l7/rest.rs index 0c136be79..c1e9754cd 100644 --- a/crates/openshell-sandbox/src/l7/rest.rs +++ b/crates/openshell-sandbox/src/l7/rest.rs @@ -12,7 +12,7 @@ use crate::secrets::rewrite_http_header_block; use miette::{IntoDiagnostic, Result, miette}; use std::collections::HashMap; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tracing::{debug, warn}; +use tracing::{debug, info, warn}; const MAX_HEADER_BYTES: usize = 16384; // 16 KiB for HTTP headers const RELAY_BUF_SIZE: usize = 8192; @@ -252,7 +252,103 @@ where C: AsyncRead + AsyncWrite + Unpin, U: AsyncRead + AsyncWrite + Unpin, { - relay_http_request_with_resolver(req, client, upstream, None).await + relay_http_request_with_resolver(req, client, upstream, None, None).await +} + +/// Get OAuth access token from environment or resolver. +/// +/// Returns `None` if the token is not found, which will be logged as a warning +/// by the caller to help debug OAuth configuration issues. +fn get_oauth_access_token( + token_env_var: &str, + resolver: Option<&crate::secrets::SecretResolver>, +) -> Option { + // Try environment variable first + if let Ok(token) = std::env::var(token_env_var) { + return Some(token.trim().to_string()); // Strip whitespace/newlines + } + + // Try resolver with placeholder + if let Some(resolver) = resolver { + let placeholder = format!("openshell:resolve:env:{}", token_env_var); + if let Some(token) = resolver.resolve_placeholder(&placeholder) { + return Some(token.trim().to_string()); // Strip whitespace/newlines + } + } + + None +} + +/// Inject or replace Authorization header in HTTP request using OAuth config +fn inject_oauth_header( + raw: &[u8], + resolver: Option<&crate::secrets::SecretResolver>, + oauth_config: &crate::l7::OAuthConfig, +) -> Result { + use crate::secrets::{RewriteResult, rewrite_http_header_block}; + + // Get the access token + let Some(access_token) = get_oauth_access_token(&oauth_config.token_env_var, resolver) else { + // No token available - log warning to help debug OAuth configuration issues + warn!( + token_env_var = %oauth_config.token_env_var, + "OAuth token not found in environment or resolver. Check that the token_env_var \ + in the sandbox policy matches the credential key from the provider. Falling back \ + to standard credential rewriting." + ); + return rewrite_http_header_block(raw, resolver); + }; + + info!( + token_env_var = %oauth_config.token_env_var, + "Injecting OAuth access token into Authorization header" + ); + + let header_end = raw + .windows(4) + .position(|w| w == b"\r\n\r\n") + .map(|p| p + 4) + .unwrap_or(raw.len()); + + let header_str = String::from_utf8_lossy(&raw[..header_end]); + let mut lines: Vec<&str> = header_str.split("\r\n").collect(); + + // Find and remove existing Authorization header + lines.retain(|line| !line.to_ascii_lowercase().starts_with("authorization:")); + + let mut output = Vec::with_capacity(raw.len() + 100); + + // Write request line + if let Some(request_line) = lines.first() { + output.extend_from_slice(request_line.as_bytes()); + output.extend_from_slice(b"\r\n"); + } + + // Write Authorization header using the template from config + let header_value = oauth_config.header_format.replace("{token}", &access_token); + let auth_header = format!("Authorization: {}", header_value); + output.extend_from_slice(auth_header.as_bytes()); + output.extend_from_slice(b"\r\n"); + + // Write remaining headers (skip first line which is request line, skip empty lines at end) + for line in lines.iter().skip(1) { + if line.is_empty() { + break; + } + output.extend_from_slice(line.as_bytes()); + output.extend_from_slice(b"\r\n"); + } + + // End headers + output.extend_from_slice(b"\r\n"); + + // Copy body + output.extend_from_slice(&raw[header_end..]); + + Ok(RewriteResult { + rewritten: output, + redacted_target: None, + }) } pub(crate) async fn relay_http_request_with_resolver( @@ -260,20 +356,58 @@ pub(crate) async fn relay_http_request_with_resolver( client: &mut C, upstream: &mut U, resolver: Option<&crate::secrets::SecretResolver>, + oauth_config: Option<&crate::l7::OAuthConfig>, ) -> Result where C: AsyncRead + AsyncWrite + Unpin, U: AsyncRead + AsyncWrite + Unpin, { + // Provider-specific request interception (Vertex AI OAuth workaround) + // + // Check if this request should be intercepted by a provider-specific handler. + // Currently only used by Vertex AI to intercept OAuth token exchange for + // Claude CLI compatibility. See `providers::vertex` module for details. + if req.action == "POST" && req.target == "/token" { + let header_str = String::from_utf8_lossy(&req.raw_header); + if let Some(host_line) = header_str + .lines() + .find(|line| line.to_ascii_lowercase().starts_with("host:")) + { + let host = host_line.split_once(':').map_or("", |(_, h)| h.trim()); + + // Check if Vertex provider should intercept this request + if crate::providers::vertex::should_intercept_oauth_request( + &req.action, + host, + &req.target, + ) { + crate::providers::vertex::log_oauth_interception("L7/TLS-terminated"); + let response = crate::providers::vertex::generate_fake_oauth_response(None); + + client.write_all(&response).await.into_diagnostic()?; + client.flush().await.into_diagnostic()?; + return Ok(RelayOutcome::Consumed); + } + } + } let header_end = req .raw_header .windows(4) .position(|w| w == b"\r\n\r\n") .map_or(req.raw_header.len(), |p| p + 4); - let rewrite_result = rewrite_http_header_block(&req.raw_header[..header_end], resolver) - .map_err(|e| miette!("credential injection failed: {e}"))?; - + // Inject OAuth header if configured for this endpoint + let rewrite_result = if let Some(oauth_cfg) = oauth_config { + // For OAuth-configured endpoints, inject/replace Authorization header + inject_oauth_header(&req.raw_header[..header_end], resolver, oauth_cfg) + .map_err(|e| miette!("OAuth header injection failed: {e}"))? + } else { + // For other requests, use standard credential rewriting + rewrite_http_header_block(&req.raw_header[..header_end], resolver) + .map_err(|e| miette!("credential injection failed: {e}"))? + }; + + // Rest of the function remains the same... upstream .write_all(&rewrite_result.rewritten) .await @@ -309,12 +443,9 @@ where if matches!(outcome, RelayOutcome::Upgraded { .. }) { let header_str = String::from_utf8_lossy(&req.raw_header[..header_end]); if !client_requested_upgrade(&header_str) { - warn!( - method = %req.action, - target = %req.target, - "upstream sent unsolicited 101 without client Upgrade request — closing connection" - ); - return Ok(RelayOutcome::Consumed); + return Err(miette!( + "upstream sent unsolicited 101 without client Upgrade request" + )); } } @@ -559,6 +690,7 @@ fn find_crlf(buf: &[u8], start: usize) -> Option { /// /// Note: callers that receive `Upgraded` are responsible for switching to /// raw bidirectional relay and forwarding the overflow bytes. +#[allow(dead_code)] pub(crate) async fn relay_response_to_client( upstream: &mut U, client: &mut C, @@ -1606,15 +1738,21 @@ mod tests { &mut proxy_to_client, &mut proxy_to_upstream, None, + None, ), ) .await .expect("relay must not deadlock"); - let outcome = result.expect("relay should succeed"); + // Unsolicited 101 upgrade should return an error assert!( - matches!(outcome, RelayOutcome::Consumed), - "unsolicited 101 should be rejected as Consumed, got {outcome:?}" + result.is_err(), + "unsolicited 101 should be rejected with an error" + ); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("unsolicited 101"), + "error message should mention unsolicited 101, got: {err_msg}" ); upstream_task.await.expect("upstream task should complete"); @@ -1663,6 +1801,7 @@ mod tests { &mut proxy_to_client, &mut proxy_to_upstream, None, + None, ), ) .await @@ -1788,6 +1927,7 @@ mod tests { &mut proxy_to_client, &mut proxy_to_upstream, resolver.as_ref(), + None, ), ) .await @@ -1871,6 +2011,7 @@ mod tests { &req, &mut proxy_to_client, &mut proxy_to_upstream, + None, None, // <-- No resolver, as in the L4 raw tunnel path ), ) @@ -1960,6 +2101,7 @@ mod tests { &mut proxy_to_client, &mut proxy_to_upstream, resolver, + None, ), ) .await diff --git a/crates/openshell-sandbox/src/lib.rs b/crates/openshell-sandbox/src/lib.rs index b160cdefc..1d112fdad 100644 --- a/crates/openshell-sandbox/src/lib.rs +++ b/crates/openshell-sandbox/src/lib.rs @@ -17,6 +17,7 @@ pub mod opa; mod policy; mod process; pub mod procfs; +mod providers; pub mod proxy; mod sandbox; mod secrets; @@ -213,6 +214,12 @@ pub async fn run_sandbox( // Prepare filesystem: create and chown read_write directories prepare_filesystem(&policy)?; + // Create fake ADC file for Vertex AI if VERTEX_ADC is present + // This allows Claude CLI to work without requiring real credentials on disk + if provider_env.contains_key("VERTEX_ADC") { + create_fake_vertex_adc(&policy)?; + } + // Generate ephemeral CA and TLS state for HTTPS L7 inspection. // The CA cert is written to disk so sandbox processes can trust it. let (tls_state, ca_file_paths) = if matches!(policy.network.mode, NetworkMode::Proxy) { @@ -1448,6 +1455,94 @@ fn prepare_filesystem(_policy: &SandboxPolicy) -> Result<()> { Ok(()) } +/// Create fake ADC credentials file for Vertex AI provider. +/// +/// **Vertex AI + Claude CLI workaround:** +/// - Claude CLI requires ADC credentials on disk to work with Vertex AI +/// - We create fake ADC credentials here to satisfy Claude CLI's requirements +/// - When Claude CLI tries to exchange these fake credentials with Google OAuth, +/// the proxy intercepts the request (see rest.rs and proxy.rs oauth2.googleapis.com handling) +/// - The proxy returns a fake OAuth success, allowing Claude CLI to proceed +/// - Real OAuth tokens are injected via Authorization headers for actual API requests +/// +/// **Why not use real credentials:** +/// - Security: Avoid writing long-lived refresh tokens to disk in the sandbox +/// - Simplicity: Users don't need to run `gcloud auth` inside the sandbox +/// - Consistency: Token management is centralized in the gateway TokenCache +/// +/// **Related code:** +/// - OAuth interception: `crates/openshell-sandbox/src/l7/rest.rs` (relay_http_request_with_resolver) +/// - OAuth interception: `crates/openshell-sandbox/src/proxy.rs` (handle_forward_proxy) +/// - Token injection: `crates/openshell-sandbox/src/l7/rest.rs` (inject_oauth_header) +#[cfg(unix)] +fn create_fake_vertex_adc(policy: &SandboxPolicy) -> Result<()> { + use nix::unistd::{Group, User, chown}; + use std::fs; + use std::os::unix::fs::PermissionsExt; + + // Resolve sandbox user/group for ownership (match pattern from prepare_filesystem) + let user_name = match policy.process.run_as_user.as_deref() { + Some(name) if !name.is_empty() => Some(name), + _ => None, + }; + let group_name = match policy.process.run_as_group.as_deref() { + Some(name) if !name.is_empty() => Some(name), + _ => None, + }; + + let uid = user_name + .and_then(|name| User::from_name(name).ok().flatten()) + .map(|u| u.uid); + let gid = group_name + .and_then(|name| Group::from_name(name).ok().flatten()) + .map(|g| g.gid); + + // Get home directory from passwd entry, defaulting to /sandbox + let home_dir = user_name + .and_then(|name| User::from_name(name).ok().flatten()) + .map(|u| u.dir) + .unwrap_or_else(|| std::path::PathBuf::from("/sandbox")); + + let gcloud_dir = home_dir.join(".config/gcloud"); + let adc_path = gcloud_dir.join("application_default_credentials.json"); + + // Create directory + fs::create_dir_all(&gcloud_dir).into_diagnostic()?; + + // Write fake ADC file + let fake_adc = r#"{ + "client_id": "fake-client-id", + "client_secret": "fake-client-secret", + "refresh_token": "fake-refresh-token", + "type": "authorized_user" +}"#; + + fs::write(&adc_path, fake_adc).into_diagnostic()?; + + // Set file permissions to 600 (owner read/write only) + let mut perms = fs::metadata(&adc_path).into_diagnostic()?.permissions(); + perms.set_mode(0o600); + fs::set_permissions(&adc_path, perms).into_diagnostic()?; + + // Set ownership on directory and file + if let (Some(uid), Some(gid)) = (uid, gid) { + chown(&gcloud_dir, Some(uid), Some(gid)).into_diagnostic()?; + chown(&adc_path, Some(uid), Some(gid)).into_diagnostic()?; + } + + info!( + path = %adc_path.display(), + "Created fake Vertex ADC credentials file" + ); + + Ok(()) +} + +#[cfg(not(unix))] +fn create_fake_vertex_adc(_policy: &SandboxPolicy) -> Result<()> { + Ok(()) +} + /// Background loop that polls the server for policy updates. /// /// When a new version is detected, attempts to reload the OPA engine via diff --git a/crates/openshell-sandbox/src/opa.rs b/crates/openshell-sandbox/src/opa.rs index f1c0ad293..aa120abd4 100644 --- a/crates/openshell-sandbox/src/opa.rs +++ b/crates/openshell-sandbox/src/opa.rs @@ -703,6 +703,12 @@ fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy) -> String { if !e.allowed_ips.is_empty() { ep["allowed_ips"] = e.allowed_ips.clone().into(); } + if let Some(oauth) = &e.oauth { + ep["oauth"] = serde_json::json!({ + "token_env_var": oauth.token_env_var, + "header_format": oauth.header_format, + }); + } ep }) .collect(); @@ -802,6 +808,7 @@ mod tests { run_as_group: "sandbox".to_string(), }), network_policies, + oauth_credentials: None, } } @@ -1639,6 +1646,7 @@ process: run_as_group: "sandbox".to_string(), }), network_policies, + oauth_credentials: None, }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); @@ -2255,6 +2263,7 @@ process: run_as_group: "sandbox".to_string(), }), network_policies, + oauth_credentials: None, }; let engine = OpaEngine::from_proto(&proto).expect("Failed to create engine from proto"); @@ -2485,6 +2494,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + oauth_credentials: None, }; let engine = OpaEngine::from_proto(&proto).unwrap(); // Port 443 diff --git a/crates/openshell-sandbox/src/process.rs b/crates/openshell-sandbox/src/process.rs index b93d125ab..ca305e5ed 100644 --- a/crates/openshell-sandbox/src/process.rs +++ b/crates/openshell-sandbox/src/process.rs @@ -26,6 +26,18 @@ const SSH_HANDSHAKE_SECRET_ENV: &str = "OPENSHELL_SSH_HANDSHAKE_SECRET"; fn inject_provider_env(cmd: &mut Command, provider_env: &HashMap) { for (key, value) in provider_env { + // Filter out OAuth access tokens - these are only needed by the supervisor's + // proxy for header injection, not by agent processes (Claude CLI, bash, etc.) + if key.ends_with("_ACCESS_TOKEN") { + continue; + } + + // Filter out ADC credentials - agent processes use fake ADC file instead + // (created by supervisor based on VERTEX_ADC presence in provider_env) + if key == "VERTEX_ADC" { + continue; + } + cmd.env(key, value); } } diff --git a/crates/openshell-sandbox/src/providers/mod.rs b/crates/openshell-sandbox/src/providers/mod.rs new file mode 100644 index 000000000..737182d11 --- /dev/null +++ b/crates/openshell-sandbox/src/providers/mod.rs @@ -0,0 +1,10 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Provider-specific runtime behavior for the sandbox. +//! +//! This module contains provider-specific logic that runs within the sandbox +//! at request processing time. This is separate from the provider discovery +//! and credential management in the `openshell-providers` crate. + +pub mod vertex; diff --git a/crates/openshell-sandbox/src/providers/vertex.rs b/crates/openshell-sandbox/src/providers/vertex.rs new file mode 100644 index 000000000..d61a53663 --- /dev/null +++ b/crates/openshell-sandbox/src/providers/vertex.rs @@ -0,0 +1,102 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Vertex AI provider-specific sandbox runtime behavior. +//! +//! ## OAuth Token Interception for Claude CLI Compatibility +//! +//! This module implements a workaround to enable Claude CLI to work with Vertex AI +//! without requiring users to manually authenticate via `gcloud auth application-default login` +//! inside the sandbox. +//! +//! ### The Problem +//! +//! Claude CLI expects valid Application Default Credentials (ADC) from Google Cloud: +//! 1. Reads ADC file from ~/.config/gcloud/application_default_credentials.json +//! 2. Attempts to exchange refresh token with oauth2.googleapis.com +//! 3. Uses returned access token for Vertex AI API requests +//! +//! ### Our Solution +//! +//! We inject **fake** ADC credentials via `create_fake_vertex_adc()` and intercept +//! the token exchange: +//! +//! 1. **Fake ADC credentials** are written to the expected path +//! 2. Claude CLI reads these fake credentials +//! 3. Claude CLI sends POST /token to oauth2.googleapis.com +//! 4. **We intercept this request** and return a fake OAuth success response +//! 5. Claude CLI proceeds to make Vertex API requests +//! 6. **Real OAuth tokens** are injected via Authorization headers by the proxy +//! +//! ### Why This is Vertex-Specific +//! +//! - Only Vertex AI uses oauth2.googleapis.com for OAuth token exchange +//! - The fake token in the intercepted response is never actually used +//! - Real tokens come from the token cache (VERTEX_ACCESS_TOKEN environment variable) +//! - This workaround is specific to Google Cloud / Vertex AI authentication flow +//! +//! ### Related Code +//! +//! - ADC credential creation: `lib.rs::create_fake_vertex_adc()` +//! - OAuth header injection: `l7/rest.rs::inject_oauth_header()` +//! - Token caching: `openshell-providers::token_cache::TokenCache` + +use tracing::info; + +/// Check if this request should be intercepted for Vertex AI OAuth workaround. +/// +/// Returns `true` if: +/// - Method is POST +/// - Host is oauth2.googleapis.com +/// - Path is /token +/// +/// This is called from both L7 (TLS-terminated) and L4 (forward proxy) paths. +pub fn should_intercept_oauth_request(method: &str, host: &str, path: &str) -> bool { + method.to_ascii_uppercase() == "POST" + && host.to_ascii_lowercase() == "oauth2.googleapis.com" + && path == "/token" +} + +/// Generate a fake OAuth success response for intercepted token exchange. +/// +/// The access token in this response is a placeholder - it will never be used. +/// Real OAuth tokens are injected via Authorization headers by the proxy's +/// `inject_oauth_header()` function. +/// +/// # L7 Path (TLS-terminated) +/// +/// For requests processed via L7 inspection (rest.rs), we return a fake token +/// because Claude CLI needs *some* response to proceed. The actual token injection +/// happens later via `inject_oauth_header()`. +/// +/// # L4 Path (forward proxy) +/// +/// For requests that bypass L7 inspection (proxy.rs FORWARD path), we can optionally +/// inject the real cached token from VERTEX_ACCESS_TOKEN if available. This is +/// more correct but still a workaround - ideally all Vertex requests would go +/// through L7 inspection where OAuth header injection happens properly. +pub fn generate_fake_oauth_response(access_token: Option<&str>) -> Vec { + let token = access_token.unwrap_or("fake-token-will-be-replaced-by-proxy"); + + let response_body = format!( + r#"{{"access_token":"{}","token_type":"Bearer","expires_in":3600}}"#, + token + ); + + format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + response_body.len(), + response_body + ) + .into_bytes() +} + +/// Log that we're intercepting a Google OAuth token exchange. +/// +/// This is called from both rest.rs and proxy.rs to provide consistent logging. +pub fn log_oauth_interception(context: &str) { + info!( + context = context, + "Intercepting Google OAuth token exchange (Vertex AI workaround)" + ); +} diff --git a/crates/openshell-sandbox/src/proxy.rs b/crates/openshell-sandbox/src/proxy.rs index 9e87450d4..de33fb4c9 100644 --- a/crates/openshell-sandbox/src/proxy.rs +++ b/crates/openshell-sandbox/src/proxy.rs @@ -129,7 +129,7 @@ impl ProxyHandle { /// The proxy uses OPA for network decisions with process-identity binding /// via `/proc/net/tcp`. All connections are evaluated through OPA policy. #[allow(clippy::too_many_arguments)] - pub async fn start_with_bind_addr( + pub(crate) async fn start_with_bind_addr( policy: &ProxyPolicy, bind_addr: Option, opa_engine: Arc, @@ -579,6 +579,7 @@ async fn handle_tcp_connection( .map(|p| p.to_string_lossy().into_owned()) .collect(), secret_resolver: secret_resolver.clone(), + oauth_config: l7_config.as_ref().and_then(|cfg| cfg.oauth.clone()), }; if effective_tls_skip { @@ -1761,7 +1762,28 @@ async fn handle_forward_proxy( }; let host_lc = host.to_ascii_lowercase(); - // 2. Reject HTTPS — must use CONNECT for TLS + // 2. Provider-specific request interception (Vertex AI OAuth workaround) + // + // Check if this request should be intercepted by a provider-specific handler. + // Currently only used by Vertex AI to intercept OAuth token exchange for + // Claude CLI compatibility. See `providers::vertex` module for details. + // + // **Context:** This is the L4 forward proxy path (HTTP without TLS termination). + // For requests that go through L7 inspection (TLS-terminated), interception happens + // in rest.rs via the same vertex provider module. + if crate::providers::vertex::should_intercept_oauth_request(method, &host_lc, &path) { + crate::providers::vertex::log_oauth_interception("L4/forward-proxy"); + + // For L4 path, we can inject the real cached token if available + let access_token = std::env::var("VERTEX_ACCESS_TOKEN").ok(); + let response = + crate::providers::vertex::generate_fake_oauth_response(access_token.as_deref()); + + respond(client, &response).await?; + return Ok(()); + } + + // 3. Reject HTTPS — must use CONNECT for TLS if scheme == "https" { info!( dst_host = %host_lc, @@ -1896,6 +1918,7 @@ async fn handle_forward_proxy( .map(|p| p.to_string_lossy().into_owned()) .collect(), secret_resolver: secret_resolver.clone(), + oauth_config: l7_config.oauth.clone(), }; let (target_path, query_params) = crate::l7::rest::parse_target_query(&path) diff --git a/crates/openshell-sandbox/src/sandbox/linux/netns.rs b/crates/openshell-sandbox/src/sandbox/linux/netns.rs index 27f4fc338..a5f19a3dc 100644 --- a/crates/openshell-sandbox/src/sandbox/linux/netns.rs +++ b/crates/openshell-sandbox/src/sandbox/linux/netns.rs @@ -383,6 +383,7 @@ impl NetworkNamespace { /// # Safety /// /// This function should only be called in a `pre_exec` context after fork. + #[allow(unsafe_code)] pub fn enter(&self) -> Result<()> { if let Some(fd) = self.ns_fd { debug!(namespace = %self.name, "Entering network namespace via setns"); diff --git a/crates/openshell-sandbox/src/sandbox/linux/seccomp.rs b/crates/openshell-sandbox/src/sandbox/linux/seccomp.rs index e23447498..5537bbfce 100644 --- a/crates/openshell-sandbox/src/sandbox/linux/seccomp.rs +++ b/crates/openshell-sandbox/src/sandbox/linux/seccomp.rs @@ -25,6 +25,7 @@ use tracing::debug; /// Value of `SECCOMP_SET_MODE_FILTER` (linux/seccomp.h). const SECCOMP_SET_MODE_FILTER: u64 = 1; +#[allow(unsafe_code)] pub fn apply(policy: &SandboxPolicy) -> Result<()> { if matches!(policy.network.mode, NetworkMode::Allow) { return Ok(()); diff --git a/crates/openshell-sandbox/src/secrets.rs b/crates/openshell-sandbox/src/secrets.rs index a27537c91..ee355cfe7 100644 --- a/crates/openshell-sandbox/src/secrets.rs +++ b/crates/openshell-sandbox/src/secrets.rs @@ -10,6 +10,37 @@ const PLACEHOLDER_PREFIX: &str = "openshell:resolve:env:"; /// Public access to the placeholder prefix for fail-closed scanning in other modules. pub(crate) const PLACEHOLDER_PREFIX_PUBLIC: &str = PLACEHOLDER_PREFIX; +/// Credentials that should be injected as actual values into the sandbox environment +/// instead of being converted to placeholders. +/// +/// These credentials are needed by tools (like `claude` CLI) that read environment +/// variables directly rather than making HTTP requests through the proxy. +/// +/// **Security consideration**: These values are visible to all sandbox processes via +/// `/proc//environ`, unlike placeholder-based credentials which are only resolved +/// within HTTP requests. Only include credentials here when direct env var access is +/// required for tool compatibility. +/// +/// **VERTEX_ADC security warning**: +/// - VERTEX_ADC contains Google OAuth refresh tokens with long expiration (typically hours to days) +/// - These refresh tokens can be used to obtain new access tokens for the scoped GCP project +/// - Visible in `/proc//environ` to all processes in the sandbox +/// - Recommendation: Use ADC with least-privilege service account scopes (e.g., only Vertex AI access) +/// - Avoid using ADC from accounts with broad GCP permissions (Owner, Editor, etc.) +/// - Consider using Workload Identity Federation for production deployments instead of ADC +fn direct_inject_credentials() -> &'static [&'static str] { + &[ + // Vertex AI credentials for claude CLI + // NOTE: VERTEX_ADC is filtered out in process.rs - agent processes use + // fake ADC file created by supervisor instead of real credentials + "ANTHROPIC_VERTEX_PROJECT_ID", + "ANTHROPIC_VERTEX_REGION", + "CLAUDE_CODE_USE_VERTEX", + // NOTE: VERTEX_ACCESS_TOKEN is NOT in this list - it's accessed via + // the SecretResolver in the proxy to inject Authorization headers + ] +} + /// Characters that are valid in an env var key name (used to extract /// placeholder boundaries within concatenated strings like path segments). fn is_env_key_char(b: u8) -> bool { @@ -45,6 +76,7 @@ pub(crate) struct RewriteResult { /// A redacted version of the request target for logging. /// Contains `[CREDENTIAL]` in place of resolved credential values. /// `None` if the target was not modified. + #[allow(dead_code)] pub redacted_target: Option, } @@ -62,13 +94,26 @@ pub(crate) struct RewriteTargetResult { // --------------------------------------------------------------------------- #[derive(Debug, Clone, Default)] -pub struct SecretResolver { +pub(crate) struct SecretResolver { by_placeholder: HashMap, } impl SecretResolver { pub(crate) fn from_provider_env( provider_env: HashMap, + ) -> (HashMap, Option) { + Self::from_provider_env_with_direct_inject(provider_env, &direct_inject_credentials()) + } + + /// Create a resolver from provider environment with selective direct injection. + /// + /// Credentials matching keys in `direct_inject` are injected as actual values + /// into the child environment (for tools like `claude` CLI that need real env vars). + /// All other credentials are converted to `openshell:resolve:env:*` placeholders + /// that get resolved by the HTTP proxy. + pub(crate) fn from_provider_env_with_direct_inject( + provider_env: HashMap, + direct_inject: &[&str], ) -> (HashMap, Option) { if provider_env.is_empty() { return (HashMap::new(), None); @@ -78,12 +123,25 @@ impl SecretResolver { let mut by_placeholder = HashMap::with_capacity(provider_env.len()); for (key, value) in provider_env { - let placeholder = placeholder_for_env_key(&key); - child_env.insert(key, placeholder.clone()); - by_placeholder.insert(placeholder, value); + // Check if this credential should be injected directly + if direct_inject.contains(&key.as_str()) { + // Direct injection: put actual value in environment + child_env.insert(key, value); + } else { + // Placeholder: will be resolved by HTTP proxy + let placeholder = placeholder_for_env_key(&key); + child_env.insert(key, placeholder.clone()); + by_placeholder.insert(placeholder, value); + } } - (child_env, Some(Self { by_placeholder })) + let resolver = if by_placeholder.is_empty() { + None + } else { + Some(Self { by_placeholder }) + }; + + (child_env, resolver) } /// Resolve a placeholder string to the real secret value. diff --git a/crates/openshell-sandbox/src/ssh.rs b/crates/openshell-sandbox/src/ssh.rs index e3add8874..d3f419ea3 100644 --- a/crates/openshell-sandbox/src/ssh.rs +++ b/crates/openshell-sandbox/src/ssh.rs @@ -738,6 +738,18 @@ fn apply_child_env( } for (key, value) in provider_env { + // Filter out OAuth access tokens - these are only needed by the supervisor's + // proxy for header injection, not by agent processes (Claude CLI, bash, etc.) + if key.ends_with("_ACCESS_TOKEN") { + continue; + } + + // Filter out ADC credentials - agent processes use fake ADC file instead + // (created by supervisor based on VERTEX_ADC presence in provider_env) + if key == "VERTEX_ADC" { + continue; + } + cmd.env(key, value); } } diff --git a/crates/openshell-server/Cargo.toml b/crates/openshell-server/Cargo.toml index 0308f30ff..dc6d29814 100644 --- a/crates/openshell-server/Cargo.toml +++ b/crates/openshell-server/Cargo.toml @@ -17,6 +17,7 @@ path = "src/main.rs" [dependencies] openshell-core = { path = "../openshell-core" } openshell-policy = { path = "../openshell-policy" } +openshell-providers = { path = "../openshell-providers" } openshell-router = { path = "../openshell-router" } # Async runtime diff --git a/crates/openshell-server/src/grpc.rs b/crates/openshell-server/src/grpc.rs index d7ef4ccf5..0426b95d3 100644 --- a/crates/openshell-server/src/grpc.rs +++ b/crates/openshell-server/src/grpc.rs @@ -929,18 +929,25 @@ impl OpenShell for OpenShellService { .spec .ok_or_else(|| Status::internal("sandbox has no spec"))?; - let environment = - resolve_provider_environment(self.state.store.as_ref(), &spec.providers).await?; + let (environment, metadata) = resolve_provider_environment( + self.state.store.as_ref(), + &self.state.token_caches, + &spec.providers, + spec.policy.as_ref(), + ) + .await?; info!( sandbox_id = %sandbox_id, provider_count = spec.providers.len(), env_count = environment.len(), + metadata_count = metadata.len(), "GetSandboxProviderEnvironment request completed successfully" ); Ok(Response::new(GetSandboxProviderEnvironmentResponse { environment, + oauth_metadata: metadata, })) } @@ -2515,7 +2522,7 @@ async fn require_no_global_policy(state: &ServerState) -> Result<(), Status> { } async fn merge_chunk_into_policy( - store: &crate::persistence::Store, + store: &Store, sandbox_id: &str, chunk: &DraftChunkRecord, ) -> Result<(i64, String), Status> { @@ -3237,7 +3244,7 @@ fn validate_sandbox_template(tmpl: &SandboxTemplate) -> Result<(), Status> { /// Validate a `map` field: entry count, key length, value length. fn validate_string_map( - map: &std::collections::HashMap, + map: &HashMap, max_entries: usize, max_key_len: usize, max_value_len: usize, @@ -3640,15 +3647,48 @@ fn build_remote_exec_command(req: &ExecSandboxRequest) -> Result /// collects credential key-value pairs. Returns a map of environment variables /// to inject into the sandbox. When duplicate keys appear across providers, the /// first provider's value wins. +/// +/// **OAuth Token Auto-Refresh:** +/// - Detects OAuth credential keys (VERTEX_ADC, etc.) +/// - Creates/reuses TokenCache for OAuth token management with auto-refresh +/// - Returns OAuth tokens that are auto-refreshed every ~55 minutes +/// - Returns metadata with expiry time and auto-refresh configuration async fn resolve_provider_environment( - store: &crate::persistence::Store, + store: &Store, + token_caches: &tokio::sync::Mutex>>, provider_names: &[String], -) -> Result, Status> { + policy: Option<&openshell_core::proto::SandboxPolicy>, +) -> Result< + ( + HashMap, + HashMap, + ), + Status, +> { + use openshell_core::proto::OAuthCredentialMetadata; + if provider_names.is_empty() { - return Ok(std::collections::HashMap::new()); + return Ok((HashMap::new(), HashMap::new())); } - let mut env = std::collections::HashMap::new(); + // Extract OAuth settings from policy (use defaults if not specified) + let (auto_refresh, refresh_margin_seconds, _max_lifetime_seconds) = policy + .and_then(|p| p.oauth_credentials.as_ref()) + .map(|oauth| { + ( + oauth.auto_refresh, + oauth.refresh_margin_seconds, + oauth.max_lifetime_seconds, + ) + }) + .unwrap_or(( + true, // Default: auto-refresh enabled + 300, // Default: 5 minutes before expiry + 86400, // Default: 24 hours max lifetime + )); + + let mut env = HashMap::new(); + let mut metadata = HashMap::new(); for name in provider_names { let provider = store @@ -3659,7 +3699,83 @@ async fn resolve_provider_environment( for (key, value) in &provider.credentials { if is_valid_env_key(key) { - env.entry(key.clone()).or_insert_with(|| value.clone()); + // Check if this credential should use OAuth token auto-refresh + if should_use_token_cache(&provider.r#type, key) { + match get_or_create_token_cache( + token_caches, + name, + &provider.r#type, + key, + value, + auto_refresh, + refresh_margin_seconds, + ) + .await + { + Ok(cache) => { + match cache.get_token_with_expiry(&provider.r#type).await { + Ok((oauth_token, expires_in)) => { + // Trim token to remove trailing newlines that break HTTP headers + let oauth_token = oauth_token.trim().to_string(); + + // For Vertex ADC: keep original JSON, create ACCESS_TOKEN with token + // Claude CLI needs the JSON file for ADC parsing + if provider.r#type == "vertex" && key == "VERTEX_ADC" { + // Keep original ADC JSON (supervisor needs it to create fake ADC file) + env.entry(key.clone()).or_insert_with(|| value.clone()); + // Add OAuth token as VERTEX_ACCESS_TOKEN for proxy injection + env.entry("VERTEX_ACCESS_TOKEN".to_string()) + .or_insert(oauth_token.clone()); + } else { + // For other credentials, replace with OAuth token + env.entry(key.clone()).or_insert(oauth_token.clone()); + } + + // Get config values from provider (with defaults) + let auto_refresh = provider + .config + .get("auto_refresh") + .and_then(|v| v.parse::().ok()) + .unwrap_or(false); // Default: disabled + + let refresh_margin_seconds = provider + .config + .get("refresh_margin_seconds") + .and_then(|v| v.parse::().ok()) + .unwrap_or(300); // Default: 5 minutes + + let max_lifetime_seconds = provider + .config + .get("max_lifetime_seconds") + .and_then(|v| v.parse::().ok()) + .unwrap_or(86400); // Default: 24 hours + + metadata.insert( + key.clone(), + OAuthCredentialMetadata { + expires_in: expires_in as i64, + auto_refresh, + refresh_margin_seconds, + max_lifetime_seconds, + }, + ); + } + Err(e) => { + return Err(Status::internal(format!( + "Failed to get OAuth token from cache: {e}" + ))); + } + } + } + Err(e) => { + return Err(Status::internal(format!( + "Failed to create token cache: {e}" + ))); + } + } + } else { + env.entry(key.clone()).or_insert_with(|| value.clone()); + } } else { warn!( provider_name = %name, @@ -3668,9 +3784,148 @@ async fn resolve_provider_environment( ); } } + + // Also inject config values as environment variables + // (e.g., ANTHROPIC_VERTEX_REGION from provider.config) + for (key, value) in &provider.config { + // Skip OAuth-specific config keys (these are metadata, not env vars) + if matches!( + key.as_str(), + "auto_refresh" | "refresh_margin_seconds" | "max_lifetime_seconds" + ) { + continue; + } + + if is_valid_env_key(key) { + env.entry(key.clone()).or_insert_with(|| value.clone()); + } else { + warn!( + provider_name = %name, + key = %key, + "skipping config with invalid env var key" + ); + } + } + } + + Ok((env, metadata)) +} + +/// Determine if a credential should use TokenCache for OAuth auto-refresh. +/// +/// This function identifies OAuth credentials that need token exchange and +/// auto-refresh. Add new provider types and credential keys here to enable +/// auto-refresh for additional OAuth providers. +/// +/// **Current supported providers:** +/// - **Vertex AI**: VERTEX_ADC (Google Application Default Credentials) +/// +/// **Future providers (examples):** +/// - **AWS Bedrock**: AWS_CREDENTIALS → STS token exchange +/// - **Azure OpenAI**: AZURE_CLIENT_SECRET → Azure AD token exchange +/// - **GitHub**: GITHUB_APP_PRIVATE_KEY → GitHub App JWT + installation token +fn should_use_token_cache(provider_type: &str, credential_key: &str) -> bool { + matches!( + (provider_type, credential_key), + ("vertex", "VERTEX_ADC") // Add more OAuth providers here + // | ("bedrock", "AWS_CREDENTIALS") + // | ("azure", "AZURE_CLIENT_SECRET") + // | ("github-app", "GITHUB_APP_PRIVATE_KEY") + ) +} + +/// Get or create a TokenCache for an OAuth provider. +/// +/// This function ensures that only one TokenCache exists per provider, stored in +/// ServerState. The TokenCache remains alive for the lifetime of the gateway, +/// allowing its background auto-refresh task to run indefinitely. +/// +/// **Benefits:** +/// - Single TokenCache per provider (no duplicate refresh tasks) +/// - Background refresh runs every 55 minutes (for 1-hour tokens) +/// - Tokens stay fresh without sandbox restarts +/// - Multiple sandboxes share the same cached token +/// +/// **Supports:** +/// - Vertex AI (ADC → OAuth token exchange) +/// - Future: AWS Bedrock, Azure OpenAI, GitHub Apps, etc. +async fn get_or_create_token_cache( + token_caches: &tokio::sync::Mutex>>, + provider_name: &str, + provider_type: &str, + credential_key: &str, + credential_value: &str, + auto_refresh: bool, + refresh_margin_seconds: i64, +) -> Result, String> { + use openshell_providers::{DatabaseStore, ProviderPlugin, TokenCache}; + use std::sync::Arc as StdArc; + + // Use a composite key for the cache: provider_name + credential_key + // This allows multiple OAuth credentials per provider if needed + let cache_key = format!("{provider_name}:{credential_key}"); + + let mut caches = token_caches.lock().await; + + // Check if cache already exists + if let Some(cache) = caches.get(&cache_key) { + tracing::debug!( + provider = provider_name, + credential_key = credential_key, + "reusing existing token cache" + ); + return Ok(cache.clone()); } - Ok(env) + // Create new TokenCache + tracing::info!( + provider = provider_name, + provider_type = provider_type, + credential_key = credential_key, + "creating new token cache with auto-refresh" + ); + + // Create provider plugin based on provider type + let provider_plugin: StdArc = match provider_type { + "vertex" => { + // Validate ADC JSON + let _: serde_json::Value = serde_json::from_str(credential_value) + .map_err(|e| format!("Invalid ADC JSON for Vertex: {e}"))?; + StdArc::new(openshell_providers::vertex::VertexProvider::new()) + } + // Future providers can be added here: + // "bedrock" => Arc::new(BedrockProvider::new()), + // "azure" => Arc::new(AzureProvider::new()), + _ => { + return Err(format!( + "Unsupported OAuth provider type for token cache: {provider_type}" + )); + } + }; + + // Create DatabaseStore with the credential + let mut creds = HashMap::new(); + creds.insert(credential_key.to_string(), credential_value.to_string()); + let store = StdArc::new(DatabaseStore::new(creds)); + + // Create TokenCache with policy-configured refresh settings + tracing::info!( + provider = provider_name, + auto_refresh = auto_refresh, + refresh_margin_seconds = refresh_margin_seconds, + "creating token cache with policy-configured settings" + ); + let cache = StdArc::new(TokenCache::new( + provider_plugin, + store, + refresh_margin_seconds, + auto_refresh, + )); + + // Store in ServerState to keep it alive + caches.insert(cache_key, cache.clone()); + + Ok(cache) } fn is_valid_env_key(key: &str) -> bool { @@ -4138,10 +4393,7 @@ fn redact_provider_credentials(mut provider: Provider) -> Provider { provider } -async fn create_provider_record( - store: &crate::persistence::Store, - mut provider: Provider, -) -> Result { +async fn create_provider_record(store: &Store, mut provider: Provider) -> Result { if provider.name.is_empty() { provider.name = generate_name(); } @@ -4176,10 +4428,7 @@ async fn create_provider_record( Ok(redact_provider_credentials(provider)) } -async fn get_provider_record( - store: &crate::persistence::Store, - name: &str, -) -> Result { +async fn get_provider_record(store: &Store, name: &str) -> Result { if name.is_empty() { return Err(Status::invalid_argument("name is required")); } @@ -4193,7 +4442,7 @@ async fn get_provider_record( } async fn list_provider_records( - store: &crate::persistence::Store, + store: &Store, limit: u32, offset: u32, ) -> Result, Status> { @@ -4218,9 +4467,9 @@ async fn list_provider_records( /// - Otherwise, upsert all incoming entries into `existing`. /// - Entries with an empty-string value are removed (delete semantics). fn merge_map( - mut existing: std::collections::HashMap, - incoming: std::collections::HashMap, -) -> std::collections::HashMap { + mut existing: HashMap, + incoming: HashMap, +) -> HashMap { if incoming.is_empty() { return existing; } @@ -4234,10 +4483,7 @@ fn merge_map( existing } -async fn update_provider_record( - store: &crate::persistence::Store, - provider: Provider, -) -> Result { +async fn update_provider_record(store: &Store, provider: Provider) -> Result { if provider.name.is_empty() { return Err(Status::invalid_argument("provider.name is required")); } @@ -4278,10 +4524,7 @@ async fn update_provider_record( Ok(redact_provider_credentials(updated)) } -async fn delete_provider_record( - store: &crate::persistence::Store, - name: &str, -) -> Result { +async fn delete_provider_record(store: &Store, name: &str) -> Result { if name.is_empty() { return Err(Status::invalid_argument("name is required")); } @@ -4872,8 +5115,12 @@ mod tests { #[tokio::test] async fn resolve_provider_env_empty_list_returns_empty() { let store = Store::connect("sqlite::memory:").await.unwrap(); - let result = resolve_provider_environment(&store, &[]).await.unwrap(); - assert!(result.is_empty()); + let token_caches = tokio::sync::Mutex::new(HashMap::new()); + let (env, metadata) = resolve_provider_environment(&store, &token_caches, &[], None) + .await + .unwrap(); + assert!(env.is_empty()); + assert!(metadata.is_empty()); } #[tokio::test] @@ -4896,22 +5143,33 @@ mod tests { .collect(), }; create_provider_record(&store, provider).await.unwrap(); + let token_caches = tokio::sync::Mutex::new(HashMap::new()); - let result = resolve_provider_environment(&store, &["claude-local".to_string()]) - .await - .unwrap(); - assert_eq!(result.get("ANTHROPIC_API_KEY"), Some(&"sk-abc".to_string())); - assert_eq!(result.get("CLAUDE_API_KEY"), Some(&"sk-abc".to_string())); - // Config values should NOT be injected. - assert!(!result.contains_key("endpoint")); + let (env, _metadata) = resolve_provider_environment( + &store, + &token_caches, + &["claude-local".to_string()], + None, + ) + .await + .unwrap(); + assert_eq!(env.get("ANTHROPIC_API_KEY"), Some(&"sk-abc".to_string())); + assert_eq!(env.get("CLAUDE_API_KEY"), Some(&"sk-abc".to_string())); + // Config values are injected as environment variables + assert_eq!( + env.get("endpoint"), + Some(&"https://api.anthropic.com".to_string()) + ); } #[tokio::test] async fn resolve_provider_env_unknown_name_returns_error() { let store = Store::connect("sqlite::memory:").await.unwrap(); - let err = resolve_provider_environment(&store, &["nonexistent".to_string()]) - .await - .unwrap_err(); + let token_caches = tokio::sync::Mutex::new(HashMap::new()); + let err = + resolve_provider_environment(&store, &token_caches, &["nonexistent".to_string()], None) + .await + .unwrap_err(); assert_eq!(err.code(), Code::FailedPrecondition); assert!(err.message().contains("nonexistent")); } @@ -4933,13 +5191,19 @@ mod tests { config: HashMap::new(), }; create_provider_record(&store, provider).await.unwrap(); + let token_caches = tokio::sync::Mutex::new(HashMap::new()); - let result = resolve_provider_environment(&store, &["test-provider".to_string()]) - .await - .unwrap(); - assert_eq!(result.get("VALID_KEY"), Some(&"value".to_string())); - assert!(!result.contains_key("nested.api_key")); - assert!(!result.contains_key("bad-key")); + let (env, _metadata) = resolve_provider_environment( + &store, + &token_caches, + &["test-provider".to_string()], + None, + ) + .await + .unwrap(); + assert_eq!(env.get("VALID_KEY"), Some(&"value".to_string())); + assert!(!env.contains_key("nested.api_key")); + assert!(!env.contains_key("bad-key")); } #[tokio::test] @@ -4974,15 +5238,18 @@ mod tests { ) .await .unwrap(); + let token_caches = tokio::sync::Mutex::new(HashMap::new()); - let result = resolve_provider_environment( + let (env, _metadata) = resolve_provider_environment( &store, + &token_caches, &["claude-local".to_string(), "gitlab-local".to_string()], + None, ) .await .unwrap(); - assert_eq!(result.get("ANTHROPIC_API_KEY"), Some(&"sk-abc".to_string())); - assert_eq!(result.get("GITLAB_TOKEN"), Some(&"glpat-xyz".to_string())); + assert_eq!(env.get("ANTHROPIC_API_KEY"), Some(&"sk-abc".to_string())); + assert_eq!(env.get("GITLAB_TOKEN"), Some(&"glpat-xyz".to_string())); } #[tokio::test] @@ -5017,14 +5284,17 @@ mod tests { ) .await .unwrap(); + let token_caches = tokio::sync::Mutex::new(HashMap::new()); - let result = resolve_provider_environment( + let (env, _metadata) = resolve_provider_environment( &store, + &token_caches, &["provider-a".to_string(), "provider-b".to_string()], + None, ) .await .unwrap(); - assert_eq!(result.get("SHARED_KEY"), Some(&"first-value".to_string())); + assert_eq!(env.get("SHARED_KEY"), Some(&"first-value".to_string())); } /// Simulates the handler flow: persist a sandbox with providers, then resolve @@ -5075,9 +5345,11 @@ mod tests { .unwrap() .unwrap(); let spec = loaded.spec.unwrap(); - let env = resolve_provider_environment(&store, &spec.providers) - .await - .unwrap(); + let token_caches = tokio::sync::Mutex::new(HashMap::new()); + let (env, _metadata) = + resolve_provider_environment(&store, &token_caches, &spec.providers, None) + .await + .unwrap(); assert_eq!(env.get("ANTHROPIC_API_KEY"), Some(&"sk-test".to_string())); } @@ -5106,11 +5378,14 @@ mod tests { .unwrap() .unwrap(); let spec = loaded.spec.unwrap(); - let env = resolve_provider_environment(&store, &spec.providers) - .await - .unwrap(); + let token_caches = tokio::sync::Mutex::new(HashMap::new()); + let (env, metadata) = + resolve_provider_environment(&store, &token_caches, &spec.providers, None) + .await + .unwrap(); assert!(env.is_empty()); + assert!(metadata.is_empty()); } /// Handler returns not-found when sandbox doesn't exist. diff --git a/crates/openshell-server/src/inference.rs b/crates/openshell-server/src/inference.rs index 0fb29bde5..5faa30518 100644 --- a/crates/openshell-server/src/inference.rs +++ b/crates/openshell-server/src/inference.rs @@ -237,7 +237,7 @@ fn resolve_provider_route(provider: &Provider) -> Result Result, + + /// Token caches for OAuth providers (e.g., Vertex AI). + /// Maps provider name to TokenCache with background auto-refresh. + /// Tokens are refreshed 5 minutes before expiry to prevent interruptions. + pub token_caches: tokio::sync::Mutex>>, } fn is_benign_tls_handshake_failure(error: &std::io::Error) -> bool { @@ -102,6 +107,7 @@ impl ServerState { ssh_connections_by_token: Mutex::new(HashMap::new()), ssh_connections_by_sandbox: Mutex::new(HashMap::new()), settings_mutex: tokio::sync::Mutex::new(()), + token_caches: tokio::sync::Mutex::new(HashMap::new()), } } } diff --git a/deploy/docker/Dockerfile.images b/deploy/docker/Dockerfile.images index 05149765b..2e58f56f4 100644 --- a/deploy/docker/Dockerfile.images +++ b/deploy/docker/Dockerfile.images @@ -86,10 +86,14 @@ RUN mkdir -p \ FROM rust-builder-skeleton AS rust-deps +# NOTE: cargo-target and sccache cache mounts are disabled because BuildKit +# persists them across all cleanup operations (docker system prune, mise clean, etc.), +# causing stale builds where code changes don't appear in compiled binaries. +# Removing these mounts ensures clean rebuilds at the cost of slower build times. +# --mount=type=cache,id=cargo-target-${TARGETARCH}-${CARGO_TARGET_CACHE_SCOPE},sharing=locked,target=/build/target \ +# --mount=type=cache,id=sccache-${TARGETARCH},sharing=locked,target=/tmp/sccache \ RUN --mount=type=cache,id=cargo-registry-${TARGETARCH},sharing=locked,target=/usr/local/cargo/registry \ --mount=type=cache,id=cargo-git-${TARGETARCH},sharing=locked,target=/usr/local/cargo/git \ - --mount=type=cache,id=cargo-target-${TARGETARCH}-${CARGO_TARGET_CACHE_SCOPE},sharing=locked,target=/build/target \ - --mount=type=cache,id=sccache-${TARGETARCH},sharing=locked,target=/tmp/sccache \ . cross-build.sh && cargo_cross_build --release -p openshell-server -p openshell-sandbox # --------------------------------------------------------------------------- @@ -151,10 +155,14 @@ FROM supervisor-workspace AS supervisor-builder ARG CARGO_CODEGEN_UNITS ARG EXTRA_CARGO_FEATURES="" +# NOTE: cargo-target and sccache cache mounts are disabled because BuildKit +# persists them across all cleanup operations (docker system prune, mise clean, etc.), +# causing stale builds where code changes don't appear in compiled binaries. +# Removing these mounts ensures clean rebuilds at the cost of slower build times. +# --mount=type=cache,id=cargo-target-${TARGETARCH}-${CARGO_TARGET_CACHE_SCOPE},sharing=locked,target=/build/target \ +# --mount=type=cache,id=sccache-${TARGETARCH},sharing=locked,target=/tmp/sccache \ RUN --mount=type=cache,id=cargo-registry-${TARGETARCH},sharing=locked,target=/usr/local/cargo/registry \ --mount=type=cache,id=cargo-git-${TARGETARCH},sharing=locked,target=/usr/local/cargo/git \ - --mount=type=cache,id=cargo-target-${TARGETARCH}-${CARGO_TARGET_CACHE_SCOPE},sharing=locked,target=/build/target \ - --mount=type=cache,id=sccache-${TARGETARCH},sharing=locked,target=/tmp/sccache \ . cross-build.sh && \ cargo_cross_build --release -p openshell-sandbox ${EXTRA_CARGO_FEATURES:+--features "$EXTRA_CARGO_FEATURES"} && \ mkdir -p /build/out && \ @@ -230,7 +238,7 @@ FROM quay.io/hummingbird/core-runtime:latest-builder AS cluster USER root RUN dnf install -y fedora-repos && \ - dnf install -y \ + dnf install -y --no-best --skip-broken \ ca-certificates \ iptables \ nftables \ diff --git a/docs/get-started/install-podman-macos.md b/docs/get-started/install-podman-macos.md index 3b744c026..abc0a3ac6 100644 --- a/docs/get-started/install-podman-macos.md +++ b/docs/get-started/install-podman-macos.md @@ -35,9 +35,7 @@ brew install podman mise bash scripts/setup-podman-macos.sh source scripts/podman.env mise run cluster:build:full -cargo build --release -p openshell-cli -mkdir -p ~/.local/bin -cp target/release/openshell ~/.local/bin/ +cargo install --path crates/openshell-cli --root ~/.local openshell sandbox create ``` @@ -53,7 +51,7 @@ brew install podman mise The `scripts/setup-podman-macos.sh` script automates Podman Machine configuration: -- Creates a dedicated `openshell` Podman machine (8 GB RAM, 4 CPUs) +- Creates a dedicated `openshell` Podman machine (12 GB RAM, 4 CPUs) - Configures cgroup delegation (required for the embedded k3s cluster) - Stops conflicting machines (only one can run at a time, with user confirmation) @@ -72,7 +70,9 @@ source scripts/podman.env This sets: - `CONTAINER_HOST` - Podman socket path - `OPENSHELL_CONTAINER_RUNTIME=podman` - Use Podman runtime -- `OPENSHELL_REGISTRY=127.0.0.1:5000/openshell` - Local registry for component images +- `OPENSHELL_IMAGE_REPO_BASE=127.0.0.1:5000/openshell` - Local registry for component images +- `OPENSHELL_REGISTRY_HOST=127.0.0.1:5000` - Registry host +- `OPENSHELL_REGISTRY_INSECURE=true` - Allow HTTP registry - `OPENSHELL_CLUSTER_IMAGE=localhost/openshell/cluster:dev` - Local cluster image To make these persistent, add to your shell profile (`~/.zshrc` or `~/.bashrc`): @@ -90,12 +90,13 @@ mise run cluster:build:full ``` This command: -- Builds the gateway image +- Builds the gateway and cluster images - Starts a local container registry at `127.0.0.1:5000` -- Builds the cluster image -- Pushes images to the local registry +- Pushes the gateway image to the local registry - Bootstraps a k3s cluster inside a Podman container -- Deploys the OpenShell gateway +- Deploys and starts the OpenShell gateway + +**Note:** This command builds the images AND starts the gateway in one step. The gateway will be running when the command completes. Or run the script directly: @@ -114,25 +115,54 @@ tasks/scripts/cluster-bootstrap.sh build For a release-optimized binary that works system-wide: ```console -cargo build --release -p openshell-cli -mkdir -p ~/.local/bin -cp target/release/openshell ~/.local/bin/ +cargo install --path crates/openshell-cli --root ~/.local ``` ## Create a Sandbox +The gateway is now running. Create a sandbox to test it: + ```console openshell sandbox create ``` +Verify the gateway is healthy: + +```console +openshell gateway info +``` + +## Rebuilding After Code Changes + +If you're developing OpenShell and need to test code changes, use the rebuild script: + +```console +bash scripts/rebuild-cluster.sh +``` + +This stops the cluster, removes the old image, rebuilds with your changes, and restarts. After rebuilding: +1. Recreate providers (gateway database was reset) +2. Reconfigure inference routing if needed +3. Recreate sandboxes + ## Cleanup -To remove all OpenShell resources and optionally the Podman machine: +### Quick Rebuild (Development) + +```console +bash scripts/rebuild-cluster.sh +``` + +Rebuilds the cluster with latest code changes. Use this during development. + +### Full Cleanup (Start Fresh) ```console bash cleanup-openshell-podman-macos.sh ``` +Removes all OpenShell resources and optionally the Podman machine. Use this to completely reset your installation. + ## Troubleshooting ### Environment variables not set @@ -154,11 +184,11 @@ openshell sandbox create ### Build fails with memory errors -Increase the Podman machine memory allocation: +Increase the Podman machine memory allocation (default is 12 GB): ```console podman machine stop openshell -podman machine set openshell --memory 8192 +podman machine set openshell --memory 16384 podman machine start openshell ``` diff --git a/docs/inference/configure.md b/docs/inference/configure.md index 78065689e..e13567135 100644 --- a/docs/inference/configure.md +++ b/docs/inference/configure.md @@ -100,6 +100,30 @@ This reads `ANTHROPIC_API_KEY` from your environment. :::: +::::{tab-item} Google Cloud Vertex AI + +```console +$ export ANTHROPIC_VERTEX_PROJECT_ID=your-gcp-project-id +$ export ANTHROPIC_VERTEX_REGION=us-east5 # Optional, defaults to us-central1 +$ openshell provider create --name vertex --type vertex --from-existing +``` + +This reads `ANTHROPIC_VERTEX_PROJECT_ID` and `ANTHROPIC_VERTEX_REGION` from your environment and automatically generates OAuth tokens from GCP Application Default Credentials. + +**Prerequisites:** +- Google Cloud project with Vertex AI API enabled and Claude models available +- Application Default Credentials configured: `gcloud auth application-default login` +- The `~/.config/gcloud/` directory must be uploaded to sandboxes for OAuth token refresh + +**Usage:** +- **Direct API calls:** Tools like `claude` CLI automatically use Vertex AI when `CLAUDE_CODE_USE_VERTEX=1` is set +- **Inference routing:** Configure `inference.local` to proxy requests to Vertex AI (see "Set Inference Routing" section below) + +**Model ID Format:** Use `@` separator for versions (e.g., `claude-sonnet-4-5@20250929`) + +:::: + + ::::: ## Set Inference Routing diff --git a/docs/sandboxes/manage-providers.md b/docs/sandboxes/manage-providers.md index 6d35766bf..716c16f5a 100644 --- a/docs/sandboxes/manage-providers.md +++ b/docs/sandboxes/manage-providers.md @@ -179,6 +179,7 @@ The following provider types are supported. | `nvidia` | `NVIDIA_API_KEY` | NVIDIA API Catalog | | `openai` | `OPENAI_API_KEY` | Any OpenAI-compatible endpoint. Set `--config OPENAI_BASE_URL` to point to the provider. Refer to {doc}`/inference/configure`. | | `opencode` | `OPENCODE_API_KEY`, `OPENROUTER_API_KEY`, `OPENAI_API_KEY` | opencode tool | +| `vertex` | `ANTHROPIC_VERTEX_PROJECT_ID`, `VERTEX_OAUTH_TOKEN`, `CLAUDE_CODE_USE_VERTEX` | Google Cloud Vertex AI with Claude models. Automatically generates OAuth tokens from GCP Application Default Credentials. Set `ANTHROPIC_VERTEX_REGION` (optional, defaults to `us-central1`) to control the region. | :::{tip} Use the `generic` type for any service not listed above. You define the @@ -193,6 +194,7 @@ The following providers have been tested with `inference.local`. Any provider th |---|---|---|---|---| | NVIDIA API Catalog | `nvidia-prod` | `nvidia` | `https://integrate.api.nvidia.com/v1` | `NVIDIA_API_KEY` | | Anthropic | `anthropic-prod` | `anthropic` | `https://api.anthropic.com` | `ANTHROPIC_API_KEY` | +| Google Vertex AI | `vertex` | `vertex` | Auto-configured per region | `ANTHROPIC_VERTEX_PROJECT_ID` (OAuth auto-generated) | | Baseten | `baseten` | `openai` | `https://inference.baseten.co/v1` | `OPENAI_API_KEY` | | Bitdeer AI | `bitdeer` | `openai` | `https://api-inference.bitdeer.ai/v1` | `OPENAI_API_KEY` | | Deepinfra | `deepinfra` | `openai` | `https://api.deepinfra.com/v1/openai` | `OPENAI_API_KEY` | diff --git a/examples/vertex-ai/OAUTH_PROVIDERS.md b/examples/vertex-ai/OAUTH_PROVIDERS.md new file mode 100644 index 000000000..ca9d2d8fe --- /dev/null +++ b/examples/vertex-ai/OAUTH_PROVIDERS.md @@ -0,0 +1,406 @@ +# Adding OAuth Auto-Refresh Support for New Providers + +The OpenShell gateway includes a generic OAuth token auto-refresh system that works for any provider implementing the `ProviderPlugin` trait with `get_runtime_token()`. + +## Current Supported Providers + +- **Vertex AI** (`vertex`): VERTEX_ADC → Google OAuth token exchange + +## Adding a New OAuth Provider + +### 1. Implement ProviderPlugin + +Create your provider in `crates/openshell-providers/src/providers/`: + +```rust +// crates/openshell-providers/src/providers/my_oauth_provider.rs +use crate::{ProviderPlugin, SecretStore, TokenResponse, RuntimeResult}; +use async_trait::async_trait; + +pub struct MyOAuthProvider { + client: reqwest::Client, +} + +impl MyOAuthProvider { + pub fn new() -> Self { + Self { + client: reqwest::Client::new(), + } + } +} + +#[async_trait] +impl ProviderPlugin for MyOAuthProvider { + fn id(&self) -> &'static str { + "my-oauth-provider" + } + + fn discover_existing(&self) -> Result, ProviderError> { + // Auto-discover credentials from environment/filesystem + // Store credentials in provider.credentials HashMap + Ok(None) + } + + async fn get_runtime_token(&self, store: &dyn SecretStore) -> RuntimeResult { + // Fetch credential from store + let credential = store.get("MY_OAUTH_CREDENTIAL").await?; + + // Exchange for OAuth token (e.g., AWS STS, Azure AD, etc.) + let token = self.exchange_for_token(&credential).await?; + + Ok(TokenResponse { + access_token: token, + token_type: "Bearer".to_string(), + expires_in: 3600, // 1 hour + metadata: HashMap::new(), + }) + } +} +``` + +### 2. Register Provider in Registry + +Add to `crates/openshell-providers/src/lib.rs`: + +```rust +impl ProviderRegistry { + pub fn new() -> Self { + let mut registry = Self::default(); + // ... existing providers + registry.register(providers::my_oauth_provider::MyOAuthProvider::new()); + registry + } +} +``` + +### 3. Enable TokenCache for Your Provider + +Update `crates/openshell-server/src/grpc.rs`: + +**Step 3a:** Add to `should_use_token_cache()`: + +```rust +fn should_use_token_cache(provider_type: &str, credential_key: &str) -> bool { + matches!( + (provider_type, credential_key), + ("vertex", "VERTEX_ADC") + | ("my-oauth-provider", "MY_OAUTH_CREDENTIAL") // ← Add this line + ) +} +``` + +**Step 3b:** Add to `get_or_create_token_cache()`: + +```rust +let provider_plugin: Arc = match provider_type { + "vertex" => { + let _: serde_json::Value = serde_json::from_str(credential_value)?; + Arc::new(openshell_providers::vertex::VertexProvider::new()) + } + "my-oauth-provider" => { + // Validate credential format if needed + Arc::new(openshell_providers::my_oauth_provider::MyOAuthProvider::new()) + } + _ => { + return Err(format!("Unsupported OAuth provider type: {provider_type}")); + } +}; +``` + +### 4. Export Provider Module + +Add to `crates/openshell-providers/src/lib.rs`: + +```rust +pub mod my_oauth_provider { + pub use crate::providers::my_oauth_provider::*; +} +``` + +### 5. Configure OAuth Header Injection + +Add OAuth header injection to your sandbox policy for endpoints that require it: + +```yaml +# sandbox-policy.yaml +version: 1 + +oauth_credentials: + auto_refresh: true + refresh_margin_seconds: 300 + +network_policies: + my_oauth_api: + name: my-oauth-api + endpoints: + - host: api.my-oauth-service.com + port: 443 + protocol: rest # Required for OAuth injection + access: full + oauth: + token_env_var: MY_OAUTH_TOKEN # Matches provider credential key + header_format: "Bearer {token}" # Or custom format +``` + +The `token_env_var` must match the credential key stored by your provider (e.g., `MY_OAUTH_CREDENTIAL` → token cached as `MY_OAUTH_TOKEN`). + +## OAuth Configuration + +OAuth auto-refresh behavior is configured **in the sandbox policy**, not at provider creation time. Provider creation is only for storing credentials. + +### Sandbox Policy Configuration + +Configure OAuth auto-refresh in your sandbox policy: + +```yaml +# sandbox-policy.yaml +version: 1 + +# OAuth credential auto-refresh configuration +oauth_credentials: + auto_refresh: true # Enable automatic token refresh + refresh_margin_seconds: 300 # Refresh 5 minutes before expiry + max_lifetime_seconds: 7200 # Maximum sandbox lifetime: 2 hours +``` + +### Configuration Fields + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `auto_refresh` | bool | `false` | Enable automatic token refresh. **Must be explicitly enabled for security.** | +| `refresh_margin_seconds` | int64 | `300` | Refresh tokens this many seconds before expiry (e.g., 300 = 5 minutes). | +| `max_lifetime_seconds` | int64 | `86400` | Maximum sandbox lifetime in seconds. `-1` = infinite, `0` or unspecified = 24 hours, `>0` = custom limit. | + +**Security defaults:** +- `auto_refresh: false` - Disabled by default. Sandboxes must be explicitly configured for long-running operation. +- `max_lifetime_seconds: 86400` - 24-hour default limit prevents infinite-running sandboxes. + +### Provider Creation + +When creating a provider, only store the OAuth credential: + +```bash +openshell provider create vertex \ + --type vertex \ + --credential VERTEX_ADC=/path/to/adc.json +``` + +Auto-refresh configuration is handled in the sandbox policy, not at provider creation time. + +## OAuth Header Injection Configuration + +Configure automatic OAuth token injection for specific endpoints in your sandbox policy: + +```yaml +network_policies: + my_api: + name: my-api + endpoints: + - host: api.example.com + port: 443 + protocol: rest # Enable L7 HTTP inspection + access: full # Or use explicit rules + oauth: + token_env_var: MY_OAUTH_TOKEN # Environment variable containing token + header_format: "Bearer {token}" # Authorization header format +``` + +### OAuth Injection Fields + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `token_env_var` | string | required | Environment variable name containing the OAuth token (e.g., `VERTEX_ACCESS_TOKEN`). The proxy resolves this from the gateway `SecretResolver`. | +| `header_format` | string | `"Bearer {token}"` | Authorization header value template. Use `{token}` as placeholder. | + +### How It Works + +When a sandbox makes an HTTP request to an endpoint with `oauth` configuration: + +1. **Policy Evaluation**: L7 proxy checks if endpoint has `oauth` field configured +2. **Token Resolution**: Proxy fetches token from environment variable via gateway `SecretResolver` +3. **Header Injection**: Proxy injects or replaces `Authorization` header using `header_format` template +4. **Request Forwarding**: Modified request forwarded to upstream with OAuth token + +**Example for different OAuth formats:** + +```yaml +# Standard Bearer token (default) +oauth: + token_env_var: GITHUB_TOKEN + header_format: "Bearer {token}" + +# Custom OAuth scheme +oauth: + token_env_var: CUSTOM_TOKEN + header_format: "OAuth {token}" + +# API key in custom header (non-standard but supported) +oauth: + token_env_var: API_KEY + header_format: "{token}" # Just the token, no prefix +``` + +## How Auto-Refresh Works + +### Architecture Overview + +OpenShell uses a **proxy-driven token refresh** model where fresh tokens are fetched on-demand rather than stored in the sandbox: + +``` +┌──────────────┐ ┌──────────────┐ ┌──────────────┐ +│ Gateway │ │ Sandbox │ │ Upstream │ +│ TokenCache │ │ Proxy │ │ API Server │ +└──────────────┘ └──────────────┘ └──────────────┘ + │ │ │ + │ 1. Fetch fresh token │ │ + │◄───────────────────────│ │ + │ │ │ + │ 2. Return token │ │ + ├───────────────────────►│ │ + │ │ 3. Inject Authorization │ + │ │ header │ + │ ├────────────────────────►│ + │ │ │ + │ │ 4. Relay response │ + │ │◄────────────────────────│ + │ │ │ +``` + +### Gateway-Side Token Caching + +1. **Provider Creation**: User creates provider with OAuth credential (e.g., `VERTEX_ADC`) +2. **Gateway Startup**: Gateway creates `TokenCache` when first sandbox uses the provider +3. **Token Exchange**: `TokenCache` calls `get_runtime_token()` to exchange credential for OAuth token +4. **Caching**: Token cached in memory, valid for `expires_in` seconds +5. **Background Refresh** (when `auto_refresh: true`): Background task wakes periodically to refresh tokens +6. **Proactive Refresh**: Token refreshed N seconds before expiry (configurable via `refresh_margin_seconds`) +7. **Shared Cache**: All sandboxes using the same provider share the same `TokenCache` + +### Proxy-Driven Token Refresh + +When the sandbox makes an HTTP request to an OAuth-protected endpoint: + +1. **Policy Lookup**: Proxy checks if endpoint has `oauth` configuration in sandbox policy +2. **Token Fetch**: Proxy fetches fresh token from gateway `TokenCache` via `SecretResolver` +3. **Header Injection**: Proxy injects/replaces `Authorization` header using `header_format` template +4. **Request Forward**: Request forwarded to upstream with valid OAuth token +5. **Seamless Refresh**: Gateway's background task ensures tokens are always fresh + +**Key benefits:** +- Tokens never stored in sandbox (only fetched on-demand via gRPC) +- Gateway handles all token lifecycle management +- Sandbox proxy automatically uses latest token for each request +- No stale token failures even for long-running sandboxes + +## Token Refresh Timing + +For 1-hour OAuth tokens (3600 seconds): +- **Refresh margin**: 300 seconds (5 minutes) +- **Refresh interval**: 3600 - 300 = 3300 seconds (55 minutes) +- **Refresh trigger**: Token refreshed at T+55min (5 min before T+60min expiry) + +For custom token lifetimes: +- Adjust `refresh_margin_seconds` in `TokenCache::new(provider, store, refresh_margin_seconds)` +- Default: 300 seconds (5 minutes) +- Minimum recommended: 60 seconds (1 minute) + +## Example: AWS Bedrock Provider + +```rust +// crates/openshell-providers/src/providers/bedrock.rs +pub struct BedrockProvider { + client: reqwest::Client, +} + +#[async_trait] +impl ProviderPlugin for BedrockProvider { + fn id(&self) -> &'static str { + "bedrock" + } + + async fn get_runtime_token(&self, store: &dyn SecretStore) -> RuntimeResult { + // Fetch AWS credentials + let aws_access_key = store.get("AWS_ACCESS_KEY_ID").await?; + let aws_secret_key = store.get("AWS_SECRET_ACCESS_KEY").await?; + + // Exchange for STS session token + let sts_token = self.get_sts_session_token(&aws_access_key, &aws_secret_key).await?; + + Ok(TokenResponse { + access_token: sts_token, + token_type: "AWS4-HMAC-SHA256".to_string(), + expires_in: 3600, // 1 hour + metadata: HashMap::new(), + }) + } +} +``` + +Then enable in gateway: + +```rust +// should_use_token_cache() +("bedrock", "AWS_ACCESS_KEY_ID") | ("bedrock", "AWS_SECRET_ACCESS_KEY") + +// get_or_create_token_cache() +"bedrock" => Arc::new(openshell_providers::bedrock::BedrockProvider::new()) +``` + +## Testing + +Add test in `crates/openshell-providers/src/providers/my_oauth_provider.rs`: + +```rust +#[cfg(test)] +mod tests { + use super::*; + use crate::{DatabaseStore, TokenCache}; + use std::sync::Arc; + + #[tokio::test] + async fn test_token_exchange() { + let mut creds = HashMap::new(); + creds.insert("MY_OAUTH_CREDENTIAL".to_string(), "test-credential".to_string()); + let store = Arc::new(DatabaseStore::new(creds)); + + let provider = Arc::new(MyOAuthProvider::new()); + let cache = TokenCache::new( + provider, + store, + 300, // refresh_margin_seconds + true, // auto_refresh + ); + + let token = cache.get_token("my-oauth-provider").await.unwrap(); + assert!(!token.is_empty()); + } +} +``` + +## Security Considerations + +1. **Validate credentials** at provider creation time (in `discover_existing()`) +2. **Never log tokens** - only log token metadata (expiry time, etc.) +3. **Clear tokens on error** - TokenCache automatically handles cache invalidation +4. **Use HTTPS only** - All OAuth exchanges must use TLS +5. **Respect token expiry** - Always honor `expires_in` from OAuth provider +6. **Handle revocation** - Return `RuntimeError::AuthFailed` if token is revoked + +## Implemented Features + +- ✅ Gateway-side token caching with background refresh +- ✅ Proxy-driven token fetch (sandbox fetches fresh tokens on-demand from gateway) +- ✅ Generic OAuth header injection via endpoint-level `oauth` configuration +- ✅ Configurable refresh margin per provider (`refresh_margin_seconds`) +- ✅ Maximum sandbox lifetime limits (`max_lifetime_seconds`) +- ✅ Security-first defaults (`auto_refresh: false`) +- ✅ Policy-based OAuth configuration (no hardcoded provider logic) +- ✅ Support for custom `header_format` templates (Bearer, OAuth, custom schemes) + +## Future Enhancements + +- ⏳ Token persistence across gateway restarts (encrypted at-rest storage) +- ⏳ Multi-region token caching (edge deployments) +- ⏳ Token metrics and monitoring (expiry alerts, refresh failures) +- ⏳ Per-sandbox token refresh tracking (observability) +- ⏳ Token rotation support (graceful handling of multiple valid tokens) diff --git a/examples/vertex-ai/README.md b/examples/vertex-ai/README.md new file mode 100644 index 000000000..1f525ac98 --- /dev/null +++ b/examples/vertex-ai/README.md @@ -0,0 +1,760 @@ +# Google Cloud Vertex AI Example + +This example demonstrates how to use OpenShell with Google Cloud Vertex AI to run Claude models via GCP infrastructure. + +## Credential Provider Architecture + +OpenShell uses a **two-layer plugin architecture** for credential management: + +**Layer 1: SecretStore (where credentials live)** +- Generic interface for retrieving raw credentials +- Current implementation: **DatabaseStore** - stores ADC in gateway database +- Future implementations: OneCLI, Vault, GCP Secret Manager, etc. + +**Layer 2: ProviderPlugin (how to interpret credentials)** +- Provider-specific logic for exchanging credentials for tokens +- Current implementation: **VertexProvider** - exchanges ADC for OAuth tokens +- Future implementations: AnthropicProvider, OpenAIProvider, etc. + +**TokenCache (orchestration layer)** +- Wraps ProviderPlugin + SecretStore +- Caches tokens in memory +- Policy-configurable auto-refresh (default: 5 min before expiry) +- Background task spawned only when `oauth_credentials.auto_refresh: true` + +### Current Implementation + +``` +Provider Discovery + └─> ~/.config/gcloud/application_default_credentials.json + └─> Stored in gateway database (provider.credentials["VERTEX_ADC"]) + +Runtime Flow (Sandbox Startup) + └─> Gateway reads sandbox policy oauth_credentials config + └─> DatabaseStore.get("VERTEX_ADC") → ADC JSON + └─> VertexProvider.get_runtime_token(store) → OAuth token + └─> TokenCache(auto_refresh, refresh_margin) → caches token + └─> Sandbox receives VERTEX_ACCESS_TOKEN env var + +HTTP Request Flow (claude CLI → Vertex AI) + └─> Proxy intercepts request to aiplatform.googleapis.com + └─> Matches endpoint with oauth config from policy + └─> Fetches current token from gateway TokenCache + └─> Injects Authorization: Bearer + └─> Forwards to upstream with fresh token +``` + +**How it works:** + +1. **Provider Discovery** - `openshell provider create --name vertex --type vertex --from-existing` + - Auto-detects ADC from `~/.config/gcloud/application_default_credentials.json` + - Stores ADC JSON in gateway database (`provider.credentials["VERTEX_ADC"]`) + - Creates DatabaseStore wrapper around credentials HashMap + +2. **Runtime Token Exchange** - When sandbox starts + - Gateway reads sandbox policy `oauth_credentials` settings + - DatabaseStore fetches ADC from provider.credentials + - VertexProvider exchanges ADC for OAuth access token (valid 1 hour) + - TokenCache caches token in memory (conditionally spawns background task) + - Sandbox receives `VERTEX_ACCESS_TOKEN` as environment variable + +3. **Auto-Refresh** - Gateway background task (policy-configured) + - **Enabled when:** `oauth_credentials.auto_refresh: true` in sandbox policy + - **Refresh timing:** `oauth_credentials.refresh_margin_seconds` before expiry (default: 300 = 5 min) + - **Wake interval:** Token duration minus refresh margin (e.g., 55 min for 1-hour tokens) + - **Updates:** Gateway TokenCache in memory (shared across all sandboxes) + - **Disabled when:** `auto_refresh: false` or field omitted (default) + +4. **OAuth Header Injection** - Proxy fetches fresh tokens on each request + - **Configured via:** `oauth` field on endpoint in sandbox policy + - **Example:** `oauth: {token_env_var: VERTEX_ACCESS_TOKEN, header_format: "Bearer {token}"}` + - **Proxy behavior:** + 1. Intercepts requests matching endpoint host/port + 2. Reads token from environment variable (initial token) OR + 3. Resolves token via SecretResolver (fetches from gateway TokenCache) + 4. Injects/replaces `Authorization: Bearer ` header + 5. Uses fresh token from gateway if auto-refresh enabled + - **Key:** Proxy is responsible for fetching refreshed tokens, not sandbox + - Generic mechanism - works for any OAuth provider (Vertex, AWS Bedrock, Azure, etc.) + +**Security Model:** +- ✅ ADC stored in gateway database (encrypted at rest) +- ✅ OAuth tokens cached in memory only (cleared on restart) +- ✅ Sandboxes receive short-lived tokens (1 hour expiry) +- ✅ Tokens visible to sandbox processes but expire quickly +- ✅ Auto-refresh optional (policy-configured, disabled by default) + +**Future SecretStore Implementations:** + +Adding a new secret store only requires implementing the `SecretStore` trait: + +```rust +#[async_trait] +pub trait SecretStore: Send + Sync { + async fn get(&self, key: &str) -> SecretResult; + async fn health_check(&self) -> SecretResult<()>; + fn name(&self) -> &'static str; +} +``` + +Planned implementations: +- 🔜 **OneCliStore** - AES-256-GCM encrypted credential gateway +- 🔜 **GcpSecretManagerStore** - team secrets in GCP +- 🔜 **VaultStore** - HashiCorp Vault integration +- 🔜 **AwsSecretsManagerStore** - AWS-native secret storage +- 🔜 **BitwardenStore** - password manager integration + +**Note:** OS Keychain and GCP Workload Identity were considered but don't work for containerized gateway deployments (which is the primary use case). Network-based secret stores are the focus for future releases. + +## Quick Start + +### Auto-Discovery from ADC File (Recommended) + +OpenShell automatically discovers your Application Default Credentials from the standard gcloud location. + +**Prerequisites:** +- Google Cloud SDK (`gcloud`) installed +- Vertex AI API enabled in your GCP project + +**Setup:** + +```bash +# 1. Authenticate with Google Cloud +gcloud auth application-default login +# This creates: ~/.config/gcloud/application_default_credentials.json + +# 2. Configure environment +export ANTHROPIC_VERTEX_PROJECT_ID=your-gcp-project-id +export ANTHROPIC_VERTEX_REGION=us-east5 + +# 3. Create provider (auto-discovers ADC file) +openshell provider create --name vertex --type vertex --from-existing +# ✅ Stores ADC in gateway database + +# 4. Create sandbox +openshell sandbox create --name vertex-test \ + --provider vertex \ + --policy examples/vertex-ai/sandbox-policy.yaml + +# 5. Inside sandbox +claude # Automatically uses Vertex AI +``` + +**Complete Flow:** +``` +1. Provider Discovery (openshell provider create) + ~/.config/gcloud/application_default_credentials.json + ↓ (auto-detected & validated) + Gateway Database (provider.credentials["VERTEX_ADC"]) + +2. Sandbox Startup (openshell sandbox create) + Sandbox requests credentials from Gateway + ↓ (gRPC: GetSandboxProviderEnvironment with policy) + Gateway reads oauth_credentials from sandbox policy + ↓ (auto_refresh, refresh_margin_seconds, max_lifetime_seconds) + Gateway exchanges ADC for OAuth token + ↓ (POST https://oauth2.googleapis.com/token) + Gateway creates TokenCache with policy settings + ↓ (conditionally spawns background task if auto_refresh: true) + Gateway sends OAuth token to Sandbox + ↓ (VERTEX_ACCESS_TOKEN environment variable) + Sandbox stores token in memory + ↓ (accessible to proxy for header injection) + +3. HTTP Request (claude CLI → Vertex AI) + Claude CLI makes request to aiplatform.googleapis.com + ↓ (HTTP/HTTPS request) + Sandbox proxy intercepts request + ↓ (matches endpoint host:port from policy) + Proxy finds oauth config on endpoint + ↓ (oauth: {token_env_var: VERTEX_ACCESS_TOKEN, header_format: "Bearer {token}"}) + Proxy fetches current token + ↓ (tries env var first, then resolves from gateway TokenCache) + Proxy injects Authorization header + ↓ (Authorization: Bearer ) + Request forwarded to Vertex AI with real token + +4. Background Refresh (if auto_refresh: true) + Gateway TokenCache wakes up at scheduled interval + ↓ (e.g., every 55 minutes for 1-hour tokens) + Checks if token needs refresh (within margin of expiry) + ↓ (e.g., 5 minutes before expiration) + Re-exchanges ADC for fresh OAuth token + ↓ (POST https://oauth2.googleapis.com/token) + Updates cached token in gateway memory + ↓ (new expiry time, e.g., +1 hour) + Next proxy request fetches fresh token + ↓ (proxy gets updated token from gateway TokenCache) + Sandbox continues without restart + ↓ (proxy handles token refresh transparently) +``` + +### Manual Credential Injection + +If your ADC file is in a different location: + +```bash +# Option 1: Set environment variable +export VERTEX_ADC="$(cat /path/to/your/adc.json)" +openshell provider create --name vertex --type vertex --from-existing + +# Option 2: Inline credential +openshell provider create --name vertex --type vertex \ + --credential VERTEX_ADC="$(cat /path/to/your/adc.json)" +``` + +## What's Included + +- **`sandbox-policy.yaml`**: Network policy allowing Google OAuth and Vertex AI endpoints + - Supports major GCP regions (us-east5, us-central1, us-west1, europe-west1, europe-west4, asia-northeast1) + - Enables direct Claude CLI usage + - Enables `inference.local` routing + +## Security Model + +### Credential Storage + +**What OpenShell stores:** +- ✅ ADC files in gateway database (encrypted at rest) +- ✅ Provider metadata (project ID, region) + +**What OpenShell NEVER stores:** +- ❌ OAuth access tokens in database +- ❌ Credentials in sandboxes +- ❌ Credentials in plaintext + +**OAuth tokens:** +- Generated on-demand by gateway during sandbox startup +- Valid for ~1 hour (Google's default) +- Exchanged fresh on each sandbox creation +- Never persisted to disk + +**Sandboxes receive environment variables:** +```bash +# Inside sandbox environment (what processes see) +VERTEX_ADC='{"type":"...","project_id":"..."}' # ← Full ADC JSON (for Claude CLI to write to file) +VERTEX_ACCESS_TOKEN=ya29.c.a0Aa... # ← OAuth token (for proxy header injection) +ANTHROPIC_VERTEX_PROJECT_ID=your-project # ← Public metadata (direct value) +ANTHROPIC_VERTEX_REGION=us-east5 # ← Public metadata (direct value) +CLAUDE_CODE_USE_VERTEX=1 # ← Boolean flag (direct value) +``` + +**Security considerations:** +- `VERTEX_ADC`: Full ADC JSON visible to all processes (needed for Claude CLI auto-detection) +- `VERTEX_ACCESS_TOKEN`: OAuth token visible to all processes (short-lived, 1 hour expiry) +- Both are injected by gateway at sandbox startup, cleared when sandbox terminates +- OAuth tokens are refreshed in background when `oauth_credentials.auto_refresh: true` + +**On every HTTP request:** +1. OpenShell proxy intercepts request to `aiplatform.googleapis.com` +2. Matches endpoint configuration from policy (host:port) +3. Finds `oauth` config: `{token_env_var: VERTEX_ACCESS_TOKEN, header_format: "Bearer {token}"}` +4. **Proxy fetches current token:** + - First tries environment variable: `$VERTEX_ACCESS_TOKEN` (initial token) + - If auto-refresh enabled: resolves via SecretResolver (fetches from gateway TokenCache) + - Gets fresh token even after background refresh (no sandbox restart needed) +5. Injects/replaces `Authorization` header: `Authorization: Bearer ya29.c.a0Aa...` +6. Forwards request to Vertex AI with real OAuth token + +**Benefits:** +- **Proxy-driven refresh:** Proxy fetches fresh tokens from gateway on each request +- **No sandbox restart:** Background refresh updates gateway cache, proxy fetches automatically +- **Short-lived exposure:** Initial token in environment variable, but expires in 1 hour +- **Centralized management:** Gateway TokenCache manages refresh, sandboxes just consume +- **Secure storage:** ADC stored in gateway database (never exposed to untrusted networks) +- **Generic mechanism:** Works for any OAuth provider (AWS Bedrock, Azure OpenAI, etc.) + +### Token Auto-Refresh + +**By default**, OAuth tokens are **NOT** auto-refreshed. Sandboxes must restart after ~1 hour when tokens expire. + +**For long-running sandboxes**, enable auto-refresh in the **sandbox policy**: + +```yaml +# examples/vertex-ai/sandbox-policy.yaml +version: 1 + +# OAuth credential auto-refresh configuration +oauth_credentials: + auto_refresh: true # Enable automatic token refresh (default: false) + refresh_margin_seconds: 300 # Refresh 5 minutes before expiry (default: 300) + max_lifetime_seconds: 7200 # Maximum sandbox lifetime: 2 hours (default: 86400 = 24h, -1 = infinite) + +network_policies: + # ... rest of policy +``` + +**How it works:** +- Gateway reads `oauth_credentials` from sandbox policy at startup +- Creates TokenCache with configured settings in gateway memory +- Conditionally spawns background task only when `auto_refresh: true` +- Background task wakes up at `token_duration - refresh_margin_seconds` (e.g., 55 min for 1-hour tokens) +- Refreshes tokens proactively before expiration +- Updates TokenCache in gateway memory (shared across sandboxes) +- **Key:** Proxy fetches fresh token from gateway on each request (via SecretResolver) +- Sandbox receives initial token as environment variable at startup +- No sandbox restart needed - proxy transparently uses refreshed tokens + +**Configuration options:** + +| Field | Default | Description | +|-------|---------|-------------| +| `auto_refresh` | `false` | **Must be explicitly enabled.** Allows sandboxes to run longer than token lifetime. | +| `refresh_margin_seconds` | `300` | Refresh tokens 5 minutes before expiry. | +| `max_lifetime_seconds` | `86400` | Maximum sandbox lifetime. `-1` = infinite, `0` = 24h default, `>0` = custom. | + +**How gateway auto-refresh works:** + +**Without auto-refresh (default):** +``` +T+0:00 - Sandbox starts → Gateway exchanges ADC for OAuth token + ↓ (token valid for ~1 hour, cached in gateway TokenCache) +T+0:00 - Sandbox receives VERTEX_ACCESS_TOKEN environment variable (initial token) +T+0:30 - HTTP request → Proxy fetches token from env var + ↓ (injects Authorization: Bearer ) +T+1:00 - Token expires in gateway cache +T+1:01 - HTTP request → Proxy fetches expired token + ↓ (HTTP 401 Unauthorized from Vertex AI) + ↓ (sandbox must be restarted to get fresh token) +``` + +**With auto-refresh enabled (`oauth_credentials.auto_refresh: true`):** +``` +T+0:00 - Sandbox starts → Gateway exchanges ADC for OAuth token + ↓ (token valid for ~1 hour, background task spawned in gateway) +T+0:00 - Sandbox receives VERTEX_ACCESS_TOKEN environment variable (initial token) +T+0:30 - HTTP request → Proxy fetches token from env var + ↓ (injects Authorization: Bearer ) +T+0:55 - Gateway background refresh → Exchanges ADC for new token + ↓ (new token valid until T+1:55, updates gateway TokenCache) +T+1:00 - HTTP request → Proxy resolves token via SecretResolver + ↓ (fetches fresh token from gateway TokenCache) + ↓ (injects Authorization: Bearer ) + ↓ (seamless, no restart needed) +T+1:50 - Gateway background refresh → Exchanges for new token again + ↓ (continues until max_lifetime_seconds reached) +T+2:00 - Sandbox reaches max_lifetime (if configured) → self-terminates +``` + +**Features:** + +- ✅ **Gateway-side refresh:** TokenCache in gateway refreshes tokens in background +- ✅ **Proxy-driven fetch:** Proxy fetches fresh token from gateway on each request +- ✅ **Auto-refresh:** Background task spawned when `auto_refresh: true` in policy +- ✅ **Configurable timing:** `refresh_margin_seconds` (default: 300 = 5 min) +- ✅ **Lifetime limits:** `max_lifetime_seconds` (default: 86400 = 24h, -1 = infinite) +- ✅ **No restarts:** Proxy transparently uses refreshed tokens, no sandbox restart +- ✅ **Seamless updates:** Refresh happens before expiry, no service interruption + +## GKE Deployment + +### 1. Create GCP Service Account + +```bash +# Create service account for OpenShell gateway +gcloud iam service-accounts create openshell-gateway \ + --project=$ANTHROPIC_VERTEX_PROJECT_ID \ + --display-name="OpenShell Gateway" + +# Grant Vertex AI permissions +gcloud projects add-iam-policy-binding $ANTHROPIC_VERTEX_PROJECT_ID \ + --member="serviceAccount:openshell-gateway@${ANTHROPIC_VERTEX_PROJECT_ID}.iam.gserviceaccount.com" \ + --role="roles/aiplatform.user" +``` + +### 2. Configure Workload Identity + +```bash +# Link Kubernetes SA to GCP SA +gcloud iam service-accounts add-iam-policy-binding \ + openshell-gateway@${ANTHROPIC_VERTEX_PROJECT_ID}.iam.gserviceaccount.com \ + --role roles/iam.workloadIdentityUser \ + --member "serviceAccount:${ANTHROPIC_VERTEX_PROJECT_ID}.svc.id.goog[openshell/openshell-gateway]" +``` + +### 3. Deploy Gateway + +```yaml +# gateway-deployment.yaml +apiVersion: v1 +kind: ServiceAccount +metadata: + name: openshell-gateway + namespace: openshell + annotations: + iam.gke.io/gcp-service-account: openshell-gateway@YOUR_PROJECT.iam.gserviceaccount.com +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: openshell-gateway + namespace: openshell +spec: + template: + spec: + serviceAccountName: openshell-gateway + containers: + - name: gateway + image: quay.io/itdove/gateway:dev + env: + - name: ANTHROPIC_VERTEX_PROJECT_ID + value: "your-gcp-project-id" + - name: ANTHROPIC_VERTEX_REGION + value: "us-east5" +``` + +```bash +kubectl apply -f gateway-deployment.yaml +``` + +### 4. Verify Workload Identity + +```bash +# Check that gateway can access GCP metadata service +kubectl exec -n openshell deployment/openshell-gateway -- \ + curl -H "Metadata-Flavor: Google" \ + http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token + +# Should return: +# {"access_token":"ya29.xxx","expires_in":3600,"token_type":"Bearer"} +``` + +## Advanced Configuration + +### Token Exchange On Demand + +OAuth tokens are exchanged fresh on each sandbox startup. This means: + +- **Short-lived credentials:** Tokens expire in ~1 hour +- **No background refresh:** Gateway exchanges tokens synchronously +- **Automatic retry:** Sandbox restart gets fresh token automatically +- **Network required:** Token exchange requires internet access during sandbox startup + +**For production deployments:** + +Consider using short-lived sandboxes (< 1 hour) to minimize credential exposure. This aligns with security best practices and ensures tokens never expire during active sessions. + +**For development workflows:** + +Long-running sandboxes (> 1 hour) will require restart to refresh tokens. Use `openshell sandbox restart ` when you see 401 Unauthorized errors. + +### Multiple Credential Storage (Future) + +**Current implementation:** + +ADC credentials are stored in the OpenShell gateway database. + +**Future feature - pluggable secret stores:** + +Support for external secret management: + +1. **GCP Secret Manager** - Team secrets (future) +2. **HashiCorp Vault** - Multi-cloud (future) +3. **GKE Workload Identity** - Keyless authentication (future) +4. **AWS Secrets Manager** - AWS deployments (future) + +These will allow enterprise deployments to avoid storing credentials in the OpenShell database entirely. + +## Troubleshooting + +### "ADC credentials rejected by Google OAuth" errors + +**Cause:** ADC credentials have expired or been revoked. + +Google Application Default Credentials (ADC) can expire after extended periods of inactivity (typically months). When this happens, token exchange will fail. + +**Solution:** + +```bash +# Re-authenticate with Google Cloud +gcloud auth application-default login + +# Update the provider with fresh credentials +openshell provider create --name vertex --type vertex --from-existing + +# Or delete and recreate +openshell provider delete vertex +openshell provider create --name vertex --type vertex --from-existing +``` + +**How to tell if credentials are expired:** +- Provider creation succeeds but sandbox requests fail with "invalid_grant" +- Error message: "ADC credentials rejected by Google OAuth (status 400)" + +**Prevention:** +- Credentials are validated when you create the provider +- If credentials expire later (days/weeks/months), re-run `gcloud auth application-default login` + +### "Vertex ADC credentials not found" errors + +**Cause:** No ADC file found during provider creation. + +**Solution:** + +```bash +# Generate ADC file +gcloud auth application-default login + +# Verify it was created +ls ~/.config/gcloud/application_default_credentials.json + +# Create provider +openshell provider create --name vertex --type vertex --from-existing +``` + +### "Authentication failed" errors (GKE/Cloud Run) + +**Cause:** Gateway cannot fetch tokens from GCP metadata service. + +**Solution:** + +1. **Verify Workload Identity is configured:** + ```bash + kubectl get sa openshell-gateway -n openshell -o yaml | grep iam.gke.io + # Should show: iam.gke.io/gcp-service-account: openshell-gateway@PROJECT.iam.gserviceaccount.com + ``` + +2. **Check gateway can access metadata service:** + ```bash + kubectl exec -n openshell deployment/openshell-gateway -- \ + curl -H "Metadata-Flavor: Google" \ + http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token + ``` + +3. **Verify GCP service account has permissions:** + ```bash + gcloud projects get-iam-policy $ANTHROPIC_VERTEX_PROJECT_ID \ + --flatten="bindings[].members" \ + --filter="bindings.members:serviceAccount:openshell-gateway@*" + # Should show: roles/aiplatform.user + ``` + +4. **Check gateway logs:** + ```bash + kubectl logs -n openshell deployment/openshell-gateway | grep -i "credential\|token\|workload" + ``` + +### "Project not found" errors + +**Cause:** Invalid or inaccessible GCP project ID. + +**Solution:** + +1. Verify project exists and you have access: + ```bash + gcloud projects describe $ANTHROPIC_VERTEX_PROJECT_ID + ``` + +2. Check Vertex AI API is enabled: + ```bash + gcloud services list --enabled --project=$ANTHROPIC_VERTEX_PROJECT_ID | grep aiplatform + ``` + +3. Enable if needed: + ```bash + gcloud services enable aiplatform.googleapis.com --project=$ANTHROPIC_VERTEX_PROJECT_ID + ``` + +### "Region not supported" errors + +**Cause:** Vertex AI endpoint for your region not in network policy. + +**Solution:** Add region to `sandbox-policy.yaml`: + +```yaml +- host: your-region-aiplatform.googleapis.com + port: 443 + protocol: rest + access: full + oauth: + token_env_var: VERTEX_ACCESS_TOKEN + header_format: "Bearer {token}" +``` + +Supported regions: us-central1, us-east5, us-west1, europe-west1, europe-west4, asia-northeast1, asia-southeast1 + +### Tokens not refreshing + +**Cause:** Auto-refresh not enabled in sandbox policy, or background task failing. + +**Solution:** + +1. **Verify auto-refresh is enabled in sandbox policy:** + ```yaml + # sandbox-policy.yaml must have: + oauth_credentials: + auto_refresh: true # Required for background refresh + refresh_margin_seconds: 300 # Optional (default: 300) + ``` + +2. **Check gateway logs for background refresh:** + ```bash + # Gateway logs should show (only when auto_refresh: true): + # "background refresh triggered" + # "background refresh succeeded" + kubectl logs -n openshell deployment/openshell-gateway | grep "refresh" + + # If you see "Auto-refresh disabled for token cache", check your policy + ``` + +3. **Verify no network issues:** + ```bash + # Test OAuth endpoint from gateway pod + kubectl exec -n openshell deployment/openshell-gateway -- \ + curl -v https://oauth2.googleapis.com/token + ``` + +4. **Check for errors in logs:** + ```bash + kubectl logs -n openshell deployment/openshell-gateway | grep -i "refresh failed" + ``` + +## Documentation + +For detailed setup instructions and configuration options, see: + +- [Credential Provider Plugin Architecture](../../credential-provider-plugin-architecture.md) +- [Provider Management](../../docs/sandboxes/manage-providers.md) +- [Inference Routing](../../docs/inference/configure.md) + +## Architecture + +### Two-Layer Plugin System + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Layer 1: Secret Store (where credentials live) │ +│ │ +│ ┌───────────────┐ ┌────────────────┐ ┌─────────────────┐ │ +│ │ OS Keychain │ │ Workload │ │ GCP Secret │ │ +│ │ macOS/Linux/ │ │ Identity │ │ Manager │ │ +│ │ Windows │ │ (GKE metadata) │ │ (team secrets) │ │ +│ └───────────────┘ └────────────────┘ └─────────────────┘ │ +│ │ │ │ │ +│ └──────────────────┴────────────────────┘ │ +│ │ │ +│ SecretStore trait │ +│ (generic get/health_check) │ +└─────────────────────────────┬───────────────────────────────┘ + │ Raw secret string + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Layer 2: Provider Plugin (how to interpret credentials) │ +│ │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ VertexProvider │ │ +│ │ - Reads ADC JSON from store │ │ +│ │ - Exchanges for OAuth token │ │ +│ │ - Knows Google OAuth endpoint │ │ +│ └──────────────────────────────────────────────────────┘ │ +│ │ │ +│ ProviderPlugin trait │ +│ (get_runtime_token method) │ +└─────────────────────────────┬───────────────────────────────┘ + │ TokenResponse + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ TokenCache (policy-configured, in gateway) │ +│ - Caches tokens in memory (~1 hour) │ +│ - Conditionally spawns background task │ +│ - Config: oauth_credentials from sandbox policy │ +│ - Wraps: ProviderPlugin + SecretStore │ +│ - Background refresh updates cache every 55 min │ +└─────────────────────────────┬───────────────────────────────┘ + │ Initial token at startup + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Sandbox Environment │ +│ VERTEX_ACCESS_TOKEN=ya29.c.a0Aa... (initial token) │ +└─────────────────────────────┬───────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ OpenShell Proxy (L7 HTTP Inspection) │ +│ 1. Intercepts HTTP to aiplatform.googleapis.com │ +│ 2. Matches endpoint with oauth config from policy │ +│ 3. Fetches token: │ +│ - First: tries $VERTEX_ACCESS_TOKEN (env var) │ +│ - Then: resolves via SecretResolver (from gateway) │ +│ 4. Gets fresh token from gateway TokenCache │ +│ 5. Injects Authorization: Bearer │ +│ 6. Forwards to Vertex AI │ +└─────────────────────────────┬───────────────────────────────┘ + │ HTTP with real token + ▼ + Vertex AI Endpoint +``` + +### Local Development Flow (OS Keychain) + +``` +macOS Keychain OpenShell Gateway + │ │ + │ 1. OsKeychainStore.get("vertex") │ + ├───────────────────────────────────>│ + │ │ + │ 2. Returns: ADC JSON │ + │<───────────────────────────────────┤ + │ │ + │ 3. VertexProvider.get_runtime_token(adc) + ├────────────────────────────────> + │ │ + │ Google OAuth + │ │ + │ 4. Returns: OAuth token │ + │<────────────────────────────────┤ + │ + │ 5. TokenCache stores + returns + │ + Sandbox gets token +``` + +### Production Flow (Workload Identity) + +``` +GCP Metadata Service OpenShell Gateway + │ │ + │ 1. WorkloadIdentityStore.get() │ + ├───────────────────────────────────>│ + │ │ + │ 2. Returns: OAuth token (JSON) │ + │<───────────────────────────────────┤ + │ │ + │ 3. VertexProvider.get_runtime_token() + │ Detects Workload Identity format + │ Returns token directly (no exchange) + │ + │ 4. TokenCache stores + returns + │ + Sandbox gets token +``` + +## Migration from ADC Upload Approach + +**Old approach (deprecated):** +```bash +# DON'T DO THIS - old method +openshell sandbox create --provider vertex \ + --upload ~/.config/gcloud/:.config/gcloud/ # ❌ No longer needed +``` + +**New approach:** +```bash +# DO THIS - credential provider plugins +openshell sandbox create --provider vertex # ✅ No upload flag +``` + +**Why the change:** +- ❌ Old: ADC credentials stored in sandbox filesystem +- ✅ New: Only short-lived OAuth tokens (1 hour expiry) +- ❌ Old: Manual token refresh needed (restart sandbox) +- ✅ New: Optional automatic background refresh (policy-configured) +- ❌ Old: Each sandbox manages tokens independently +- ✅ New: Centralized token management at gateway +- ❌ Old: Compromised sandbox = compromised long-lived credentials +- ✅ New: Compromised sandbox = short-lived token (max 1 hour) + +**If you're using the old approach:** +1. Remove `--upload ~/.config/gcloud/` from sandbox creation +2. Deploy gateway with Workload Identity (see GKE Deployment section) +3. Existing sandboxes will continue to work until recreated diff --git a/examples/vertex-ai/sandbox-policy.yaml b/examples/vertex-ai/sandbox-policy.yaml new file mode 100644 index 000000000..8a6dab4ad --- /dev/null +++ b/examples/vertex-ai/sandbox-policy.yaml @@ -0,0 +1,135 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Sandbox policy for Google Cloud Vertex AI +# +# This policy allows sandboxes to access Google Cloud endpoints required for +# Vertex AI with Anthropic Claude models. + +version: 1 + +# OAuth credential auto-refresh configuration (optional) +# Uncomment to enable auto-refresh for long-running sandboxes +# oauth_credentials: +# auto_refresh: true # Enable automatic token refresh (default: false) +# refresh_margin_seconds: 300 # Refresh 5 minutes before expiry (default: 300) +# max_lifetime_seconds: 7200 # Maximum sandbox lifetime: 2 hours (default: 86400 = 24h, -1 = infinite) + +network_policies: + google_vertex: + name: google-vertex + endpoints: + # Google OAuth endpoints (for ADC token exchange - not intercepted) + - host: oauth2.googleapis.com + port: 443 + - host: accounts.google.com + port: 443 + - host: www.googleapis.com + port: 443 + + # Vertex AI endpoints (global and regional) + # protocol: rest enables L7 HTTP inspection + # oauth: configures automatic OAuth token injection into Authorization headers + - host: aiplatform.googleapis.com + port: 443 + protocol: rest + access: full + oauth: + token_env_var: VERTEX_ACCESS_TOKEN + header_format: "Bearer {token}" + - host: us-east5-aiplatform.googleapis.com + port: 443 + protocol: rest + access: full + oauth: + token_env_var: VERTEX_ACCESS_TOKEN + header_format: "Bearer {token}" + - host: us-central1-aiplatform.googleapis.com + port: 443 + protocol: rest + access: full + oauth: + token_env_var: VERTEX_ACCESS_TOKEN + header_format: "Bearer {token}" + - host: us-west1-aiplatform.googleapis.com + port: 443 + protocol: rest + access: full + oauth: + token_env_var: VERTEX_ACCESS_TOKEN + header_format: "Bearer {token}" + - host: europe-west1-aiplatform.googleapis.com + port: 443 + protocol: rest + access: full + oauth: + token_env_var: VERTEX_ACCESS_TOKEN + header_format: "Bearer {token}" + - host: europe-west4-aiplatform.googleapis.com + port: 443 + protocol: rest + access: full + oauth: + token_env_var: VERTEX_ACCESS_TOKEN + header_format: "Bearer {token}" + - host: asia-northeast1-aiplatform.googleapis.com + port: 443 + protocol: rest + access: full + oauth: + token_env_var: VERTEX_ACCESS_TOKEN + header_format: "Bearer {token}" + + binaries: + # Claude CLI for direct Vertex AI usage + - path: /usr/local/bin/claude + # Python for Anthropic SDK usage + - path: /usr/bin/python3 + # curl for testing + - path: /usr/bin/curl + + inference_local: + name: inference-local + endpoints: + # Local inference routing endpoint + - host: inference.local + port: 80 + binaries: + - path: /usr/bin/curl + - path: /usr/bin/python3 + + pypi: + name: pypi + endpoints: + # Python Package Index (PyPI) for pip install + - host: pypi.org + port: 443 + - host: files.pythonhosted.org + port: 443 + - host: "*.pythonhosted.org" + port: 443 + - host: pypi.python.org + port: 443 + binaries: + # Python executables (pip runs as Python subprocess) + - path: /usr/bin/python3 + - path: /usr/bin/python3.13 + - path: /usr/local/bin/python3 + - path: /usr/bin/python + # Pip executables + - path: /usr/local/bin/pip + - path: /usr/local/bin/pip3 + - path: /usr/bin/pip + - path: /usr/bin/pip3 + # Venv paths (pip installs to venv by default) + - path: /sandbox/.venv/bin/python3 + - path: /sandbox/.venv/bin/python3.13 + - path: /sandbox/.venv/bin/python + - path: /sandbox/.venv/bin/pip3 + - path: /sandbox/.venv/bin/pip + # UV Python installation (resolved symlink path) + - path: /sandbox/.uv/python/cpython-3.13.12-linux-aarch64-gnu/bin/python3.13 + - path: /sandbox/.uv/python/cpython-3.13-linux-aarch64-gnu/bin/python3 + - path: /sandbox/.uv/python/cpython-3.13-linux-aarch64-gnu/bin/python + # Testing tools + - path: /usr/bin/curl diff --git a/proto/openshell.proto b/proto/openshell.proto index 04f705020..abc199baa 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -429,10 +429,36 @@ message GetSandboxProviderEnvironmentRequest { string sandbox_id = 1; } +// Metadata for OAuth credentials with auto-refresh support. +message OAuthCredentialMetadata { + // Token expiry time in seconds (from OAuth response). + // Example: 3600 for 1-hour tokens. + int64 expires_in = 1; + + // Whether auto-refresh is enabled for this credential. + // Default: false (tokens expire after expires_in, sandbox limited to 1 hour). + bool auto_refresh = 2; + + // Seconds before expiry to trigger refresh (default: 300 = 5 minutes). + // Only used when auto_refresh = true. + int64 refresh_margin_seconds = 3; + + // Maximum sandbox lifetime in seconds when auto_refresh is enabled. + // -1 = infinite (use with caution!) + // 0 or unspecified = default (86400 = 24 hours) + // >0 = custom limit in seconds + int64 max_lifetime_seconds = 4; +} + // Get sandbox provider environment response. message GetSandboxProviderEnvironmentResponse { // Provider credential environment variables. map environment = 1; + + // Metadata for OAuth credentials (token expiry, auto-refresh config). + // Key matches credential key in environment (e.g., "VERTEX_ADC"). + // Only present for OAuth providers that support token auto-refresh. + map oauth_metadata = 2; } // --------------------------------------------------------------------------- diff --git a/proto/sandbox.proto b/proto/sandbox.proto index 61948a527..8e3689194 100644 --- a/proto/sandbox.proto +++ b/proto/sandbox.proto @@ -17,6 +17,24 @@ message SandboxPolicy { ProcessPolicy process = 4; // Network access policies keyed by name (e.g. "claude_code", "gitlab"). map network_policies = 5; + // OAuth credential auto-refresh policy. + OAuthCredentialsPolicy oauth_credentials = 6; +} + +// OAuth credential auto-refresh policy configuration. +message OAuthCredentialsPolicy { + // Enable automatic token refresh for long-running sandboxes. + // When true, the sandbox will periodically re-fetch credentials from the gateway. + // Default: false (must be explicitly enabled for security). + bool auto_refresh = 1; + // Seconds before expiry to trigger refresh (e.g., 300 = 5 minutes). + // If 0 or unspecified, uses provider default (typically 300). + int64 refresh_margin_seconds = 2; + // Maximum sandbox lifetime in seconds. + // -1 = infinite (no limit) + // 0 or unspecified = 86400 (24 hours default) + // >0 = custom limit + int64 max_lifetime_seconds = 3; } // Filesystem access policy. @@ -53,6 +71,20 @@ message NetworkPolicyRule { repeated NetworkBinary binaries = 3; } +// OAuth header injection configuration for HTTP endpoints. +// Allows automatic injection of OAuth tokens from environment variables +// into Authorization headers for API requests. +message OAuthInjectionConfig { + // Environment variable name containing the OAuth token. + // The token will be resolved from the SecretResolver at request time. + // Example: "VERTEX_ACCESS_TOKEN", "AZURE_ACCESS_TOKEN" + string token_env_var = 1; + // Header value format template. Use {token} as placeholder. + // Example: "Bearer {token}" + // Default: "Bearer {token}" + string header_format = 2; +} + // A network endpoint (host + port) with optional L7 inspection config. message NetworkEndpoint { // Hostname or host glob pattern. Exact match is case-insensitive. @@ -85,6 +117,10 @@ message NetworkEndpoint { // If `port` is set and `ports` is empty, `port` is normalized to `ports: [port]`. // If both are set, `ports` takes precedence. repeated uint32 ports = 9; + // OAuth header injection configuration for this endpoint. + // When set, the proxy will inject or replace the Authorization header + // with the token from the specified environment variable. + OAuthInjectionConfig oauth = 10; } // An L7 policy rule (allow-only). diff --git a/scripts/bin/openshell b/scripts/bin/openshell index 13df76987..6d9ef9663 100755 --- a/scripts/bin/openshell +++ b/scripts/bin/openshell @@ -10,22 +10,8 @@ STATE_FILE="$PROJECT_ROOT/.cache/openshell-build.state" # Bash version compatibility helper # --------------------------------------------------------------------------- -# Read lines into an array variable (bash 3 & 4 compatible) -# Usage: read_lines_into_array array_name < <(command) -read_lines_into_array() { - local array_name=$1 - if ((BASH_VERSINFO[0] >= 4)); then - # Bash 4+: use mapfile (faster) - mapfile -t "$array_name" - else - # Bash 3: use while loop - local line - eval "$array_name=()" - while IFS= read -r line; do - eval "$array_name+=(\"\$line\")" - done - fi -} +# shellcheck source=tasks/scripts/lib/common.sh +source "$PROJECT_ROOT/tasks/scripts/lib/common.sh" # --------------------------------------------------------------------------- # Fingerprint-based rebuild check diff --git a/scripts/podman.env b/scripts/podman.env index 1e74a6b71..459627c0e 100644 --- a/scripts/podman.env +++ b/scripts/podman.env @@ -8,6 +8,11 @@ MACHINE_NAME="${PODMAN_MACHINE_NAME:-openshell}" +# Clear variables from other build workflows that would interfere with local development +unset IMAGE_TAG +unset TAG_LATEST +unset REGISTRY + # Get Podman socket path from the machine if command -v podman &>/dev/null; then SOCKET_PATH=$(podman machine inspect "${MACHINE_NAME}" --format '{{.ConnectionInfo.PodmanSocket.Path}}' 2>/dev/null) @@ -21,13 +26,19 @@ if command -v podman &>/dev/null; then export OPENSHELL_CONTAINER_RUNTIME=podman # Local development image registry - export OPENSHELL_REGISTRY="127.0.0.1:5000/openshell" + export OPENSHELL_IMAGE_REPO_BASE="127.0.0.1:5000/openshell" + export OPENSHELL_REGISTRY_HOST="127.0.0.1:5000" + export OPENSHELL_REGISTRY_NAMESPACE="openshell" + export OPENSHELL_REGISTRY_ENDPOINT="host.containers.internal:5000" + export OPENSHELL_REGISTRY_INSECURE="true" export OPENSHELL_CLUSTER_IMAGE="localhost/openshell/cluster:dev" echo "✓ Podman environment configured:" echo " CONTAINER_HOST=${CONTAINER_HOST}" echo " OPENSHELL_CONTAINER_RUNTIME=${OPENSHELL_CONTAINER_RUNTIME}" - echo " OPENSHELL_REGISTRY=${OPENSHELL_REGISTRY}" + echo " OPENSHELL_IMAGE_REPO_BASE=${OPENSHELL_IMAGE_REPO_BASE}" + echo " OPENSHELL_REGISTRY_HOST=${OPENSHELL_REGISTRY_HOST}" + echo " OPENSHELL_REGISTRY_INSECURE=${OPENSHELL_REGISTRY_INSECURE}" echo " OPENSHELL_CLUSTER_IMAGE=${OPENSHELL_CLUSTER_IMAGE}" fi else diff --git a/scripts/setup-podman-macos.sh b/scripts/setup-podman-macos.sh index 1538259f3..02fdf2343 100755 --- a/scripts/setup-podman-macos.sh +++ b/scripts/setup-podman-macos.sh @@ -9,7 +9,7 @@ set -euo pipefail MACHINE_NAME="${PODMAN_MACHINE_NAME:-openshell}" -MEMORY="${PODMAN_MEMORY:-8192}" +MEMORY="${PODMAN_MEMORY:-12288}" CPUS="${PODMAN_CPUS:-4}" echo "=== OpenShell Podman Setup for macOS ===" @@ -108,9 +108,9 @@ echo "Podman machine '${MACHINE_NAME}' is ready!" echo "" echo "Next steps:" echo " 1. Set up environment: source scripts/podman.env" -echo " 2. Build and deploy: mise run cluster:build:full" -echo " 3. Build CLI: cargo build --release -p openshell-cli" -echo " 4. Install CLI: cp target/release/openshell ~/.local/bin/" +echo " 2. Build and deploy cluster: mise run cluster:build:full" +echo " 3. Install CLI: cargo install --path crates/openshell-cli --root ~/.local" +echo " 4. Verify installation: openshell gateway info" echo "" echo "To make the environment persistent, add to your shell profile (~/.zshrc):" echo " source $(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)/scripts/podman.env" diff --git a/tasks/scripts/cluster-deploy-fast.sh b/tasks/scripts/cluster-deploy-fast.sh index 86fe9746d..3e22f8379 100755 --- a/tasks/scripts/cluster-deploy-fast.sh +++ b/tasks/scripts/cluster-deploy-fast.sh @@ -28,6 +28,9 @@ log_duration() { echo "${label} took $((end - start))s" } +# shellcheck source=lib/common.sh +source "$(dirname "$0")/lib/common.sh" + if ! $CONTAINER_RUNTIME ps -q --filter "name=^${CONTAINER_NAME}$" --filter "health=healthy" | grep -q .; then echo "Error: Cluster container '${CONTAINER_NAME}' is not running or not healthy." echo "Start the cluster first with: mise run cluster" @@ -86,7 +89,7 @@ fi declare -a changed_files=() detect_start=$(date +%s) -mapfile -t changed_files < <( +read_lines_into_array changed_files < <( { git diff --name-only git diff --name-only --cached diff --git a/tasks/scripts/docker-build-image.sh b/tasks/scripts/docker-build-image.sh index 38b200a2e..a76b01d12 100755 --- a/tasks/scripts/docker-build-image.sh +++ b/tasks/scripts/docker-build-image.sh @@ -212,11 +212,13 @@ if [[ "${CONTAINER_RUNTIME}" == "podman" ]]; then ARCH_ARGS+=(--build-arg "BUILDARCH=${TARGETARCH}") fi - # Filter OUTPUT_ARGS: Podman stores images locally by default (no --load) + # Filter OUTPUT_ARGS: Podman doesn't support --load or --push in build command PODMAN_OUTPUT_ARGS=() + PODMAN_SHOULD_PUSH=0 for arg in ${OUTPUT_ARGS[@]+"${OUTPUT_ARGS[@]}"}; do case "${arg}" in --load) ;; # implicit in Podman + --push) PODMAN_SHOULD_PUSH=1 ;; # push after build *) PODMAN_OUTPUT_ARGS+=("${arg}") ;; esac done @@ -227,6 +229,13 @@ if [[ "${CONTAINER_RUNTIME}" == "podman" ]]; then ${TLS_ARGS[@]+"${TLS_ARGS[@]}"} \ ${PODMAN_OUTPUT_ARGS[@]+"${PODMAN_OUTPUT_ARGS[@]}"} \ . + + # Push after build if requested (Podman doesn't support --push in build) + if [[ "${PODMAN_SHOULD_PUSH}" == "1" && "${IS_FINAL_IMAGE}" == "1" ]]; then + echo "Pushing ${IMAGE_NAME}:${IMAGE_TAG}..." + podman_local_tls_args "${IMAGE_NAME}" + podman push ${PODMAN_TLS_ARGS[@]+"${PODMAN_TLS_ARGS[@]}"} "${IMAGE_NAME}:${IMAGE_TAG}" + fi else # Docker: use buildx docker buildx build \ diff --git a/tasks/scripts/docker-publish-multiarch.sh b/tasks/scripts/docker-publish-multiarch.sh index f83a7c203..ca68a057a 100755 --- a/tasks/scripts/docker-publish-multiarch.sh +++ b/tasks/scripts/docker-publish-multiarch.sh @@ -27,8 +27,59 @@ fi if [[ "${CONTAINER_RUNTIME}" == "podman" ]]; then echo "Using Podman for multi-arch build (podman manifest)" + echo "Note: Podman builds platforms sequentially (slower than Docker buildx)" export DOCKER_BUILDER="" + + # Podman implements multi-arch via explicit manifest creation + per-platform + # builds. Cannot use docker-build-image.sh here because it builds single + # images, not manifests. Docker buildx handles multi-arch internally, so the + # Docker path below can delegate to docker-build-image.sh. + IFS=',' read -ra PLATFORM_ARRAY <<< "${PLATFORMS}" + + for component in gateway cluster; do + full_image="${REGISTRY}/${component}" + echo "" + echo "=== Building multi-arch ${component} image ===" + + # Create manifest list + podman manifest rm "${full_image}:${IMAGE_TAG}" 2>/dev/null || true + podman manifest create "${full_image}:${IMAGE_TAG}" + + # Build for each platform + for platform in "${PLATFORM_ARRAY[@]}"; do + arch="${platform##*/}" + case "${arch}" in + amd64) target_arch="amd64" ;; + arm64) target_arch="arm64" ;; + *) echo "Unsupported arch: ${arch}" >&2; exit 1 ;; + esac + + echo "Building ${component} for ${platform}..." + + # Package Helm chart for cluster builds + if [[ "${component}" == "cluster" ]]; then + mkdir -p deploy/docker/.build/charts + helm package deploy/helm/openshell -d deploy/docker/.build/charts/ >/dev/null + fi + + # Build with explicit TARGETARCH/BUILDARCH to avoid cross-compilation + # (QEMU emulation handles running the different architecture) + podman build --platform "${platform}" \ + --build-arg TARGETARCH="${target_arch}" \ + --build-arg BUILDARCH="${target_arch}" \ + --manifest "${full_image}:${IMAGE_TAG}" \ + -f deploy/docker/Dockerfile.images \ + --target "${component}" \ + . + done + + # Push manifest + echo "Pushing ${full_image}:${IMAGE_TAG}..." + podman manifest push "${full_image}:${IMAGE_TAG}" \ + "docker://${full_image}:${IMAGE_TAG}" + done else + # Docker: use buildx BUILDER_NAME=${DOCKER_BUILDER:-multiarch} if docker buildx inspect "${BUILDER_NAME}" >/dev/null 2>&1; then echo "Using existing buildx builder: ${BUILDER_NAME}" @@ -38,19 +89,21 @@ else docker buildx create --name "${BUILDER_NAME}" --use --bootstrap fi export DOCKER_BUILDER="${BUILDER_NAME}" -fi -export DOCKER_PLATFORM="${PLATFORMS}" -export DOCKER_PUSH=1 -export IMAGE_REGISTRY="${REGISTRY}" + export DOCKER_PLATFORM="${PLATFORMS}" + export DOCKER_PUSH=1 + export IMAGE_REGISTRY="${REGISTRY}" -echo "Building multi-arch gateway image..." -tasks/scripts/docker-build-image.sh gateway + echo "Building multi-arch gateway image..." + tasks/scripts/docker-build-image.sh gateway -echo -echo "Building multi-arch cluster image..." -tasks/scripts/docker-build-image.sh cluster + echo + echo "Building multi-arch cluster image..." + tasks/scripts/docker-build-image.sh cluster +fi -TAGS_TO_APPLY=("${EXTRA_TAGS[@]}") +# Build list of additional tags to apply (beyond IMAGE_TAG which is already set). +# Combines EXTRA_TAGS with optional "latest" tag without modifying EXTRA_TAGS. +TAGS_TO_APPLY=(${EXTRA_TAGS[@]+"${EXTRA_TAGS[@]}"}) if [[ "${TAG_LATEST}" == "true" ]]; then TAGS_TO_APPLY+=("latest") fi @@ -58,7 +111,7 @@ fi if [[ ${#TAGS_TO_APPLY[@]} -gt 0 ]]; then for component in gateway cluster; do full_image="${REGISTRY}/${component}" - for tag in "${TAGS_TO_APPLY[@]}"; do + for tag in ${TAGS_TO_APPLY[@]+"${TAGS_TO_APPLY[@]}"}; do [[ "${tag}" == "${IMAGE_TAG}" ]] && continue echo "Tagging ${full_image}:${tag}..." if [[ "${CONTAINER_RUNTIME}" == "podman" ]]; then diff --git a/tasks/scripts/lib/common.sh b/tasks/scripts/lib/common.sh new file mode 100644 index 000000000..42fc9e828 --- /dev/null +++ b/tasks/scripts/lib/common.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash + +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Common shell functions shared across task scripts + +# Read lines into an array variable (bash 3 & 4 compatible) +# Usage: read_lines_into_array array_name < <(command) +# +# Example: +# read_lines_into_array my_files < <(ls *.txt) +# for file in "${my_files[@]}"; do +# echo "$file" +# done +read_lines_into_array() { + local array_name=$1 + if ((BASH_VERSINFO[0] >= 4)); then + # Bash 4+: use mapfile (faster) + mapfile -t "$array_name" + else + # Bash 3: use while loop (macOS default bash is 3.x) + local line + eval "$array_name=()" + while IFS= read -r line; do + eval "$array_name+=(\"\$line\")" + done + fi +}