From 84bb4379488e6c18fe1a2a37d02e1ed6ca170377 Mon Sep 17 00:00:00 2001 From: Drew Newberry Date: Sun, 12 Apr 2026 22:12:08 -0700 Subject: [PATCH 1/7] refactor(server): extract kubernetes compute driver Signed-off-by: Drew Newberry --- Cargo.lock | 26 +- architecture/gateway.md | 31 +- crates/openshell-core/build.rs | 1 + crates/openshell-core/src/proto/mod.rs | 14 + crates/openshell-driver-kubernetes/Cargo.toml | 37 + .../openshell-driver-kubernetes/src/config.rs | 16 + .../src/driver.rs} | 757 ++++++++---------- .../openshell-driver-kubernetes/src/grpc.rs | 112 +++ crates/openshell-driver-kubernetes/src/lib.rs | 10 + .../openshell-driver-kubernetes/src/main.rs | 88 ++ crates/openshell-server/Cargo.toml | 4 +- crates/openshell-server/src/compute/mod.rs | 511 ++++++++++++ crates/openshell-server/src/grpc/mod.rs | 2 +- crates/openshell-server/src/grpc/policy.rs | 4 +- crates/openshell-server/src/grpc/sandbox.rs | 193 +---- crates/openshell-server/src/lib.rs | 69 +- crates/openshell-server/src/sandbox_watch.rs | 115 --- crates/openshell-server/src/ssh_tunnel.rs | 29 +- proto/compute_driver.proto | 100 +++ proto/datamodel.proto | 70 -- proto/openshell.proto | 26 +- proto/sandbox.proto | 85 ++ 22 files changed, 1427 insertions(+), 873 deletions(-) create mode 100644 crates/openshell-driver-kubernetes/Cargo.toml create mode 100644 crates/openshell-driver-kubernetes/src/config.rs rename crates/{openshell-server/src/sandbox/mod.rs => openshell-driver-kubernetes/src/driver.rs} (80%) create mode 100644 crates/openshell-driver-kubernetes/src/grpc.rs create mode 100644 crates/openshell-driver-kubernetes/src/lib.rs create mode 100644 crates/openshell-driver-kubernetes/src/main.rs create mode 100644 crates/openshell-server/src/compute/mod.rs create mode 100644 proto/compute_driver.proto diff --git a/Cargo.lock b/Cargo.lock index e09c2583a..516665544 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3067,6 +3067,28 @@ dependencies = [ "url", ] +[[package]] +name = "openshell-driver-kubernetes" +version = "0.0.0" +dependencies = [ + "clap", + "futures", + "k8s-openapi", + "kube", + "kube-runtime", + "miette", + "openshell-core", + "prost", + "prost-types", + "serde_json", + "thiserror 2.0.18", + "tokio", + "tokio-stream", + "tonic", + "tracing", + "tracing-subscriber", +] + [[package]] name = "openshell-ocsf" version = "0.0.0" @@ -3191,11 +3213,9 @@ dependencies = [ "hyper", "hyper-rustls", "hyper-util", - "k8s-openapi", - "kube", - "kube-runtime", "miette", "openshell-core", + "openshell-driver-kubernetes", "openshell-policy", "openshell-router", "petname", diff --git a/architecture/gateway.md b/architecture/gateway.md index 72574410d..53e547235 100644 --- a/architecture/gateway.md +++ b/architecture/gateway.md @@ -70,9 +70,10 @@ graph TD | Persistence | `crates/openshell-server/src/persistence/mod.rs` | `Store` enum (SQLite/Postgres), generic object CRUD, protobuf codec | | Persistence: SQLite | `crates/openshell-server/src/persistence/sqlite.rs` | `SqliteStore` with sqlx | | Persistence: Postgres | `crates/openshell-server/src/persistence/postgres.rs` | `PostgresStore` with sqlx | -| Sandbox K8s | `crates/openshell-server/src/sandbox/mod.rs` | `SandboxClient`, CRD creation/deletion, Kubernetes watcher, phase derivation | +| Compute runtime | `crates/openshell-server/src/compute/mod.rs` | `ComputeRuntime`, gateway-owned sandbox lifecycle orchestration over a compute backend | +| Compute driver: Kubernetes | `crates/openshell-driver-kubernetes/src/driver.rs` | Kubernetes CRD create/delete, endpoint resolution, watch stream, pod template translation | | Sandbox index | `crates/openshell-server/src/sandbox_index.rs` | `SandboxIndex` -- in-memory name/pod-to-id correlation | -| Watch bus | `crates/openshell-server/src/sandbox_watch.rs` | `SandboxWatchBus`, `PlatformEventBus`, Kubernetes event tailer | +| Watch bus | `crates/openshell-server/src/sandbox_watch.rs` | `SandboxWatchBus` -- in-memory broadcast for persisted sandbox updates | | Tracing bus | `crates/openshell-server/src/tracing_bus.rs` | `TracingLogBus` -- captures tracing events keyed by `sandbox_id` | Proto definitions consumed by the gateway: @@ -80,9 +81,10 @@ Proto definitions consumed by the gateway: | Proto file | Package | Defines | |------------|---------|---------| | `proto/openshell.proto` | `openshell.v1` | `OpenShell` service, sandbox/provider/SSH/watch messages | +| `proto/compute_driver.proto` | `openshell.compute.v1` | Internal `ComputeDriver` service, endpoint resolution, compute watch stream envelopes | | `proto/inference.proto` | `openshell.inference.v1` | `Inference` service: `SetClusterInference`, `GetClusterInference`, `GetInferenceBundle` | -| `proto/datamodel.proto` | `openshell.datamodel.v1` | `Sandbox`, `SandboxSpec`, `SandboxStatus`, `Provider`, `SandboxPhase` | -| `proto/sandbox.proto` | `openshell.sandbox.v1` | `SandboxPolicy`, `NetworkPolicyRule`, `SettingValue`, `EffectiveSetting`, `SettingScope`, `PolicySource`, `GetSandboxSettingsRequest/Response`, `GetGatewaySettingsRequest/Response` | +| `proto/datamodel.proto` | `openshell.datamodel.v1` | `Provider` | +| `proto/sandbox.proto` | `openshell.sandbox.v1` | Shared sandbox lifecycle types (`Sandbox`, `SandboxSpec`, `SandboxStatus`, `SandboxPhase`, `PlatformEvent`) plus policy/settings messages | ## Startup Sequence @@ -94,11 +96,10 @@ The gateway boots in `main()` (`crates/openshell-server/src/main.rs`) and procee 4. **Build `Config`** -- Assembles a `openshell_core::Config` from the parsed arguments. 5. **Call `run_server()`** (`crates/openshell-server/src/lib.rs`): 1. Connect to the persistence store (`Store::connect`), which auto-detects SQLite vs Postgres from the URL prefix and runs migrations. - 2. Create `SandboxClient` (initializes a `kube::Client` from in-cluster or kubeconfig). + 2. Create `ComputeRuntime` with the in-process Kubernetes compute backend (`KubernetesComputeDriver`). 3. Build `ServerState` (shared via `Arc` across all handlers). 4. **Spawn background tasks**: - - `spawn_sandbox_watcher` -- watches Kubernetes Sandbox CRDs and syncs state to the store. - - `spawn_kube_event_tailer` -- watches Kubernetes Events in the sandbox namespace and publishes them to the `PlatformEventBus`. + - `ComputeRuntime::spawn_watchers` -- consumes the compute-driver watch stream, updates persisted sandbox records, and republishes platform events. 5. Create `MultiplexService`. 6. Bind `TcpListener` on `config.bind_address`. 7. Optionally create `TlsAcceptor` from cert/key files. @@ -137,7 +138,7 @@ All handlers share an `Arc` (`crates/openshell-server/src/lib.rs`): pub struct ServerState { pub config: Config, pub store: Arc, - pub sandbox_client: SandboxClient, + pub compute: ComputeRuntime, pub sandbox_index: SandboxIndex, pub sandbox_watch_bus: SandboxWatchBus, pub tracing_log_bus: TracingLogBus, @@ -148,10 +149,10 @@ pub struct ServerState { ``` - **`store`** -- persistence backend (SQLite or Postgres) for all object types. -- **`sandbox_client`** -- Kubernetes client scoped to the sandbox namespace; creates/deletes CRDs and resolves pod IPs. -- **`sandbox_index`** -- in-memory bidirectional index mapping sandbox names and agent pod names to sandbox IDs. Used by the event tailer to correlate Kubernetes events. +- **`compute`** -- gateway-owned compute orchestration. Persists sandbox lifecycle transitions, validates create requests through the compute backend, resolves exec/SSH endpoints, and consumes the backend watch stream. +- **`sandbox_index`** -- in-memory bidirectional index mapping sandbox names and agent pod names to sandbox IDs. Updated from compute-driver sandbox snapshots. - **`sandbox_watch_bus`** -- `broadcast`-based notification bus keyed by sandbox ID. Producers call `notify(&id)` when the persisted sandbox record changes; consumers in `WatchSandbox` streams receive `()` signals and re-read the record. -- **`tracing_log_bus`** -- captures `tracing` events that include a `sandbox_id` field and republishes them as `SandboxLogLine` messages. Maintains a per-sandbox tail buffer (default 200 entries). Also contains a nested `PlatformEventBus` for Kubernetes events. +- **`tracing_log_bus`** -- captures `tracing` events that include a `sandbox_id` field and republishes them as `SandboxLogLine` messages. Maintains a per-sandbox tail buffer (default 200 entries). Also contains a nested `PlatformEventBus` for compute-driver platform events. - **`settings_mutex`** -- serializes settings mutations (global and sandbox) to prevent read-modify-write races. Held for the duration of any setting set/delete or global policy set/delete operation. See [Gateway Settings Channel](gateway-settings.md#global-policy-lifecycle). ## Protocol Multiplexing @@ -499,15 +500,15 @@ The Helm chart template is at `deploy/helm/openshell/templates/statefulset.yaml` ### Sandbox CRD Management -`SandboxClient` (`crates/openshell-server/src/sandbox/mod.rs`) manages `agents.x-k8s.io/v1alpha1/Sandbox` CRDs. +`KubernetesComputeDriver` (`crates/openshell-driver-kubernetes/src/driver.rs`) manages `agents.x-k8s.io/v1alpha1/Sandbox` CRDs behind the gateway's compute interface. -- **Create**: Translates a `Sandbox` proto into a Kubernetes `DynamicObject` with labels (`openshell.ai/sandbox-id`, `openshell.ai/managed-by: openshell`) and a spec that includes the pod template, environment variables, and gateway-required env vars (`OPENSHELL_SANDBOX_ID`, `OPENSHELL_ENDPOINT`, `OPENSHELL_SSH_LISTEN_ADDR`, etc.). When callers do not provide custom `volumeClaimTemplates`, the server injects a default `workspace` PVC and mounts it at `/sandbox` so the default sandbox home/workdir survives pod rescheduling. +- **Create**: Translates a shared `openshell.sandbox.v1.Sandbox` message into a Kubernetes `DynamicObject` with labels (`openshell.ai/sandbox-id`, `openshell.ai/managed-by: openshell`) and a spec that includes the pod template, environment variables, and gateway-required env vars (`OPENSHELL_SANDBOX_ID`, `OPENSHELL_ENDPOINT`, `OPENSHELL_SSH_LISTEN_ADDR`, etc.). When callers do not provide custom `volumeClaimTemplates`, the driver injects a default `workspace` PVC and mounts it at `/sandbox` so the default sandbox home/workdir survives pod rescheduling. - **Delete**: Calls the Kubernetes API to delete the CRD by name. Returns `false` if already gone (404). - **Pod IP resolution**: `agent_pod_ip()` fetches the agent pod and reads `status.podIP`. ### Sandbox Watcher -`spawn_sandbox_watcher()` (`crates/openshell-server/src/sandbox/mod.rs`) runs a Kubernetes watcher on `Sandbox` CRDs and processes three event types: +The Kubernetes driver emits `WatchSandboxes` events through `proto/compute_driver.proto`. `ComputeRuntime` consumes that stream and applies the resulting snapshots to the store. - **Applied**: Extracts the sandbox ID from labels (or falls back to name prefix stripping), reads the CRD status, derives the phase, and upserts the sandbox record in the store. Notifies the watch bus. - **Deleted**: Removes the sandbox record from the store and the index. Notifies the watch bus. @@ -530,7 +531,7 @@ All other `Ready=False` reasons are treated as terminal failures (`Error` phase) ### Kubernetes Event Tailer -`spawn_kube_event_tailer()` (`crates/openshell-server/src/sandbox_watch.rs`) watches all Kubernetes `Event` objects in the sandbox namespace and correlates them to sandbox IDs using `SandboxIndex`: +The Kubernetes driver also watches namespace-scoped Kubernetes `Event` objects and correlates them to sandbox IDs before emitting them as compute-driver platform events: - Events involving `kind: Sandbox` are correlated by sandbox name. - Events involving `kind: Pod` are correlated by agent pod name. diff --git a/crates/openshell-core/build.rs b/crates/openshell-core/build.rs index f44cdc75f..c89a03483 100644 --- a/crates/openshell-core/build.rs +++ b/crates/openshell-core/build.rs @@ -32,6 +32,7 @@ fn main() -> Result<(), Box> { "../../proto/openshell.proto", "../../proto/datamodel.proto", "../../proto/sandbox.proto", + "../../proto/compute_driver.proto", "../../proto/inference.proto", "../../proto/test.proto", ]; diff --git a/crates/openshell-core/src/proto/mod.rs b/crates/openshell-core/src/proto/mod.rs index d8d382a31..2644cb39a 100644 --- a/crates/openshell-core/src/proto/mod.rs +++ b/crates/openshell-core/src/proto/mod.rs @@ -42,6 +42,19 @@ pub mod sandbox { } } +#[allow( + clippy::all, + clippy::pedantic, + clippy::nursery, + unused_qualifications, + rust_2018_idioms +)] +pub mod compute { + pub mod v1 { + include!(concat!(env!("OUT_DIR"), "/openshell.compute.v1.rs")); + } +} + #[allow( clippy::all, clippy::pedantic, @@ -66,6 +79,7 @@ pub mod inference { } } +pub use compute::v1::*; pub use datamodel::v1::*; pub use inference::v1::*; pub use openshell::*; diff --git a/crates/openshell-driver-kubernetes/Cargo.toml b/crates/openshell-driver-kubernetes/Cargo.toml new file mode 100644 index 000000000..5e247dc77 --- /dev/null +++ b/crates/openshell-driver-kubernetes/Cargo.toml @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +[package] +name = "openshell-driver-kubernetes" +description = "Kubernetes compute driver for OpenShell" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +repository.workspace = true + +[[bin]] +name = "openshell-driver-kubernetes" +path = "src/main.rs" + +[dependencies] +openshell-core = { path = "../openshell-core" } + +tokio = { workspace = true } +tonic = { workspace = true, features = ["transport"] } +prost = { workspace = true } +prost-types = { workspace = true } +futures = { workspace = true } +tokio-stream = { workspace = true } +kube = { workspace = true } +kube-runtime = { workspace = true } +k8s-openapi = { workspace = true } +serde_json = { workspace = true } +clap = { workspace = true } +tracing = { workspace = true } +tracing-subscriber = { workspace = true } +thiserror = { workspace = true } +miette = { workspace = true } + +[lints] +workspace = true diff --git a/crates/openshell-driver-kubernetes/src/config.rs b/crates/openshell-driver-kubernetes/src/config.rs new file mode 100644 index 000000000..3ce98eae8 --- /dev/null +++ b/crates/openshell-driver-kubernetes/src/config.rs @@ -0,0 +1,16 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#[derive(Debug, Clone)] +pub struct KubernetesComputeConfig { + pub namespace: String, + pub default_image: String, + pub image_pull_policy: String, + pub grpc_endpoint: String, + pub ssh_listen_addr: String, + pub ssh_port: u16, + pub ssh_handshake_secret: String, + pub ssh_handshake_skew_secs: u64, + pub client_tls_secret_name: String, + pub host_gateway_ip: String, +} diff --git a/crates/openshell-server/src/sandbox/mod.rs b/crates/openshell-driver-kubernetes/src/driver.rs similarity index 80% rename from crates/openshell-server/src/sandbox/mod.rs rename to crates/openshell-driver-kubernetes/src/driver.rs index a5d7dc071..85d781045 100644 --- a/crates/openshell-server/src/sandbox/mod.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -1,25 +1,52 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -//! Kubernetes sandbox integration. +//! Kubernetes compute driver. -use crate::persistence::{ObjectId, ObjectName, ObjectType, Store}; -use futures::{StreamExt, TryStreamExt}; -use k8s_openapi::api::core::v1::{Node, Pod}; +use crate::config::KubernetesComputeConfig; +use futures::{Stream, StreamExt, TryStreamExt}; +use k8s_openapi::api::core::v1::{Event as KubeEventObj, Node, Pod}; use kube::api::{Api, ApiResource, DeleteParams, ListParams, PostParams}; use kube::core::gvk::GroupVersionKind; use kube::core::{DynamicObject, ObjectMeta}; use kube::runtime::watcher::{self, Event}; use kube::{Client, Error as KubeError}; use openshell_core::proto::{ - Sandbox, SandboxCondition, SandboxPhase, SandboxSpec, SandboxStatus, SandboxTemplate, + GetCapabilitiesResponse, PlatformEvent, ResolveSandboxEndpointResponse, Sandbox, + SandboxCondition, SandboxEndpoint, SandboxPhase, SandboxSpec, SandboxStatus, SandboxTemplate, + WatchSandboxesDeletedEvent, WatchSandboxesEvent, WatchSandboxesPlatformEvent, + WatchSandboxesSandboxEvent, }; use std::collections::BTreeMap; use std::net::IpAddr; -use std::sync::Arc; +use std::pin::Pin; use std::time::Duration; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; use tracing::{debug, info, warn}; +pub type WatchStream = + Pin> + Send>>; + +#[derive(Debug, thiserror::Error)] +pub enum KubernetesDriverError { + #[error("sandbox already exists")] + AlreadyExists, + #[error("{0}")] + Precondition(String), + #[error("{0}")] + Message(String), +} + +impl KubernetesDriverError { + fn from_kube(err: KubeError) -> Self { + match err { + KubeError::Api(api) if api.code == 409 => Self::AlreadyExists, + other => Self::Message(other.to_string()), + } + } +} + /// Timeout for individual Kubernetes API calls (create, delete, get). /// This prevents gRPC handlers from blocking indefinitely when the k8s /// API server is unreachable or slow. @@ -73,120 +100,98 @@ const WORKSPACE_DEFAULT_STORAGE: &str = "2Gi"; const WORKSPACE_SENTINEL: &str = ".workspace-initialized"; #[derive(Clone)] -pub struct SandboxClient { +pub struct KubernetesComputeDriver { client: Client, - namespace: String, - default_image: String, - /// Kubernetes `imagePullPolicy` for sandbox containers. When empty the - /// field is omitted from the pod spec and Kubernetes applies its default. - image_pull_policy: String, - grpc_endpoint: String, - ssh_listen_addr: String, - ssh_handshake_secret: String, - ssh_handshake_skew_secs: u64, - /// When non-empty, sandbox pods get this K8s secret mounted for mTLS to the server. - client_tls_secret_name: String, - /// When non-empty, sandbox pods get `hostAliases` entries mapping - /// `host.docker.internal` and `host.openshell.internal` to this IP. - host_gateway_ip: String, + config: KubernetesComputeConfig, } -impl std::fmt::Debug for SandboxClient { +impl std::fmt::Debug for KubernetesComputeDriver { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("SandboxClient") - .field("namespace", &self.namespace) - .field("default_image", &self.default_image) - .field("grpc_endpoint", &self.grpc_endpoint) + f.debug_struct("KubernetesComputeDriver") + .field("namespace", &self.config.namespace) + .field("default_image", &self.config.default_image) + .field("grpc_endpoint", &self.config.grpc_endpoint) .finish() } } -impl SandboxClient { - pub async fn new( - namespace: String, - default_image: String, - image_pull_policy: String, - grpc_endpoint: String, - ssh_listen_addr: String, - ssh_handshake_secret: String, - ssh_handshake_skew_secs: u64, - client_tls_secret_name: String, - host_gateway_ip: String, - ) -> Result { - let mut config = match kube::Config::incluster() { +impl KubernetesComputeDriver { + pub async fn new(config: KubernetesComputeConfig) -> Result { + let mut kube_config = match kube::Config::incluster() { Ok(c) => c, Err(_) => kube::Config::infer() .await .map_err(kube::Error::InferConfig)?, }; - config.connect_timeout = Some(Duration::from_secs(10)); - config.read_timeout = Some(Duration::from_secs(30)); - config.write_timeout = Some(Duration::from_secs(30)); - let client = Client::try_from(config)?; - Ok(Self { - client, - namespace, - default_image, - image_pull_policy, - grpc_endpoint, - ssh_listen_addr, - ssh_handshake_secret, - ssh_handshake_skew_secs, - client_tls_secret_name, - host_gateway_ip, + kube_config.connect_timeout = Some(Duration::from_secs(10)); + kube_config.read_timeout = Some(Duration::from_secs(30)); + kube_config.write_timeout = Some(Duration::from_secs(30)); + let client = Client::try_from(kube_config)?; + Ok(Self { client, config }) + } + + pub async fn capabilities(&self) -> Result { + Ok(GetCapabilitiesResponse { + driver_name: "kubernetes".to_string(), + driver_version: openshell_core::VERSION.to_string(), + default_image: self.config.default_image.clone(), + supports_gpu: self.has_gpu_capacity().await.unwrap_or(false), }) } pub fn default_image(&self) -> &str { - &self.default_image + &self.config.default_image } pub fn namespace(&self) -> &str { - &self.namespace + &self.config.namespace } pub fn ssh_listen_addr(&self) -> &str { - &self.ssh_listen_addr - } - - pub fn ssh_handshake_secret(&self) -> &str { - &self.ssh_handshake_secret + &self.config.ssh_listen_addr } pub const fn ssh_handshake_skew_secs(&self) -> u64 { - self.ssh_handshake_skew_secs + self.config.ssh_handshake_skew_secs } - pub fn api(&self) -> Api { + fn api(&self) -> Api { let gvk = GroupVersionKind::gvk(SANDBOX_GROUP, SANDBOX_VERSION, SANDBOX_KIND); let resource = ApiResource::from_gvk(&gvk); - Api::namespaced_with(self.client.clone(), &self.namespace, &resource) + Api::namespaced_with(self.client.clone(), &self.config.namespace, &resource) } - pub async fn validate_gpu_support(&self) -> Result<(), tonic::Status> { + async fn has_gpu_capacity(&self) -> Result { let nodes: Api = Api::all(self.client.clone()); - let node_list = nodes.list(&ListParams::default()).await.map_err(|err| { - tonic::Status::internal(format!("check GPU node capacity failed: {err}")) - })?; - - let has_gpu_capacity = node_list.items.into_iter().any(|node| { + let node_list = nodes.list(&ListParams::default()).await?; + Ok(node_list.items.into_iter().any(|node| { node.status .and_then(|status| status.allocatable) .and_then(|allocatable| allocatable.get(GPU_RESOURCE_NAME).cloned()) .is_some_and(|quantity| quantity.0 != "0") - }); + })) + } - if !has_gpu_capacity { + pub async fn validate_sandbox_create(&self, sandbox: &Sandbox) -> Result<(), tonic::Status> { + let gpu_requested = sandbox.spec.as_ref().is_some_and(|spec| spec.gpu); + if gpu_requested + && !self.has_gpu_capacity().await.map_err(|err| { + tonic::Status::internal(format!("check GPU node capacity failed: {err}")) + })? + { return Err(tonic::Status::failed_precondition( "GPU sandbox requested, but the active gateway has no allocatable GPUs. Please refer to documentation and use `openshell doctor` commands to inspect GPU support and gateway configuration.", )); } - Ok(()) } - pub async fn agent_pod_ip(&self, pod_name: &str) -> Result, KubeError> { - let api: Api = Api::namespaced(self.client.clone(), &self.namespace); + fn ssh_handshake_secret(&self) -> &str { + &self.config.ssh_handshake_secret + } + + async fn agent_pod_ip(&self, pod_name: &str) -> Result, KubeError> { + let api: Api = Api::namespaced(self.client.clone(), &self.config.namespace); match api.get(pod_name).await { Ok(pod) => { let ip = pod @@ -200,12 +205,12 @@ impl SandboxClient { } } - pub async fn create(&self, sandbox: &Sandbox) -> Result { + pub async fn create_sandbox(&self, sandbox: &Sandbox) -> Result<(), KubernetesDriverError> { let name = sandbox.name.as_str(); info!( sandbox_id = %sandbox.id, sandbox_name = %name, - namespace = %self.namespace, + namespace = %self.config.namespace, "Creating sandbox in Kubernetes" ); @@ -214,34 +219,34 @@ impl SandboxClient { let mut obj = DynamicObject::new(name, &resource); obj.metadata = ObjectMeta { name: Some(name.to_string()), - namespace: Some(self.namespace.clone()), + namespace: Some(self.config.namespace.clone()), labels: Some(sandbox_labels(sandbox)), ..Default::default() }; obj.data = sandbox_to_k8s_spec( sandbox.spec.as_ref(), - &self.default_image, - &self.image_pull_policy, + &self.config.default_image, + &self.config.image_pull_policy, &sandbox.id, &sandbox.name, - &self.grpc_endpoint, + &self.config.grpc_endpoint, self.ssh_listen_addr(), self.ssh_handshake_secret(), self.ssh_handshake_skew_secs(), - &self.client_tls_secret_name, - &self.host_gateway_ip, + &self.config.client_tls_secret_name, + &self.config.host_gateway_ip, ); let api = self.api(); match tokio::time::timeout(KUBE_API_TIMEOUT, api.create(&PostParams::default(), &obj)).await { - Ok(Ok(result)) => { + Ok(Ok(_result)) => { info!( sandbox_id = %sandbox.id, sandbox_name = %name, "Sandbox created in Kubernetes successfully" ); - Ok(result) + Ok(()) } Ok(Err(err)) => { warn!( @@ -250,7 +255,7 @@ impl SandboxClient { error = %err, "Failed to create sandbox in Kubernetes" ); - Err(err) + Err(KubernetesDriverError::from_kube(err)) } Err(_elapsed) => { warn!( @@ -259,23 +264,18 @@ impl SandboxClient { timeout_secs = KUBE_API_TIMEOUT.as_secs(), "Timed out creating sandbox in Kubernetes" ); - Err(KubeError::Api(kube::core::ErrorResponse { - status: "Failure".to_string(), - message: format!( - "timed out after {}s waiting for Kubernetes API", - KUBE_API_TIMEOUT.as_secs() - ), - reason: "Timeout".to_string(), - code: 504, - })) + Err(KubernetesDriverError::Message(format!( + "timed out after {}s waiting for Kubernetes API", + KUBE_API_TIMEOUT.as_secs() + ))) } } } - pub async fn delete(&self, name: &str) -> Result { + pub async fn delete_sandbox(&self, name: &str) -> Result { info!( sandbox_name = %name, - namespace = %self.namespace, + namespace = %self.config.namespace, "Deleting sandbox from Kubernetes" ); @@ -297,7 +297,7 @@ impl SandboxClient { error = %err, "Failed to delete sandbox from Kubernetes" ); - Err(err) + Err(err.to_string()) } Err(_elapsed) => { warn!( @@ -305,364 +305,302 @@ impl SandboxClient { timeout_secs = KUBE_API_TIMEOUT.as_secs(), "Timed out deleting sandbox from Kubernetes" ); - Err(KubeError::Api(kube::core::ErrorResponse { - status: "Failure".to_string(), - message: format!( - "timed out after {}s waiting for Kubernetes API", - KUBE_API_TIMEOUT.as_secs() - ), - reason: "Timeout".to_string(), - code: 504, - })) + Err(format!( + "timed out after {}s waiting for Kubernetes API", + KUBE_API_TIMEOUT.as_secs() + )) } } } -} - -impl ObjectType for Sandbox { - fn object_type() -> &'static str { - "sandbox" - } -} + pub async fn resolve_sandbox_endpoint( + &self, + sandbox: &Sandbox, + ) -> Result { + if let Some(status) = sandbox.status.as_ref() + && !status.agent_pod.is_empty() + { + match self.agent_pod_ip(&status.agent_pod).await { + Ok(Some(ip)) => { + return Ok(ResolveSandboxEndpointResponse { + endpoint: Some(SandboxEndpoint { + target: Some(openshell_core::proto::sandbox_endpoint::Target::Ip( + ip.to_string(), + )), + port: u32::from(self.config.ssh_port), + }), + }); + } + Ok(None) => { + return Err("sandbox agent pod IP is not available".to_string()); + } + Err(err) => { + return Err(format!("failed to resolve agent pod IP: {err}")); + } + } + } -impl ObjectId for Sandbox { - fn object_id(&self) -> &str { - &self.id - } -} + if sandbox.name.is_empty() { + return Err("sandbox has no name".to_string()); + } -impl ObjectName for Sandbox { - fn object_name(&self) -> &str { - &self.name + Ok(ResolveSandboxEndpointResponse { + endpoint: Some(SandboxEndpoint { + target: Some(openshell_core::proto::sandbox_endpoint::Target::Host( + format!( + "{}.{}.svc.cluster.local", + sandbox.name, self.config.namespace + ), + )), + port: u32::from(self.config.ssh_port), + }), + }) } -} -pub fn spawn_sandbox_watcher( - store: Arc, - client: SandboxClient, - index: crate::sandbox_index::SandboxIndex, - watch_bus: crate::sandbox_watch::SandboxWatchBus, - tracing_log_bus: crate::tracing_bus::TracingLogBus, -) { - let namespace = client.namespace().to_string(); - info!(namespace = %namespace, "Starting sandbox watcher"); - - tokio::spawn(async move { - let api = client.api(); - let mut stream = watcher::watcher(api, watcher::Config::default()).boxed(); - - loop { - match stream.try_next().await { - Ok(Some(event)) => match event { - Event::Applied(obj) => { - let obj_name = obj.metadata.name.clone().unwrap_or_default(); - debug!(sandbox_name = %obj_name, "Received Applied event from Kubernetes"); - if let Err(err) = - handle_applied(&store, &client, &index, &watch_bus, obj).await - { - warn!(sandbox_name = %obj_name, error = %err, "Failed to apply sandbox update"); + pub async fn watch_sandboxes(&self) -> Result { + let namespace = self.config.namespace.clone(); + let sandbox_api = self.api(); + let event_api: Api = Api::namespaced(self.client.clone(), &namespace); + let mut sandbox_stream = watcher::watcher(sandbox_api, watcher::Config::default()).boxed(); + let mut event_stream = watcher::watcher(event_api, watcher::Config::default()).boxed(); + let (tx, rx) = mpsc::channel(256); + + tokio::spawn(async move { + let mut sandbox_name_to_id = std::collections::HashMap::::new(); + let mut agent_pod_to_id = std::collections::HashMap::::new(); + + loop { + tokio::select! { + result = sandbox_stream.try_next() => match result { + Ok(Some(Event::Applied(obj))) => { + match sandbox_from_object(&namespace, obj) { + Ok(sandbox) => { + update_indexes(&mut sandbox_name_to_id, &mut agent_pod_to_id, &sandbox); + let event = WatchSandboxesEvent { + payload: Some(openshell_core::proto::watch_sandboxes_event::Payload::Sandbox( + WatchSandboxesSandboxEvent { sandbox: Some(sandbox) } + )), + }; + if tx.send(Ok(event)).await.is_err() { + break; + } + } + Err(err) => { + if tx.send(Err(KubernetesDriverError::Message(err))).await.is_err() { + break; + } + } + } } - } - Event::Deleted(obj) => { - let obj_name = obj.metadata.name.clone().unwrap_or_default(); - debug!(sandbox_name = %obj_name, "Received Deleted event from Kubernetes"); - if let Err(err) = - handle_deleted(&store, &index, &watch_bus, &tracing_log_bus, obj).await - { - warn!(sandbox_name = %obj_name, error = %err, "Failed to delete sandbox record"); + Ok(Some(Event::Deleted(obj))) => { + match sandbox_id_from_object(&obj) { + Ok(sandbox_id) => { + remove_indexes(&mut sandbox_name_to_id, &mut agent_pod_to_id, &sandbox_id); + let event = WatchSandboxesEvent { + payload: Some(openshell_core::proto::watch_sandboxes_event::Payload::Deleted( + WatchSandboxesDeletedEvent { sandbox_id } + )), + }; + if tx.send(Ok(event)).await.is_err() { + break; + } + } + Err(err) => { + if tx.send(Err(KubernetesDriverError::Message(err))).await.is_err() { + break; + } + } + } } - } - Event::Restarted(objs) => { - info!( - count = objs.len(), - "Sandbox watcher restarted, re-syncing sandboxes" - ); - for obj in objs { - let obj_name = obj.metadata.name.clone().unwrap_or_default(); - if let Err(err) = - handle_applied(&store, &client, &index, &watch_bus, obj).await - { - warn!(sandbox_name = %obj_name, error = %err, "Failed to apply sandbox update during resync"); + Ok(Some(Event::Restarted(objs))) => { + for obj in objs { + match sandbox_from_object(&namespace, obj) { + Ok(sandbox) => { + update_indexes(&mut sandbox_name_to_id, &mut agent_pod_to_id, &sandbox); + let event = WatchSandboxesEvent { + payload: Some(openshell_core::proto::watch_sandboxes_event::Payload::Sandbox( + WatchSandboxesSandboxEvent { sandbox: Some(sandbox) } + )), + }; + if tx.send(Ok(event)).await.is_err() { + return; + } + } + Err(err) => { + if tx.send(Err(KubernetesDriverError::Message(err))).await.is_err() { + return; + } + } + } + } + } + Ok(None) => { + let _ = tx.send(Err(KubernetesDriverError::Message( + "sandbox watcher stream ended unexpectedly".to_string() + ))).await; + break; + } + Err(err) => { + let _ = tx.send(Err(KubernetesDriverError::Message(err.to_string()))).await; + break; + } + }, + result = event_stream.try_next() => match result { + Ok(Some(Event::Applied(obj))) => { + if let Some((sandbox_id, event)) = map_kube_event_to_platform( + &sandbox_name_to_id, + &agent_pod_to_id, + &obj, + ) { + let event = WatchSandboxesEvent { + payload: Some(openshell_core::proto::watch_sandboxes_event::Payload::PlatformEvent( + WatchSandboxesPlatformEvent { sandbox_id, event: Some(event) } + )), + }; + if tx.send(Ok(event)).await.is_err() { + break; + } } } + Ok(Some(Event::Deleted(_))) => {} + Ok(Some(Event::Restarted(_))) => { + debug!(namespace = %namespace, "Kubernetes event watcher restarted"); + } + Ok(None) => { + let _ = tx.send(Err(KubernetesDriverError::Message( + "kubernetes event watcher stream ended".to_string() + ))).await; + break; + } + Err(err) => { + let _ = tx.send(Err(KubernetesDriverError::Message(err.to_string()))).await; + break; + } } - }, - Ok(None) => { - warn!("Sandbox watcher stream ended unexpectedly"); - break; - } - Err(err) => { - warn!(error = %err, "Sandbox watcher error"); } } - } - }); -} + }); -/// Interval between store-vs-k8s reconciliation sweeps. -const RECONCILE_INTERVAL: Duration = Duration::from_secs(60); - -/// How long a sandbox can stay in `Provisioning` in the store without a -/// corresponding Kubernetes resource before it is considered orphaned and -/// removed. -const ORPHAN_GRACE_PERIOD: Duration = Duration::from_secs(300); - -/// Periodically reconcile the store against Kubernetes to clean up orphaned -/// sandbox records. A record is orphaned when it exists in the store but -/// has no corresponding Kubernetes `Sandbox` CR — typically because the -/// k8s create timed out or the gRPC handler was cancelled. -pub fn spawn_store_reconciler( - store: Arc, - client: SandboxClient, - index: crate::sandbox_index::SandboxIndex, - watch_bus: crate::sandbox_watch::SandboxWatchBus, - tracing_log_bus: crate::tracing_bus::TracingLogBus, -) { - tokio::spawn(async move { - // Wait for initial startup to settle before running the first sweep. - tokio::time::sleep(RECONCILE_INTERVAL).await; - - loop { - if let Err(e) = - reconcile_orphaned_sandboxes(&store, &client, &index, &watch_bus, &tracing_log_bus) - .await - { - warn!(error = %e, "Store reconciliation sweep failed"); - } - tokio::time::sleep(RECONCILE_INTERVAL).await; - } - }); + Ok(Box::pin(ReceiverStream::new(rx))) + } } -/// Single reconciliation sweep: list all sandboxes in the store that are -/// still `Provisioning`, check if they have a corresponding k8s resource, -/// and remove any that have been orphaned beyond the grace period. -async fn reconcile_orphaned_sandboxes( - store: &Store, - client: &SandboxClient, - index: &crate::sandbox_index::SandboxIndex, - watch_bus: &crate::sandbox_watch::SandboxWatchBus, - tracing_log_bus: &crate::tracing_bus::TracingLogBus, -) -> Result<(), String> { - let records = store - .list(Sandbox::object_type(), 500, 0) - .await - .map_err(|e| e.to_string())?; - - let api = client.api(); - let now_ms = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_millis() as i64; - - for record in records { - let sandbox: Sandbox = match prost::Message::decode(record.payload.as_slice()) { - Ok(s) => s, - Err(e) => { - warn!(error = %e, "Failed to decode sandbox record during reconciliation"); - continue; - } - }; - - // Only check sandboxes that are still provisioning — these are the - // ones at risk of being orphaned. - if sandbox.phase != SandboxPhase::Provisioning as i32 { - continue; - } +fn sandbox_labels(sandbox: &Sandbox) -> BTreeMap { + let mut labels = BTreeMap::new(); + labels.insert(SANDBOX_ID_LABEL.to_string(), sandbox.id.clone()); + labels.insert( + SANDBOX_MANAGED_LABEL.to_string(), + SANDBOX_MANAGED_VALUE.to_string(), + ); + labels +} - // Check how old this record is using the store's created_at_ms. - let age_ms = now_ms.saturating_sub(record.created_at_ms); - if age_ms < ORPHAN_GRACE_PERIOD.as_millis() as i64 { - continue; - } +fn sandbox_id_from_object(obj: &DynamicObject) -> Result { + if let Some(labels) = obj.metadata.labels.as_ref() + && let Some(id) = labels.get(SANDBOX_ID_LABEL) + { + return Ok(id.clone()); + } - // Check if a corresponding k8s resource exists. - match tokio::time::timeout(KUBE_API_TIMEOUT, api.get(&sandbox.name)).await { - Ok(Ok(_)) => { - // k8s resource exists — not orphaned. - continue; - } - Ok(Err(KubeError::Api(err))) if err.code == 404 => { - // k8s resource does not exist — orphaned store entry. - info!( - sandbox_id = %sandbox.id, - sandbox_name = %sandbox.name, - age_secs = age_ms / 1000, - "Removing orphaned sandbox from store (no corresponding k8s resource)" - ); - if let Err(e) = store.delete(Sandbox::object_type(), &sandbox.id).await { - warn!(sandbox_id = %sandbox.id, error = %e, "Failed to remove orphaned sandbox"); - } - index.remove_sandbox(&sandbox.id); - watch_bus.notify(&sandbox.id); - tracing_log_bus.remove(&sandbox.id); - tracing_log_bus.platform_event_bus.remove(&sandbox.id); - watch_bus.remove(&sandbox.id); - } - Ok(Err(err)) => { - // k8s API error — skip this record and try again next cycle. - debug!( - sandbox_id = %sandbox.id, - error = %err, - "Skipping orphan check due to k8s API error" - ); - } - Err(_elapsed) => { - debug!( - sandbox_id = %sandbox.id, - "Skipping orphan check due to k8s API timeout" - ); - } - } + let name = obj.metadata.name.clone().unwrap_or_default(); + if let Some(id) = name.strip_prefix("sandbox-") { + return Ok(id.to_string()); } - Ok(()) + Err("sandbox id not found on object".to_string()) } -async fn handle_applied( - store: &Store, - client: &SandboxClient, - index: &crate::sandbox_index::SandboxIndex, - watch_bus: &crate::sandbox_watch::SandboxWatchBus, - obj: DynamicObject, -) -> Result<(), String> { +fn sandbox_from_object(namespace: &str, obj: DynamicObject) -> Result { let id = sandbox_id_from_object(&obj)?; let name = obj.metadata.name.clone().unwrap_or_default(); let namespace = obj .metadata .namespace .clone() - .unwrap_or_else(|| client.namespace().to_string()); + .unwrap_or_else(|| namespace.to_string()); let deletion_timestamp = obj.metadata.deletion_timestamp.is_some(); - - let existing = store - .get_message::(&id) - .await - .map_err(|e| e.to_string())?; - - let mut status = status_from_object(&obj); - rewrite_user_facing_conditions( - &mut status, - existing.as_ref().and_then(|sandbox| sandbox.spec.as_ref()), - ); + let status = status_from_object(&obj); let phase = derive_phase(&status, deletion_timestamp); - // If the record doesn't exist yet, the `create_sandbox` handler may - // still be in-flight (it creates the k8s resource first, then writes - // to the store). Build a minimal placeholder but never overwrite an - // existing record's `spec` — only the `create_sandbox` handler sets it. - let mut sandbox = existing.unwrap_or_else(|| Sandbox { - id: id.clone(), - name: name.clone(), + Ok(Sandbox { + id, + name, namespace, spec: None, - status: None, - phase: SandboxPhase::Unknown as i32, + status, + phase: phase as i32, ..Default::default() - }); + }) +} - // Log phase transitions - let old_phase = SandboxPhase::try_from(sandbox.phase).unwrap_or(SandboxPhase::Unknown); - if old_phase != phase { - info!( - sandbox_id = %id, - sandbox_name = %name, - old_phase = ?old_phase, - new_phase = ?phase, - "Sandbox phase changed" - ); +fn update_indexes( + sandbox_name_to_id: &mut std::collections::HashMap, + agent_pod_to_id: &mut std::collections::HashMap, + sandbox: &Sandbox, +) { + if !sandbox.name.is_empty() { + sandbox_name_to_id.insert(sandbox.name.clone(), sandbox.id.clone()); } - - // Log error conditions with details - if phase == SandboxPhase::Error - && let Some(ref status) = status + if let Some(status) = sandbox.status.as_ref() + && !status.agent_pod.is_empty() { - for condition in &status.conditions { - if condition.r#type == "Ready" - && condition.status.eq_ignore_ascii_case("false") - && is_terminal_failure_condition(condition) - { - warn!( - sandbox_id = %id, - sandbox_name = %name, - reason = %condition.reason, - message = %condition.message, - "Sandbox failed to become ready" - ); - } - } - } - - // Log when sandbox becomes ready - if phase == SandboxPhase::Ready && old_phase != SandboxPhase::Ready { - info!( - sandbox_id = %id, - sandbox_name = %name, - "Sandbox is now ready" - ); + agent_pod_to_id.insert(status.agent_pod.clone(), sandbox.id.clone()); } - - sandbox.status = status; - sandbox.phase = phase as i32; - - index.update_from_sandbox(&sandbox); - - store - .put_message(&sandbox) - .await - .map_err(|e| e.to_string())?; - - watch_bus.notify(&id); - Ok(()) } -async fn handle_deleted( - store: &Store, - index: &crate::sandbox_index::SandboxIndex, - watch_bus: &crate::sandbox_watch::SandboxWatchBus, - tracing_log_bus: &crate::tracing_bus::TracingLogBus, - obj: DynamicObject, -) -> Result<(), String> { - let id = sandbox_id_from_object(&obj)?; - let deleted = store - .delete(Sandbox::object_type(), &id) - .await - .map_err(|e| e.to_string())?; - debug!(sandbox_id = %id, deleted, "Deleted sandbox record"); - index.remove_sandbox(&id); - watch_bus.notify(&id); - - // Clean up bus entries to prevent unbounded memory growth. - tracing_log_bus.remove(&id); - tracing_log_bus.platform_event_bus.remove(&id); - watch_bus.remove(&id); - - Ok(()) +fn remove_indexes( + sandbox_name_to_id: &mut std::collections::HashMap, + agent_pod_to_id: &mut std::collections::HashMap, + sandbox_id: &str, +) { + sandbox_name_to_id.retain(|_, value| value != sandbox_id); + agent_pod_to_id.retain(|_, value| value != sandbox_id); } -fn sandbox_labels(sandbox: &Sandbox) -> BTreeMap { - let mut labels = BTreeMap::new(); - labels.insert(SANDBOX_ID_LABEL.to_string(), sandbox.id.clone()); - labels.insert( - SANDBOX_MANAGED_LABEL.to_string(), - SANDBOX_MANAGED_VALUE.to_string(), - ); - labels -} +fn map_kube_event_to_platform( + sandbox_name_to_id: &std::collections::HashMap, + agent_pod_to_id: &std::collections::HashMap, + obj: &KubeEventObj, +) -> Option<(String, PlatformEvent)> { + let involved = obj.involved_object.clone(); + let involved_kind = involved.kind.unwrap_or_default(); + let involved_name = involved.name.unwrap_or_default(); + + let sandbox_id = match involved_kind.as_str() { + "Sandbox" => sandbox_name_to_id.get(&involved_name).cloned()?, + "Pod" => sandbox_name_to_id + .get(&involved_name) + .cloned() + .or_else(|| agent_pod_to_id.get(&involved_name).cloned())?, + _ => return None, + }; -fn sandbox_id_from_object(obj: &DynamicObject) -> Result { - if let Some(labels) = obj.metadata.labels.as_ref() - && let Some(id) = labels.get(SANDBOX_ID_LABEL) - { - return Ok(id.clone()); - } + let ts = obj + .last_timestamp + .as_ref() + .or(obj.first_timestamp.as_ref()) + .map_or(0, |t| t.0.timestamp_millis()); - let name = obj.metadata.name.clone().unwrap_or_default(); - if let Some(id) = name.strip_prefix("sandbox-") { - return Ok(id.to_string()); + let mut metadata = std::collections::HashMap::new(); + metadata.insert("involved_kind".to_string(), involved_kind); + metadata.insert("involved_name".to_string(), involved_name); + if let Some(ns) = &obj.involved_object.namespace { + metadata.insert("namespace".to_string(), ns.clone()); + } + if let Some(count) = obj.count { + metadata.insert("count".to_string(), count.to_string()); } - Err("sandbox id not found on object".to_string()) + Some(( + sandbox_id, + PlatformEvent { + timestamp_ms: ts, + source: "kubernetes".to_string(), + r#type: obj.type_.clone().unwrap_or_default(), + reason: obj.reason.clone().unwrap_or_default(), + message: obj.message.clone().unwrap_or_default(), + metadata, + }, + )) } /// Path where the supervisor binary is mounted inside the agent container. @@ -1364,6 +1302,7 @@ fn condition_from_value(value: &serde_json::Value) -> Option { }) } +#[cfg_attr(not(test), allow(dead_code))] fn rewrite_user_facing_conditions(status: &mut Option, spec: Option<&SandboxSpec>) { let gpu_requested = spec.is_some_and(|sandbox_spec| sandbox_spec.gpu); if !gpu_requested { diff --git a/crates/openshell-driver-kubernetes/src/grpc.rs b/crates/openshell-driver-kubernetes/src/grpc.rs new file mode 100644 index 000000000..15589a9e3 --- /dev/null +++ b/crates/openshell-driver-kubernetes/src/grpc.rs @@ -0,0 +1,112 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use futures::{Stream, StreamExt}; +use openshell_core::proto::compute_driver_server::ComputeDriver; +use openshell_core::proto::{ + ComputeCreateSandboxRequest, ComputeCreateSandboxResponse, ComputeDeleteSandboxRequest, + ComputeDeleteSandboxResponse, GetCapabilitiesRequest, GetCapabilitiesResponse, + ResolveSandboxEndpointRequest, ResolveSandboxEndpointResponse, ValidateSandboxCreateRequest, + ValidateSandboxCreateResponse, WatchSandboxesEvent, WatchSandboxesRequest, +}; +use std::pin::Pin; +use tonic::{Request, Response, Status}; + +use crate::KubernetesComputeDriver; + +#[derive(Debug, Clone)] +pub struct ComputeDriverService { + driver: KubernetesComputeDriver, +} + +impl ComputeDriverService { + #[must_use] + pub fn new(driver: KubernetesComputeDriver) -> Self { + Self { driver } + } +} + +#[tonic::async_trait] +impl ComputeDriver for ComputeDriverService { + async fn get_capabilities( + &self, + _request: Request, + ) -> Result, Status> { + self.driver + .capabilities() + .await + .map(Response::new) + .map_err(Status::internal) + } + + async fn validate_sandbox_create( + &self, + request: Request, + ) -> Result, Status> { + let sandbox = request + .into_inner() + .sandbox + .ok_or_else(|| Status::invalid_argument("sandbox is required"))?; + self.driver.validate_sandbox_create(&sandbox).await?; + Ok(Response::new(ValidateSandboxCreateResponse {})) + } + + async fn create_sandbox( + &self, + request: Request, + ) -> Result, Status> { + let sandbox = request + .into_inner() + .sandbox + .ok_or_else(|| Status::invalid_argument("sandbox is required"))?; + self.driver + .create_sandbox(&sandbox) + .await + .map_err(|err| Status::internal(err.to_string()))?; + Ok(Response::new(ComputeCreateSandboxResponse {})) + } + + async fn delete_sandbox( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + let deleted = self + .driver + .delete_sandbox(&request.sandbox_name) + .await + .map_err(Status::internal)?; + Ok(Response::new(ComputeDeleteSandboxResponse { deleted })) + } + + async fn resolve_sandbox_endpoint( + &self, + request: Request, + ) -> Result, Status> { + let sandbox = request + .into_inner() + .sandbox + .ok_or_else(|| Status::invalid_argument("sandbox is required"))?; + self.driver + .resolve_sandbox_endpoint(&sandbox) + .await + .map(Response::new) + .map_err(Status::internal) + } + + type WatchSandboxesStream = + Pin> + Send + 'static>>; + + async fn watch_sandboxes( + &self, + _request: Request, + ) -> Result, Status> { + let stream = self + .driver + .watch_sandboxes() + .await + .map_err(Status::internal)?; + let stream = stream.map(|item| item.map_err(|err| Status::internal(err.to_string()))); + Ok(Response::new(Box::pin(stream))) + } +} diff --git a/crates/openshell-driver-kubernetes/src/lib.rs b/crates/openshell-driver-kubernetes/src/lib.rs new file mode 100644 index 000000000..54149fe83 --- /dev/null +++ b/crates/openshell-driver-kubernetes/src/lib.rs @@ -0,0 +1,10 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +pub mod config; +pub mod driver; +pub mod grpc; + +pub use config::KubernetesComputeConfig; +pub use driver::{KubernetesComputeDriver, KubernetesDriverError}; +pub use grpc::ComputeDriverService; diff --git a/crates/openshell-driver-kubernetes/src/main.rs b/crates/openshell-driver-kubernetes/src/main.rs new file mode 100644 index 000000000..32160a6dd --- /dev/null +++ b/crates/openshell-driver-kubernetes/src/main.rs @@ -0,0 +1,88 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use clap::Parser; +use miette::{IntoDiagnostic, Result}; +use std::net::SocketAddr; +use tracing::info; +use tracing_subscriber::EnvFilter; + +use openshell_core::VERSION; +use openshell_core::proto::compute_driver_server::ComputeDriverServer; +use openshell_driver_kubernetes::{ + ComputeDriverService, KubernetesComputeConfig, KubernetesComputeDriver, +}; + +#[derive(Parser, Debug)] +#[command(name = "openshell-driver-kubernetes")] +#[command(version = VERSION)] +struct Args { + #[arg( + long, + env = "OPENSHELL_COMPUTE_DRIVER_BIND", + default_value = "127.0.0.1:50061" + )] + bind_address: SocketAddr, + + #[arg(long, env = "OPENSHELL_LOG_LEVEL", default_value = "info")] + log_level: String, + + #[arg(long, env = "OPENSHELL_SANDBOX_NAMESPACE", default_value = "default")] + sandbox_namespace: String, + + #[arg(long, env = "OPENSHELL_SANDBOX_IMAGE")] + sandbox_image: Option, + + #[arg(long, env = "OPENSHELL_SANDBOX_IMAGE_PULL_POLICY")] + sandbox_image_pull_policy: Option, + + #[arg(long, env = "OPENSHELL_GRPC_ENDPOINT")] + grpc_endpoint: Option, + + #[arg(long, env = "OPENSHELL_SANDBOX_SSH_PORT", default_value_t = 2222)] + sandbox_ssh_port: u16, + + #[arg(long, env = "OPENSHELL_SSH_HANDSHAKE_SECRET")] + ssh_handshake_secret: String, + + #[arg(long, env = "OPENSHELL_SSH_HANDSHAKE_SKEW_SECS", default_value_t = 300)] + ssh_handshake_skew_secs: u64, + + #[arg(long, env = "OPENSHELL_CLIENT_TLS_SECRET_NAME")] + client_tls_secret_name: Option, + + #[arg(long, env = "OPENSHELL_HOST_GATEWAY_IP")] + host_gateway_ip: Option, +} + +#[tokio::main] +async fn main() -> Result<()> { + let args = Args::parse(); + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&args.log_level)), + ) + .init(); + + let driver = KubernetesComputeDriver::new(KubernetesComputeConfig { + namespace: args.sandbox_namespace, + default_image: args.sandbox_image.unwrap_or_default(), + image_pull_policy: args.sandbox_image_pull_policy.unwrap_or_default(), + grpc_endpoint: args.grpc_endpoint.unwrap_or_default(), + ssh_listen_addr: format!("0.0.0.0:{}", args.sandbox_ssh_port), + ssh_port: args.sandbox_ssh_port, + ssh_handshake_secret: args.ssh_handshake_secret, + ssh_handshake_skew_secs: args.ssh_handshake_skew_secs, + client_tls_secret_name: args.client_tls_secret_name.unwrap_or_default(), + host_gateway_ip: args.host_gateway_ip.unwrap_or_default(), + }) + .await + .into_diagnostic()?; + + info!(address = %args.bind_address, "Starting Kubernetes compute driver"); + tonic::transport::Server::builder() + .add_service(ComputeDriverServer::new(ComputeDriverService::new(driver))) + .serve(args.bind_address) + .await + .into_diagnostic() +} diff --git a/crates/openshell-server/Cargo.toml b/crates/openshell-server/Cargo.toml index 0308f30ff..678eaf2de 100644 --- a/crates/openshell-server/Cargo.toml +++ b/crates/openshell-server/Cargo.toml @@ -16,6 +16,7 @@ path = "src/main.rs" [dependencies] openshell-core = { path = "../openshell-core" } +openshell-driver-kubernetes = { path = "../openshell-driver-kubernetes" } openshell-policy = { path = "../openshell-policy" } openshell-router = { path = "../openshell-router" } @@ -63,9 +64,6 @@ serde_json = { workspace = true } tokio-stream = { workspace = true } sqlx = { workspace = true } reqwest = { workspace = true } -kube = { workspace = true } -kube-runtime = { workspace = true } -k8s-openapi = { workspace = true } uuid = { workspace = true } hmac = "0.12" sha2 = "0.10" diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs new file mode 100644 index 000000000..da92862a5 --- /dev/null +++ b/crates/openshell-server/src/compute/mod.rs @@ -0,0 +1,511 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Gateway-owned compute orchestration over a pluggable compute backend. + +use crate::grpc::policy::{SANDBOX_SETTINGS_OBJECT_TYPE, sandbox_settings_id}; +use crate::persistence::{ObjectId, ObjectName, ObjectType, Store}; +use crate::sandbox_index::SandboxIndex; +use crate::sandbox_watch::SandboxWatchBus; +use crate::tracing_bus::TracingLogBus; +use futures::{Stream, StreamExt}; +use openshell_core::proto::{ + ResolveSandboxEndpointResponse, Sandbox, SandboxCondition, SandboxPhase, SandboxSpec, + SandboxStatus, SshSession, WatchSandboxesEvent, +}; +use openshell_driver_kubernetes::{ + KubernetesComputeConfig, KubernetesComputeDriver, KubernetesDriverError, +}; +use prost::Message; +use std::fmt; +use std::net::IpAddr; +use std::pin::Pin; +use std::sync::Arc; +use std::time::Duration; +use tonic::Status; +use tracing::{info, warn}; + +type ComputeWatchStream = + Pin> + Send>>; + +#[derive(Debug, thiserror::Error)] +pub enum ComputeError { + #[error("sandbox already exists")] + AlreadyExists, + #[error("{0}")] + Precondition(String), + #[error("{0}")] + Message(String), +} + +impl From for ComputeError { + fn from(value: KubernetesDriverError) -> Self { + match value { + KubernetesDriverError::AlreadyExists => Self::AlreadyExists, + KubernetesDriverError::Precondition(message) => Self::Precondition(message), + KubernetesDriverError::Message(message) => Self::Message(message), + } + } +} + +pub enum ResolvedEndpoint { + Ip(IpAddr, u16), + Host(String, u16), +} + +#[tonic::async_trait] +pub trait ComputeBackend: fmt::Debug + Send + Sync { + fn default_image(&self) -> &str; + async fn validate_sandbox_create(&self, sandbox: &Sandbox) -> Result<(), Status>; + async fn create_sandbox(&self, sandbox: &Sandbox) -> Result<(), ComputeError>; + async fn delete_sandbox(&self, sandbox_name: &str) -> Result; + async fn resolve_sandbox_endpoint( + &self, + sandbox: &Sandbox, + ) -> Result; + async fn watch_sandboxes(&self) -> Result; +} + +#[derive(Debug)] +pub struct InProcessKubernetesBackend { + driver: KubernetesComputeDriver, +} + +impl InProcessKubernetesBackend { + #[must_use] + pub fn new(driver: KubernetesComputeDriver) -> Self { + Self { driver } + } +} + +#[tonic::async_trait] +impl ComputeBackend for InProcessKubernetesBackend { + fn default_image(&self) -> &str { + self.driver.default_image() + } + + async fn validate_sandbox_create(&self, sandbox: &Sandbox) -> Result<(), Status> { + self.driver.validate_sandbox_create(sandbox).await + } + + async fn create_sandbox(&self, sandbox: &Sandbox) -> Result<(), ComputeError> { + self.driver + .create_sandbox(sandbox) + .await + .map_err(Into::into) + } + + async fn delete_sandbox(&self, sandbox_name: &str) -> Result { + self.driver + .delete_sandbox(sandbox_name) + .await + .map_err(ComputeError::Message) + } + + async fn resolve_sandbox_endpoint( + &self, + sandbox: &Sandbox, + ) -> Result { + let response = self + .driver + .resolve_sandbox_endpoint(sandbox) + .await + .map_err(ComputeError::Message)?; + resolved_endpoint_from_response(&response) + } + + async fn watch_sandboxes(&self) -> Result { + let stream = self + .driver + .watch_sandboxes() + .await + .map_err(ComputeError::Message)?; + Ok(Box::pin(stream.map(|item| item.map_err(Into::into)))) + } +} + +#[derive(Clone)] +pub struct ComputeRuntime { + backend: Arc, + store: Arc, + sandbox_index: SandboxIndex, + sandbox_watch_bus: SandboxWatchBus, + tracing_log_bus: TracingLogBus, +} + +impl fmt::Debug for ComputeRuntime { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ComputeRuntime").finish_non_exhaustive() + } +} + +impl ComputeRuntime { + pub async fn new_kubernetes( + config: KubernetesComputeConfig, + store: Arc, + sandbox_index: SandboxIndex, + sandbox_watch_bus: SandboxWatchBus, + tracing_log_bus: TracingLogBus, + ) -> Result { + let driver = KubernetesComputeDriver::new(config) + .await + .map_err(|err| ComputeError::Message(err.to_string()))?; + Ok(Self { + backend: Arc::new(InProcessKubernetesBackend::new(driver)), + store, + sandbox_index, + sandbox_watch_bus, + tracing_log_bus, + }) + } + + #[must_use] + pub fn default_image(&self) -> &str { + self.backend.default_image() + } + + pub async fn validate_sandbox_create(&self, sandbox: &Sandbox) -> Result<(), Status> { + self.backend.validate_sandbox_create(sandbox).await + } + + pub async fn create_sandbox(&self, sandbox: Sandbox) -> Result { + let existing = self + .store + .get_message_by_name::(&sandbox.name) + .await + .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))?; + if existing.is_some() { + return Err(Status::already_exists(format!( + "sandbox '{}' already exists", + sandbox.name + ))); + } + + self.sandbox_index.update_from_sandbox(&sandbox); + self.store + .put_message(&sandbox) + .await + .map_err(|e| Status::internal(format!("persist sandbox failed: {e}")))?; + + match self.backend.create_sandbox(&sandbox).await { + Ok(()) => { + self.sandbox_watch_bus.notify(&sandbox.id); + Ok(sandbox) + } + Err(ComputeError::AlreadyExists) => { + let _ = self.store.delete(Sandbox::object_type(), &sandbox.id).await; + self.sandbox_index.remove_sandbox(&sandbox.id); + Err(Status::already_exists("sandbox already exists")) + } + Err(ComputeError::Precondition(message)) => { + let _ = self.store.delete(Sandbox::object_type(), &sandbox.id).await; + self.sandbox_index.remove_sandbox(&sandbox.id); + Err(Status::failed_precondition(message)) + } + Err(err) => { + let _ = self.store.delete(Sandbox::object_type(), &sandbox.id).await; + self.sandbox_index.remove_sandbox(&sandbox.id); + Err(Status::internal(format!("create sandbox failed: {err}"))) + } + } + } + + pub async fn delete_sandbox(&self, name: &str) -> Result { + let sandbox = self + .store + .get_message_by_name::(name) + .await + .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))?; + + let Some(mut sandbox) = sandbox else { + return Err(Status::not_found("sandbox not found")); + }; + + let id = sandbox.id.clone(); + sandbox.phase = SandboxPhase::Deleting as i32; + self.store + .put_message(&sandbox) + .await + .map_err(|e| Status::internal(format!("persist sandbox failed: {e}")))?; + self.sandbox_index.update_from_sandbox(&sandbox); + self.sandbox_watch_bus.notify(&id); + + if let Ok(records) = self.store.list(SshSession::object_type(), 1000, 0).await { + for record in records { + if let Ok(session) = SshSession::decode(record.payload.as_slice()) + && session.sandbox_id == id + && let Err(e) = self + .store + .delete(SshSession::object_type(), &session.id) + .await + { + warn!( + session_id = %session.id, + error = %e, + "Failed to delete SSH session during sandbox cleanup" + ); + } + } + } + + if let Err(e) = self + .store + .delete(SANDBOX_SETTINGS_OBJECT_TYPE, &sandbox_settings_id(&id)) + .await + { + warn!( + sandbox_id = %id, + error = %e, + "Failed to delete sandbox settings during cleanup" + ); + } + + let deleted = self + .backend + .delete_sandbox(&sandbox.name) + .await + .map_err(|err| Status::internal(format!("delete sandbox failed: {err}")))?; + + if !deleted && let Err(e) = self.store.delete(Sandbox::object_type(), &id).await { + warn!(sandbox_id = %id, error = %e, "Failed to clean up store after delete"); + } + + self.cleanup_sandbox_state(&id); + Ok(deleted) + } + + pub async fn resolve_sandbox_endpoint( + &self, + sandbox: &Sandbox, + ) -> Result { + self.backend + .resolve_sandbox_endpoint(sandbox) + .await + .map_err(|err| match err { + ComputeError::Precondition(message) => Status::failed_precondition(message), + other => Status::internal(other.to_string()), + }) + } + + pub fn spawn_watchers(&self) { + let runtime = Arc::new(self.clone()); + tokio::spawn(async move { + runtime.watch_loop().await; + }); + } + + async fn watch_loop(self: Arc) { + loop { + let mut stream = match self.backend.watch_sandboxes().await { + Ok(stream) => stream, + Err(err) => { + warn!(error = %err, "Compute driver watch stream failed to start"); + tokio::time::sleep(Duration::from_secs(2)).await; + continue; + } + }; + + let mut restart = false; + while let Some(item) = stream.next().await { + match item { + Ok(event) => { + if let Err(err) = self.apply_watch_event(event).await { + warn!(error = %err, "Failed to apply compute driver event"); + } + } + Err(err) => { + warn!(error = %err, "Compute driver watch stream errored"); + restart = true; + break; + } + } + } + + if !restart { + warn!("Compute driver watch stream ended unexpectedly"); + } + tokio::time::sleep(Duration::from_secs(2)).await; + } + } + + async fn apply_watch_event(&self, event: WatchSandboxesEvent) -> Result<(), String> { + use openshell_core::proto::watch_sandboxes_event::Payload; + + match event.payload { + Some(Payload::Sandbox(sandbox)) => { + if let Some(sandbox) = sandbox.sandbox { + self.apply_sandbox_update(sandbox).await?; + } + } + Some(Payload::Deleted(deleted)) => { + self.apply_deleted(&deleted.sandbox_id).await?; + } + Some(Payload::PlatformEvent(platform_event)) => { + if let Some(event) = platform_event.event { + self.tracing_log_bus.platform_event_bus.publish( + &platform_event.sandbox_id, + openshell_core::proto::SandboxStreamEvent { + payload: Some( + openshell_core::proto::sandbox_stream_event::Payload::Event(event), + ), + }, + ); + } + } + None => {} + } + Ok(()) + } + + async fn apply_sandbox_update(&self, incoming: Sandbox) -> Result<(), String> { + let existing = self + .store + .get_message::(&incoming.id) + .await + .map_err(|e| e.to_string())?; + + let mut status = incoming.status.clone(); + rewrite_user_facing_conditions( + &mut status, + existing.as_ref().and_then(|sandbox| sandbox.spec.as_ref()), + ); + + let phase = SandboxPhase::try_from(incoming.phase).unwrap_or(SandboxPhase::Unknown); + let mut sandbox = existing.unwrap_or_else(|| Sandbox { + id: incoming.id.clone(), + name: incoming.name.clone(), + namespace: incoming.namespace.clone(), + spec: None, + status: None, + phase: SandboxPhase::Unknown as i32, + ..Default::default() + }); + + let old_phase = SandboxPhase::try_from(sandbox.phase).unwrap_or(SandboxPhase::Unknown); + if old_phase != phase { + info!( + sandbox_id = %incoming.id, + sandbox_name = %incoming.name, + old_phase = ?old_phase, + new_phase = ?phase, + "Sandbox phase changed" + ); + } + + if phase == SandboxPhase::Error + && let Some(ref status) = status + { + for condition in &status.conditions { + if condition.r#type == "Ready" + && condition.status.eq_ignore_ascii_case("false") + && is_terminal_failure_condition(condition) + { + warn!( + sandbox_id = %incoming.id, + sandbox_name = %incoming.name, + reason = %condition.reason, + message = %condition.message, + "Sandbox failed to become ready" + ); + } + } + } + + sandbox.name = incoming.name; + sandbox.namespace = incoming.namespace; + sandbox.status = status; + sandbox.phase = phase as i32; + + self.sandbox_index.update_from_sandbox(&sandbox); + self.store + .put_message(&sandbox) + .await + .map_err(|e| e.to_string())?; + self.sandbox_watch_bus.notify(&sandbox.id); + Ok(()) + } + + async fn apply_deleted(&self, sandbox_id: &str) -> Result<(), String> { + let _ = self + .store + .delete(Sandbox::object_type(), sandbox_id) + .await + .map_err(|e| e.to_string())?; + self.sandbox_index.remove_sandbox(sandbox_id); + self.sandbox_watch_bus.notify(sandbox_id); + self.cleanup_sandbox_state(sandbox_id); + Ok(()) + } + + fn cleanup_sandbox_state(&self, sandbox_id: &str) { + self.tracing_log_bus.remove(sandbox_id); + self.tracing_log_bus.platform_event_bus.remove(sandbox_id); + self.sandbox_watch_bus.remove(sandbox_id); + } +} + +impl ObjectType for Sandbox { + fn object_type() -> &'static str { + "sandbox" + } +} + +impl ObjectId for Sandbox { + fn object_id(&self) -> &str { + &self.id + } +} + +impl ObjectName for Sandbox { + fn object_name(&self) -> &str { + &self.name + } +} + +fn resolved_endpoint_from_response( + response: &ResolveSandboxEndpointResponse, +) -> Result { + let endpoint = response + .endpoint + .as_ref() + .ok_or_else(|| ComputeError::Message("compute driver returned no endpoint".to_string()))?; + let port = u16::try_from(endpoint.port) + .map_err(|_| ComputeError::Message("compute driver returned invalid port".to_string()))?; + + match endpoint.target.as_ref() { + Some(openshell_core::proto::sandbox_endpoint::Target::Ip(ip)) => ip + .parse() + .map(|ip| ResolvedEndpoint::Ip(ip, port)) + .map_err(|e| ComputeError::Message(format!("invalid endpoint IP: {e}"))), + Some(openshell_core::proto::sandbox_endpoint::Target::Host(host)) => { + Ok(ResolvedEndpoint::Host(host.clone(), port)) + } + None => Err(ComputeError::Message( + "compute driver returned endpoint without target".to_string(), + )), + } +} + +fn rewrite_user_facing_conditions(status: &mut Option, spec: Option<&SandboxSpec>) { + let gpu_requested = spec.is_some_and(|sandbox_spec| sandbox_spec.gpu); + if !gpu_requested { + return; + } + + if let Some(status) = status { + for condition in &mut status.conditions { + if condition.r#type == "Ready" + && condition.status.eq_ignore_ascii_case("false") + && condition.reason.eq_ignore_ascii_case("Unschedulable") + { + condition.message = "GPU sandbox could not be scheduled on the active gateway. Another GPU sandbox may already be using the available GPU, or the gateway may not currently be able to satisfy GPU placement. Please refer to documentation and use `openshell doctor` commands to inspect GPU support and gateway configuration.".to_string(); + } + } + } +} + +fn is_terminal_failure_condition(condition: &SandboxCondition) -> bool { + let reason = condition.reason.to_ascii_lowercase(); + let transient_reasons = ["reconcilererror", "dependenciesnotready"]; + !transient_reasons.contains(&reason.as_str()) +} diff --git a/crates/openshell-server/src/grpc/mod.rs b/crates/openshell-server/src/grpc/mod.rs index 59d2ea9fd..af60897d1 100644 --- a/crates/openshell-server/src/grpc/mod.rs +++ b/crates/openshell-server/src/grpc/mod.rs @@ -3,7 +3,7 @@ //! gRPC service implementation. -mod policy; +pub(crate) mod policy; mod provider; mod sandbox; mod validation; diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index a1639d1ce..c8f14b0dd 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -53,7 +53,7 @@ const GLOBAL_SETTINGS_OBJECT_TYPE: &str = "gateway_settings"; const GLOBAL_SETTINGS_ID: &str = "gateway_settings:global"; const GLOBAL_SETTINGS_NAME: &str = "global"; /// Internal object type for durable sandbox-scoped settings. -pub(super) const SANDBOX_SETTINGS_OBJECT_TYPE: &str = "sandbox_settings"; +pub(crate) const SANDBOX_SETTINGS_OBJECT_TYPE: &str = "sandbox_settings"; /// Reserved settings key used to store global policy payload. const POLICY_SETTING_KEY: &str = "policy"; /// Sentinel `sandbox_id` used to store global policy revisions. @@ -1930,7 +1930,7 @@ pub(super) async fn save_global_settings( } /// Derive a distinct settings record ID from a sandbox UUID. -pub(super) fn sandbox_settings_id(sandbox_id: &str) -> String { +pub(crate) fn sandbox_settings_id(sandbox_id: &str) -> String { format!("settings:{sandbox_id}") } diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index 39c34b89d..8e5930826 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -70,18 +70,7 @@ pub(super) async fn handle_create_sandbox( let mut spec = spec; let template = spec.template.get_or_insert_with(SandboxTemplate::default); if template.image.is_empty() { - template.image = state.sandbox_client.default_image().to_string(); - } - - if spec.gpu { - state - .sandbox_client - .validate_gpu_support() - .await - .map_err(|status| { - warn!(error = %status, "Rejecting GPU sandbox request"); - status - })?; + template.image = state.compute.default_image().to_string(); } // Ensure process identity defaults to "sandbox" when missing or @@ -109,61 +98,20 @@ pub(super) async fn handle_create_sandbox( ..Default::default() }; - // Reject duplicate names early. - let existing = state - .store - .get_message_by_name::(&name) - .await - .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))?; - if existing.is_some() { - return Err(Status::already_exists(format!( - "sandbox '{name}' already exists" - ))); - } - - // Persist to the store FIRST so the sandbox watcher always finds - // the record with `spec` populated. - state.sandbox_index.update_from_sandbox(&sandbox); - state - .store - .put_message(&sandbox) + .compute + .validate_sandbox_create(&sandbox) .await - .map_err(|e| Status::internal(format!("persist sandbox failed: {e}")))?; - - // Now create the Kubernetes resource. - match state.sandbox_client.create(&sandbox).await { - Ok(_) => {} - Err(kube::Error::Api(err)) if err.code == 409 => { - let _ = state.store.delete("sandbox", &id).await; - state.sandbox_index.remove_sandbox(&id); - warn!( - sandbox_id = %id, - sandbox_name = %name, - "Sandbox already exists in Kubernetes" - ); - return Err(Status::already_exists("sandbox already exists")); - } - Err(err) => { - let _ = state.store.delete("sandbox", &id).await; - state.sandbox_index.remove_sandbox(&id); - warn!( - sandbox_id = %id, - sandbox_name = %name, - error = %err, - "CreateSandbox request failed" - ); - return Err(Status::internal(format!( - "create sandbox in kubernetes failed: {err}" - ))); - } - } + .map_err(|status| { + warn!(error = %status, "Rejecting sandbox create request"); + status + })?; - state.sandbox_watch_bus.notify(&id); + let sandbox = state.compute.create_sandbox(sandbox).await?; info!( - sandbox_id = %id, - sandbox_name = %name, + sandbox_id = %sandbox.id, + sandbox_name = %sandbox.name, "CreateSandbox request completed successfully" ); Ok(Response::new(SandboxResponse { @@ -224,92 +172,8 @@ pub(super) async fn handle_delete_sandbox( return Err(Status::invalid_argument("name is required")); } - let sandbox = state - .store - .get_message_by_name::(&name) - .await - .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))?; - - let Some(mut sandbox) = sandbox else { - return Err(Status::not_found("sandbox not found")); - }; - - let id = sandbox.id.clone(); - - sandbox.phase = SandboxPhase::Deleting as i32; - state - .store - .put_message(&sandbox) - .await - .map_err(|e| Status::internal(format!("persist sandbox failed: {e}")))?; - - state.sandbox_index.update_from_sandbox(&sandbox); - state.sandbox_watch_bus.notify(&id); - - // Clean up SSH sessions associated with this sandbox. - if let Ok(records) = state.store.list(SshSession::object_type(), 1000, 0).await { - for record in records { - if let Ok(session) = SshSession::decode(record.payload.as_slice()) - && session.sandbox_id == id - && let Err(e) = state - .store - .delete(SshSession::object_type(), &session.id) - .await - { - warn!( - session_id = %session.id, - error = %e, - "Failed to delete SSH session during sandbox cleanup" - ); - } - } - } - - // Clean up sandbox-scoped settings record. - if let Err(e) = state - .store - .delete( - super::policy::SANDBOX_SETTINGS_OBJECT_TYPE, - &super::policy::sandbox_settings_id(&id), - ) - .await - { - warn!( - sandbox_id = %id, - error = %e, - "Failed to delete sandbox settings during cleanup" - ); - } - - let deleted = match state.sandbox_client.delete(&sandbox.name).await { - Ok(deleted) => deleted, - Err(err) => { - warn!( - sandbox_id = %id, - sandbox_name = %sandbox.name, - error = %err, - "DeleteSandbox request failed" - ); - return Err(Status::internal(format!( - "delete sandbox in kubernetes failed: {err}" - ))); - } - }; - - if !deleted && let Err(e) = state.store.delete(Sandbox::object_type(), &id).await { - warn!(sandbox_id = %id, error = %e, "Failed to clean up store after delete"); - } - - // Clean up bus entries to prevent unbounded memory growth. - state.tracing_log_bus.remove(&id); - state.tracing_log_bus.platform_event_bus.remove(&id); - state.sandbox_watch_bus.remove(&id); - - info!( - sandbox_id = %id, - sandbox_name = %sandbox.name, - "DeleteSandbox request completed successfully" - ); + let deleted = state.compute.delete_sandbox(&name).await?; + info!(sandbox_name = %name, "DeleteSandbox request completed successfully"); Ok(Response::new(DeleteSandboxResponse { deleted })) } @@ -724,37 +588,10 @@ async fn resolve_sandbox_exec_target( state: &ServerState, sandbox: &Sandbox, ) -> Result<(String, u16), Status> { - if let Some(status) = sandbox.status.as_ref() - && !status.agent_pod.is_empty() - { - match state.sandbox_client.agent_pod_ip(&status.agent_pod).await { - Ok(Some(ip)) => { - return Ok((ip.to_string(), state.config.sandbox_ssh_port)); - } - Ok(None) => { - return Err(Status::failed_precondition( - "sandbox agent pod IP is not available", - )); - } - Err(err) => { - return Err(Status::internal(format!( - "failed to resolve agent pod IP: {err}" - ))); - } - } - } - - if sandbox.name.is_empty() { - return Err(Status::failed_precondition("sandbox has no name")); + match state.compute.resolve_sandbox_endpoint(sandbox).await? { + crate::compute::ResolvedEndpoint::Ip(ip, port) => Ok((ip.to_string(), port)), + crate::compute::ResolvedEndpoint::Host(host, port) => Ok((host, port)), } - - Ok(( - format!( - "{}.{}.svc.cluster.local", - sandbox.name, state.config.sandbox_namespace - ), - state.config.sandbox_ssh_port, - )) } /// Shell-escape a value for embedding in a POSIX shell command. diff --git a/crates/openshell-server/src/lib.rs b/crates/openshell-server/src/lib.rs index e827b3628..c9ff7704c 100644 --- a/crates/openshell-server/src/lib.rs +++ b/crates/openshell-server/src/lib.rs @@ -10,12 +10,12 @@ //! - mTLS support mod auth; +mod compute; mod grpc; mod http; mod inference; mod multiplex; mod persistence; -mod sandbox; mod sandbox_index; mod sandbox_watch; mod ssh_tunnel; @@ -30,13 +30,14 @@ use std::sync::{Arc, Mutex}; use tokio::net::TcpListener; use tracing::{debug, error, info}; +use compute::ComputeRuntime; pub use grpc::OpenShellService; pub use http::{health_router, http_router}; pub use multiplex::{MultiplexService, MultiplexedService}; +use openshell_driver_kubernetes::KubernetesComputeConfig; use persistence::Store; -use sandbox::{SandboxClient, spawn_sandbox_watcher, spawn_store_reconciler}; use sandbox_index::SandboxIndex; -use sandbox_watch::{SandboxWatchBus, spawn_kube_event_tailer}; +use sandbox_watch::SandboxWatchBus; pub use tls::TlsAcceptor; use tracing_bus::TracingLogBus; @@ -49,8 +50,8 @@ pub struct ServerState { /// Persistence store. pub store: Arc, - /// Kubernetes sandbox client. - pub sandbox_client: SandboxClient, + /// Compute orchestration over the configured driver. + pub compute: ComputeRuntime, /// In-memory sandbox correlation index. pub sandbox_index: SandboxIndex, @@ -87,7 +88,7 @@ impl ServerState { pub fn new( config: Config, store: Arc, - sandbox_client: SandboxClient, + compute: ComputeRuntime, sandbox_index: SandboxIndex, sandbox_watch_bus: SandboxWatchBus, tracing_log_bus: TracingLogBus, @@ -95,7 +96,7 @@ impl ServerState { Self { config, store, - sandbox_client, + compute, sandbox_index, sandbox_watch_bus, tracing_log_bus, @@ -124,48 +125,40 @@ pub async fn run_server(config: Config, tracing_log_bus: TracingLogBus) -> Resul )); } - let store = Store::connect(database_url).await?; - let sandbox_client = SandboxClient::new( - config.sandbox_namespace.clone(), - config.sandbox_image.clone(), - config.sandbox_image_pull_policy.clone(), - config.grpc_endpoint.clone(), - format!("0.0.0.0:{}", config.sandbox_ssh_port), - config.ssh_handshake_secret.clone(), - config.ssh_handshake_skew_secs, - config.client_tls_secret_name.clone(), - config.host_gateway_ip.clone(), - ) - .await - .map_err(|e| Error::execution(format!("failed to create kubernetes client: {e}")))?; - let store = Arc::new(store); + let store = Arc::new(Store::connect(database_url).await?); let sandbox_index = SandboxIndex::new(); let sandbox_watch_bus = SandboxWatchBus::new(); + let compute = ComputeRuntime::new_kubernetes( + KubernetesComputeConfig { + namespace: config.sandbox_namespace.clone(), + default_image: config.sandbox_image.clone(), + image_pull_policy: config.sandbox_image_pull_policy.clone(), + grpc_endpoint: config.grpc_endpoint.clone(), + ssh_listen_addr: format!("0.0.0.0:{}", config.sandbox_ssh_port), + ssh_port: config.sandbox_ssh_port, + ssh_handshake_secret: config.ssh_handshake_secret.clone(), + ssh_handshake_skew_secs: config.ssh_handshake_skew_secs, + client_tls_secret_name: config.client_tls_secret_name.clone(), + host_gateway_ip: config.host_gateway_ip.clone(), + }, + store.clone(), + sandbox_index.clone(), + sandbox_watch_bus.clone(), + tracing_log_bus.clone(), + ) + .await + .map_err(|e| Error::execution(format!("failed to create compute runtime: {e}")))?; let state = Arc::new(ServerState::new( config.clone(), store.clone(), - sandbox_client, + compute, sandbox_index, sandbox_watch_bus, tracing_log_bus, )); - spawn_sandbox_watcher( - store.clone(), - state.sandbox_client.clone(), - state.sandbox_index.clone(), - state.sandbox_watch_bus.clone(), - state.tracing_log_bus.clone(), - ); - spawn_store_reconciler( - store.clone(), - state.sandbox_client.clone(), - state.sandbox_index.clone(), - state.sandbox_watch_bus.clone(), - state.tracing_log_bus.clone(), - ); - spawn_kube_event_tailer(state.clone()); + state.compute.spawn_watchers(); ssh_tunnel::spawn_session_reaper(store.clone(), std::time::Duration::from_secs(3600)); // Create the multiplexed service diff --git a/crates/openshell-server/src/sandbox_watch.rs b/crates/openshell-server/src/sandbox_watch.rs index 6b5ec8f1a..73cc4bf26 100644 --- a/crates/openshell-server/src/sandbox_watch.rs +++ b/crates/openshell-server/src/sandbox_watch.rs @@ -6,17 +6,8 @@ use std::collections::HashMap; use std::sync::{Arc, Mutex}; -use futures::{StreamExt, TryStreamExt}; -use k8s_openapi::api::core::v1::Event as KubeEventObj; -use kube::Client; -use kube::api::Api; -use kube::runtime::watcher::{self, Event}; -use openshell_core::proto::{PlatformEvent, SandboxStreamEvent}; use tokio::sync::broadcast; use tonic::Status; -use tracing::{debug, warn}; - -use crate::ServerState; /// Broadcast bus of sandbox updates keyed by sandbox id. /// @@ -68,112 +59,6 @@ impl SandboxWatchBus { } } -/// Spawn a background Kubernetes Event tailer. -/// -/// This tailer publishes platform events (sourced from Kubernetes) into per-sandbox broadcast streams. -pub fn spawn_kube_event_tailer(state: Arc) { - tokio::spawn(async move { - let client = match Client::try_default().await { - Ok(c) => c, - Err(e) => { - warn!(error = %e, "Failed to create kube client for event tailer"); - return; - } - }; - - let ns = state.config.sandbox_namespace.clone(); - let api: Api = Api::namespaced(client, &ns); - - // We don't have a stable label to select Events by sandbox id. - // Instead, we watch all Events in the namespace and dispatch using the in-memory index. - // This is best-effort and efficient enough for typical sandbox counts. - let mut stream = watcher::watcher(api, watcher::Config::default()).boxed(); - - loop { - match stream.try_next().await { - Ok(Some(Event::Applied(obj))) => { - if let Some((sandbox_id, evt)) = map_kube_event_to_platform(&state, &obj) { - state - .tracing_log_bus - .platform_event_bus - .publish(&sandbox_id, evt); - } - } - Ok(Some(Event::Deleted(_))) => {} - Ok(Some(Event::Restarted(_))) => { - debug!(namespace = %ns, "Kubernetes event watcher restarted"); - } - Ok(None) => { - warn!(namespace = %ns, "Kubernetes event watcher stream ended"); - break; - } - Err(err) => { - warn!(namespace = %ns, error = %err, "Kubernetes event watcher error"); - } - } - } - }); -} - -fn map_kube_event_to_platform( - state: &ServerState, - obj: &KubeEventObj, -) -> Option<(String, SandboxStreamEvent)> { - let involved = obj.involved_object.clone(); - let involved_kind = involved.kind.unwrap_or_default(); - let involved_name = involved.name.unwrap_or_default(); - - let sandbox_id = match involved_kind.as_str() { - "Sandbox" => state - .sandbox_index - .sandbox_id_for_sandbox_name(&involved_name)?, - "Pod" => { - // The sandbox controller creates pods with the same name as the sandbox, - // so try looking up by sandbox name first, then fall back to agent_pod index. - state - .sandbox_index - .sandbox_id_for_sandbox_name(&involved_name) - .or_else(|| state.sandbox_index.sandbox_id_for_agent_pod(&involved_name))? - } - _ => return None, - }; - - let ts = obj - .last_timestamp - .as_ref() - .or(obj.first_timestamp.as_ref()) - .map_or(0, |t| t.0.timestamp_millis()); - - // Build metadata map with Kubernetes-specific details - let mut metadata = HashMap::new(); - metadata.insert("involved_kind".to_string(), involved_kind); - metadata.insert("involved_name".to_string(), involved_name); - if let Some(ns) = &obj.involved_object.namespace { - metadata.insert("namespace".to_string(), ns.clone()); - } - if let Some(count) = obj.count { - metadata.insert("count".to_string(), count.to_string()); - } - - let evt = PlatformEvent { - timestamp_ms: ts, - source: "kubernetes".to_string(), - r#type: obj.type_.clone().unwrap_or_default(), - reason: obj.reason.clone().unwrap_or_default(), - message: obj.message.clone().unwrap_or_default(), - metadata, - }; - - Some(( - sandbox_id, - SandboxStreamEvent { - payload: Some(openshell_core::proto::sandbox_stream_event::Payload::Event( - evt, - )), - }, - )) -} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/openshell-server/src/ssh_tunnel.rs b/crates/openshell-server/src/ssh_tunnel.rs index 899f413b6..536513ccd 100644 --- a/crates/openshell-server/src/ssh_tunnel.rs +++ b/crates/openshell-server/src/ssh_tunnel.rs @@ -100,25 +100,18 @@ async fn ssh_connect( return StatusCode::PRECONDITION_FAILED.into_response(); } - let connect_target = if let Some(status) = sandbox.status.as_ref() - && !status.agent_pod.is_empty() - { - match state.sandbox_client.agent_pod_ip(&status.agent_pod).await { - Ok(Some(ip)) => ConnectTarget::Ip(SocketAddr::new(ip, state.config.sandbox_ssh_port)), - Ok(None) => return StatusCode::BAD_GATEWAY.into_response(), - Err(err) => { - warn!(error = %err, "Failed to resolve agent pod IP"); - return StatusCode::BAD_GATEWAY.into_response(); - } + let connect_target = match state.compute.resolve_sandbox_endpoint(&sandbox).await { + Ok(crate::compute::ResolvedEndpoint::Ip(ip, port)) => { + ConnectTarget::Ip(SocketAddr::new(ip, port)) + } + Ok(crate::compute::ResolvedEndpoint::Host(host, port)) => ConnectTarget::Host(host, port), + Err(status) if status.code() == tonic::Code::FailedPrecondition => { + return StatusCode::PRECONDITION_FAILED.into_response(); + } + Err(err) => { + warn!(error = %err, "Failed to resolve sandbox endpoint"); + return StatusCode::BAD_GATEWAY.into_response(); } - } else if !sandbox.name.is_empty() { - let service_host = format!( - "{}.{}.svc.cluster.local", - sandbox.name, state.config.sandbox_namespace - ); - ConnectTarget::Host(service_host, state.config.sandbox_ssh_port) - } else { - return StatusCode::PRECONDITION_FAILED.into_response(); }; // Enforce per-token concurrent connection limit. { diff --git a/proto/compute_driver.proto b/proto/compute_driver.proto new file mode 100644 index 000000000..9005ca23b --- /dev/null +++ b/proto/compute_driver.proto @@ -0,0 +1,100 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package openshell.compute.v1; + +import "sandbox.proto"; + +// Internal compute-driver contract used by the gateway. +service ComputeDriver { + // Report driver capabilities and defaults. + rpc GetCapabilities(GetCapabilitiesRequest) returns (GetCapabilitiesResponse); + + // Validate a sandbox before create-time provisioning. + rpc ValidateSandboxCreate(ValidateSandboxCreateRequest) + returns (ValidateSandboxCreateResponse); + + // Provision platform resources for a sandbox. + rpc CreateSandbox(ComputeCreateSandboxRequest) returns (ComputeCreateSandboxResponse); + + // Tear down platform resources for a sandbox. + rpc DeleteSandbox(ComputeDeleteSandboxRequest) returns (ComputeDeleteSandboxResponse); + + // Resolve the current endpoint for sandbox exec/SSH transport. + rpc ResolveSandboxEndpoint(ResolveSandboxEndpointRequest) + returns (ResolveSandboxEndpointResponse); + + // Stream sandbox observations from the platform. + rpc WatchSandboxes(WatchSandboxesRequest) returns (stream WatchSandboxesEvent); +} + +message GetCapabilitiesRequest {} + +message GetCapabilitiesResponse { + string driver_name = 1; + string driver_version = 2; + string default_image = 3; + bool supports_gpu = 4; +} + +message ValidateSandboxCreateRequest { + openshell.sandbox.v1.Sandbox sandbox = 1; +} + +message ValidateSandboxCreateResponse {} + +message ComputeCreateSandboxRequest { + openshell.sandbox.v1.Sandbox sandbox = 1; +} + +message ComputeCreateSandboxResponse {} + +message ComputeDeleteSandboxRequest { + string sandbox_id = 1; + string sandbox_name = 2; +} + +message ComputeDeleteSandboxResponse { + bool deleted = 1; +} + +message ResolveSandboxEndpointRequest { + openshell.sandbox.v1.Sandbox sandbox = 1; +} + +message SandboxEndpoint { + oneof target { + string ip = 1; + string host = 2; + } + uint32 port = 3; +} + +message ResolveSandboxEndpointResponse { + SandboxEndpoint endpoint = 1; +} + +message WatchSandboxesRequest {} + +message WatchSandboxesSandboxEvent { + openshell.sandbox.v1.Sandbox sandbox = 1; +} + +message WatchSandboxesDeletedEvent { + string sandbox_id = 1; +} + +message WatchSandboxesPlatformEvent { + string sandbox_id = 1; + openshell.sandbox.v1.PlatformEvent event = 2; +} + +message WatchSandboxesEvent { + oneof payload { + WatchSandboxesSandboxEvent sandbox = 1; + WatchSandboxesDeletedEvent deleted = 2; + WatchSandboxesPlatformEvent platform_event = 3; + } +} diff --git a/proto/datamodel.proto b/proto/datamodel.proto index 2232a1228..f84d1e352 100644 --- a/proto/datamodel.proto +++ b/proto/datamodel.proto @@ -5,76 +5,6 @@ syntax = "proto3"; package openshell.datamodel.v1; -import "google/protobuf/struct.proto"; -import "sandbox.proto"; - -// Sandbox model stored by OpenShell. -message Sandbox { - string id = 1; - string name = 2; - string namespace = 3; - SandboxSpec spec = 4; - SandboxStatus status = 5; - SandboxPhase phase = 6; - // Milliseconds since Unix epoch when the sandbox was created. - int64 created_at_ms = 7; - // Currently active policy version (updated when sandbox reports loaded). - uint32 current_policy_version = 8; -} - -// OpenShell-level sandbox spec. -message SandboxSpec { - string log_level = 1; - map environment = 5; - SandboxTemplate template = 6; - // Required sandbox policy configuration. - openshell.sandbox.v1.SandboxPolicy policy = 7; - // Provider names to attach to this sandbox. - repeated string providers = 8; - // Request NVIDIA GPU resources for this sandbox. - bool gpu = 9; -} - -// Sandbox template mapped onto Kubernetes pod template inputs. -message SandboxTemplate { - string image = 1; - string runtime_class_name = 2; - string agent_socket = 3; - map labels = 4; - map annotations = 5; - map environment = 6; - google.protobuf.Struct resources = 7; - google.protobuf.Struct volume_claim_templates = 9; -} - -// Sandbox status captured from Kubernetes. -message SandboxStatus { - string sandbox_name = 1; - string agent_pod = 2; - string agent_fd = 3; - string sandbox_fd = 4; - repeated SandboxCondition conditions = 5; -} - -// Sandbox condition mirrors Kubernetes conditions. -message SandboxCondition { - string type = 1; - string status = 2; - string reason = 3; - string message = 4; - string last_transition_time = 5; -} - -// High-level sandbox lifecycle phase. -enum SandboxPhase { - SANDBOX_PHASE_UNSPECIFIED = 0; - SANDBOX_PHASE_PROVISIONING = 1; - SANDBOX_PHASE_READY = 2; - SANDBOX_PHASE_ERROR = 3; - SANDBOX_PHASE_DELETING = 4; - SANDBOX_PHASE_UNKNOWN = 5; -} - // Provider model stored by OpenShell. message Provider { string id = 1; diff --git a/proto/openshell.proto b/proto/openshell.proto index 04f705020..43d903b6d 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -142,7 +142,7 @@ message HealthResponse { // Create sandbox request. message CreateSandboxRequest { - openshell.datamodel.v1.SandboxSpec spec = 1; + openshell.sandbox.v1.SandboxSpec spec = 1; // Optional user-supplied sandbox name. When empty the server generates one. string name = 2; } @@ -167,12 +167,12 @@ message DeleteSandboxRequest { // Sandbox response. message SandboxResponse { - openshell.datamodel.v1.Sandbox sandbox = 1; + openshell.sandbox.v1.Sandbox sandbox = 1; } // List sandboxes response. message ListSandboxesResponse { - repeated openshell.datamodel.v1.Sandbox sandboxes = 1; + repeated openshell.sandbox.v1.Sandbox sandboxes = 1; } // Delete sandbox response. @@ -336,11 +336,11 @@ message WatchSandboxRequest { message SandboxStreamEvent { oneof payload { // Latest sandbox snapshot. - openshell.datamodel.v1.Sandbox sandbox = 1; + openshell.sandbox.v1.Sandbox sandbox = 1; // One server log line/event. SandboxLogLine log = 2; // One platform event. - PlatformEvent event = 3; + openshell.sandbox.v1.PlatformEvent event = 3; // Warning from the server (e.g. missed messages due to lag). SandboxStreamWarning warning = 4; // Draft policy update notification. @@ -362,22 +362,6 @@ message SandboxLogLine { map fields = 7; } -// Platform event correlated to a sandbox. -message PlatformEvent { - // Event timestamp in milliseconds since epoch. - int64 timestamp_ms = 1; - // Event source (e.g. "kubernetes", "docker", "process"). - string source = 2; - // Event type/severity (e.g. "Normal", "Warning"). - string type = 3; - // Short reason code (e.g. "Started", "Pulled", "Failed"). - string reason = 4; - // Human-readable event message. - string message = 5; - // Optional metadata as key-value pairs. - map metadata = 6; -} - message SandboxStreamWarning { string message = 1; } diff --git a/proto/sandbox.proto b/proto/sandbox.proto index 61948a527..f810f1e0a 100644 --- a/proto/sandbox.proto +++ b/proto/sandbox.proto @@ -5,6 +5,91 @@ syntax = "proto3"; package openshell.sandbox.v1; +import "google/protobuf/struct.proto"; + +// Sandbox model stored by OpenShell. +message Sandbox { + string id = 1; + string name = 2; + string namespace = 3; + SandboxSpec spec = 4; + SandboxStatus status = 5; + SandboxPhase phase = 6; + // Milliseconds since Unix epoch when the sandbox was created. + int64 created_at_ms = 7; + // Currently active policy version (updated when sandbox reports loaded). + uint32 current_policy_version = 8; +} + +// OpenShell-level sandbox spec. +message SandboxSpec { + string log_level = 1; + map environment = 5; + SandboxTemplate template = 6; + // Required sandbox policy configuration. + SandboxPolicy policy = 7; + // Provider names to attach to this sandbox. + repeated string providers = 8; + // Request NVIDIA GPU resources for this sandbox. + bool gpu = 9; +} + +// Sandbox template mapped onto compute-driver template inputs. +message SandboxTemplate { + string image = 1; + string runtime_class_name = 2; + string agent_socket = 3; + map labels = 4; + map annotations = 5; + map environment = 6; + google.protobuf.Struct resources = 7; + google.protobuf.Struct volume_claim_templates = 9; +} + +// Sandbox status captured from the compute platform. +message SandboxStatus { + string sandbox_name = 1; + string agent_pod = 2; + string agent_fd = 3; + string sandbox_fd = 4; + repeated SandboxCondition conditions = 5; +} + +// Sandbox condition mirrors the compute platform condition model. +message SandboxCondition { + string type = 1; + string status = 2; + string reason = 3; + string message = 4; + string last_transition_time = 5; +} + +// High-level sandbox lifecycle phase. +enum SandboxPhase { + SANDBOX_PHASE_UNSPECIFIED = 0; + SANDBOX_PHASE_PROVISIONING = 1; + SANDBOX_PHASE_READY = 2; + SANDBOX_PHASE_ERROR = 3; + SANDBOX_PHASE_DELETING = 4; + SANDBOX_PHASE_UNKNOWN = 5; +} + +// Platform event correlated to a sandbox. +message PlatformEvent { + // Event timestamp in milliseconds since epoch. + int64 timestamp_ms = 1; + // Event source (e.g. "kubernetes", "docker", "process"). + string source = 2; + // Event type/severity (e.g. "Normal", "Warning"). + string type = 3; + // Short reason code (e.g. "Started", "Pulled", "Failed"). + string reason = 4; + // Human-readable event message. + string message = 5; + // Optional metadata as key-value pairs. + map metadata = 6; +} + // Sandbox security policy configuration. message SandboxPolicy { // Policy version. From 890818f7ed3595356c354243c19cf83438b31dad Mon Sep 17 00:00:00 2001 From: Drew Newberry Date: Mon, 13 Apr 2026 21:51:14 -0700 Subject: [PATCH 2/7] cleanup --- .../skills/debug-openshell-cluster/SKILL.md | 2 + architecture/gateway.md | 21 +- crates/openshell-core/src/proto/mod.rs | 1 - .../openshell-driver-kubernetes/src/driver.rs | 369 +++++------------ .../openshell-driver-kubernetes/src/grpc.rs | 73 +++- .../openshell-driver-kubernetes/src/main.rs | 2 +- crates/openshell-server/src/compute/mod.rs | 378 ++++++++++++++++-- deploy/docker/Dockerfile.images | 5 + proto/compute_driver.proto | 189 ++++++++- proto/openshell.proto | 138 ++++++- proto/sandbox.proto | 91 +---- tasks/scripts/cluster-deploy-fast.sh | 4 +- 12 files changed, 848 insertions(+), 425 deletions(-) diff --git a/.agents/skills/debug-openshell-cluster/SKILL.md b/.agents/skills/debug-openshell-cluster/SKILL.md index 4ef851a7e..f4c5672a2 100644 --- a/.agents/skills/debug-openshell-cluster/SKILL.md +++ b/.agents/skills/debug-openshell-cluster/SKILL.md @@ -182,6 +182,8 @@ Component images (server, sandbox) can reach kubelet via two paths: **Local/external pull mode** (default local via `mise run cluster`): Local images are tagged to the configured local registry base (default `127.0.0.1:5000/openshell/*`), pushed to that registry, and pulled by k3s via `registries.yaml` mirror endpoint (typically `host.docker.internal:5000`). The `cluster` task pushes prebuilt local tags (`openshell/*:dev`, falling back to `localhost:5000/openshell/*:dev` or `127.0.0.1:5000/openshell/*:dev`). +Gateway image builds now stage a partial Rust workspace from `deploy/docker/Dockerfile.images`. If cargo fails with a missing manifest under `/build/crates/...`, verify that every current gateway dependency crate (including `openshell-driver-kubernetes`) is copied into the staged workspace there. + ```bash # Verify image refs currently used by openshell deployment openshell doctor exec -- kubectl -n openshell get statefulset openshell -o jsonpath="{.spec.template.spec.containers[*].image}" diff --git a/architecture/gateway.md b/architecture/gateway.md index 53e547235..b783e7c03 100644 --- a/architecture/gateway.md +++ b/architecture/gateway.md @@ -80,11 +80,11 @@ Proto definitions consumed by the gateway: | Proto file | Package | Defines | |------------|---------|---------| -| `proto/openshell.proto` | `openshell.v1` | `OpenShell` service, sandbox/provider/SSH/watch messages | -| `proto/compute_driver.proto` | `openshell.compute.v1` | Internal `ComputeDriver` service, endpoint resolution, compute watch stream envelopes | +| `proto/openshell.proto` | `openshell.v1` | `OpenShell` service, public sandbox resource model, provider/SSH/watch messages | +| `proto/compute_driver.proto` | `openshell.compute.v1` | Internal `ComputeDriver` service, driver-native sandbox observations, endpoint resolution, compute watch stream envelopes | | `proto/inference.proto` | `openshell.inference.v1` | `Inference` service: `SetClusterInference`, `GetClusterInference`, `GetInferenceBundle` | | `proto/datamodel.proto` | `openshell.datamodel.v1` | `Provider` | -| `proto/sandbox.proto` | `openshell.sandbox.v1` | Shared sandbox lifecycle types (`Sandbox`, `SandboxSpec`, `SandboxStatus`, `SandboxPhase`, `PlatformEvent`) plus policy/settings messages | +| `proto/sandbox.proto` | `openshell.sandbox.v1` | Sandbox supervisor policy, settings, and config messages | ## Startup Sequence @@ -502,25 +502,28 @@ The Helm chart template is at `deploy/helm/openshell/templates/statefulset.yaml` `KubernetesComputeDriver` (`crates/openshell-driver-kubernetes/src/driver.rs`) manages `agents.x-k8s.io/v1alpha1/Sandbox` CRDs behind the gateway's compute interface. -- **Create**: Translates a shared `openshell.sandbox.v1.Sandbox` message into a Kubernetes `DynamicObject` with labels (`openshell.ai/sandbox-id`, `openshell.ai/managed-by: openshell`) and a spec that includes the pod template, environment variables, and gateway-required env vars (`OPENSHELL_SANDBOX_ID`, `OPENSHELL_ENDPOINT`, `OPENSHELL_SSH_LISTEN_ADDR`, etc.). When callers do not provide custom `volumeClaimTemplates`, the driver injects a default `workspace` PVC and mounts it at `/sandbox` so the default sandbox home/workdir survives pod rescheduling. +- **Get**: `GetSandbox` looks up a sandbox CRD by name and returns a driver-native platform observation (`openshell.compute.v1.DriverSandbox`) with raw status and condition data from the object. +- **List**: `ListSandboxes` enumerates sandbox CRDs and returns driver-native platform observations for each, sorted by name for stable results. +- **Create**: Translates an internal `openshell.compute.v1.DriverSandbox` message into a Kubernetes `DynamicObject` with labels (`openshell.ai/sandbox-id`, `openshell.ai/managed-by: openshell`) and a spec that includes the pod template, environment variables, and gateway-required env vars (`OPENSHELL_SANDBOX_ID`, `OPENSHELL_ENDPOINT`, `OPENSHELL_SSH_LISTEN_ADDR`, etc.). When callers do not provide custom `volumeClaimTemplates`, the driver injects a default `workspace` PVC and mounts it at `/sandbox` so the default sandbox home/workdir survives pod rescheduling. - **Delete**: Calls the Kubernetes API to delete the CRD by name. Returns `false` if already gone (404). +- **Stop**: `proto/compute_driver.proto` now reserves `StopSandbox` for a non-destructive lifecycle transition. Resume is intentionally not a dedicated compute-driver RPC; the gateway is expected to auto-resume a stopped sandbox when a client connects or executes into it. - **Pod IP resolution**: `agent_pod_ip()` fetches the agent pod and reads `status.podIP`. ### Sandbox Watcher -The Kubernetes driver emits `WatchSandboxes` events through `proto/compute_driver.proto`. `ComputeRuntime` consumes that stream and applies the resulting snapshots to the store. +The Kubernetes driver emits `WatchSandboxes` events through `proto/compute_driver.proto`. `ComputeRuntime` consumes that stream, translates the driver-native snapshots into public `openshell.v1.Sandbox` resources, derives the public phase, and applies the results to the store. -- **Applied**: Extracts the sandbox ID from labels (or falls back to name prefix stripping), reads the CRD status, derives the phase, and upserts the sandbox record in the store. Notifies the watch bus. +- **Applied**: Extracts the sandbox ID from labels (or falls back to name prefix stripping), reads the CRD status, emits a driver-native snapshot, and lets the gateway translate that into the stored public sandbox record. Notifies the watch bus. - **Deleted**: Removes the sandbox record from the store and the index. Notifies the watch bus. - **Restarted**: Re-processes all objects (full resync). -### Phase Derivation +### Gateway Phase Derivation -`derive_phase()` maps Kubernetes condition state to `SandboxPhase`: +`ComputeRuntime::derive_phase()` (`crates/openshell-server/src/compute/mod.rs`) maps driver-native compute status to the public `SandboxPhase` exposed by `proto/openshell.proto`: | Condition | Phase | |-----------|-------| -| `deletionTimestamp` is set | `Deleting` | +| Driver status `deleting=true` | `Deleting` | | Ready condition `status=True` | `Ready` | | Ready condition `status=False`, terminal reason | `Error` | | Ready condition `status=False`, transient reason | `Provisioning` | diff --git a/crates/openshell-core/src/proto/mod.rs b/crates/openshell-core/src/proto/mod.rs index 2644cb39a..08b062d2e 100644 --- a/crates/openshell-core/src/proto/mod.rs +++ b/crates/openshell-core/src/proto/mod.rs @@ -79,7 +79,6 @@ pub mod inference { } } -pub use compute::v1::*; pub use datamodel::v1::*; pub use inference::v1::*; pub use openshell::*; diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index 85d781045..3bc99e520 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -11,11 +11,13 @@ use kube::core::gvk::GroupVersionKind; use kube::core::{DynamicObject, ObjectMeta}; use kube::runtime::watcher::{self, Event}; use kube::{Client, Error as KubeError}; -use openshell_core::proto::{ - GetCapabilitiesResponse, PlatformEvent, ResolveSandboxEndpointResponse, Sandbox, - SandboxCondition, SandboxEndpoint, SandboxPhase, SandboxSpec, SandboxStatus, SandboxTemplate, +use openshell_core::proto::compute::v1::{ + DriverCondition as SandboxCondition, DriverPlatformEvent as PlatformEvent, + DriverSandbox as Sandbox, DriverSandboxSpec as SandboxSpec, + DriverSandboxStatus as SandboxStatus, DriverSandboxTemplate as SandboxTemplate, + GetCapabilitiesResponse, ResolveSandboxEndpointResponse, SandboxEndpoint, WatchSandboxesDeletedEvent, WatchSandboxesEvent, WatchSandboxesPlatformEvent, - WatchSandboxesSandboxEvent, + WatchSandboxesSandboxEvent, sandbox_endpoint, watch_sandboxes_event, }; use std::collections::BTreeMap; use std::net::IpAddr; @@ -186,6 +188,85 @@ impl KubernetesComputeDriver { Ok(()) } + pub async fn get_sandbox(&self, name: &str) -> Result, String> { + info!( + sandbox_name = %name, + namespace = %self.config.namespace, + "Fetching sandbox from Kubernetes" + ); + + let api = self.api(); + match tokio::time::timeout(KUBE_API_TIMEOUT, api.get(name)).await { + Ok(Ok(obj)) => sandbox_from_object(&self.config.namespace, obj).map(Some), + Ok(Err(KubeError::Api(err))) if err.code == 404 => { + debug!(sandbox_name = %name, "Sandbox not found in Kubernetes"); + Ok(None) + } + Ok(Err(err)) => { + warn!( + sandbox_name = %name, + error = %err, + "Failed to fetch sandbox from Kubernetes" + ); + Err(err.to_string()) + } + Err(_elapsed) => { + warn!( + sandbox_name = %name, + timeout_secs = KUBE_API_TIMEOUT.as_secs(), + "Timed out fetching sandbox from Kubernetes" + ); + Err(format!( + "timed out after {}s waiting for Kubernetes API", + KUBE_API_TIMEOUT.as_secs() + )) + } + } + } + + pub async fn list_sandboxes(&self) -> Result, String> { + info!( + namespace = %self.config.namespace, + "Listing sandboxes from Kubernetes" + ); + + let api = self.api(); + match tokio::time::timeout(KUBE_API_TIMEOUT, api.list(&ListParams::default())).await { + Ok(Ok(list)) => { + let mut sandboxes = list + .items + .into_iter() + .map(|obj| sandbox_from_object(&self.config.namespace, obj)) + .collect::, _>>()?; + sandboxes.sort_by(|left, right| { + left.name + .cmp(&right.name) + .then_with(|| left.id.cmp(&right.id)) + }); + Ok(sandboxes) + } + Ok(Err(err)) => { + warn!( + namespace = %self.config.namespace, + error = %err, + "Failed to list sandboxes from Kubernetes" + ); + Err(err.to_string()) + } + Err(_elapsed) => { + warn!( + namespace = %self.config.namespace, + timeout_secs = KUBE_API_TIMEOUT.as_secs(), + "Timed out listing sandboxes from Kubernetes" + ); + Err(format!( + "timed out after {}s waiting for Kubernetes API", + KUBE_API_TIMEOUT.as_secs() + )) + } + } + } + fn ssh_handshake_secret(&self) -> &str { &self.config.ssh_handshake_secret } @@ -323,9 +404,7 @@ impl KubernetesComputeDriver { Ok(Some(ip)) => { return Ok(ResolveSandboxEndpointResponse { endpoint: Some(SandboxEndpoint { - target: Some(openshell_core::proto::sandbox_endpoint::Target::Ip( - ip.to_string(), - )), + target: Some(sandbox_endpoint::Target::Ip(ip.to_string())), port: u32::from(self.config.ssh_port), }), }); @@ -345,12 +424,10 @@ impl KubernetesComputeDriver { Ok(ResolveSandboxEndpointResponse { endpoint: Some(SandboxEndpoint { - target: Some(openshell_core::proto::sandbox_endpoint::Target::Host( - format!( - "{}.{}.svc.cluster.local", - sandbox.name, self.config.namespace - ), - )), + target: Some(sandbox_endpoint::Target::Host(format!( + "{}.{}.svc.cluster.local", + sandbox.name, self.config.namespace + ))), port: u32::from(self.config.ssh_port), }), }) @@ -376,7 +453,7 @@ impl KubernetesComputeDriver { Ok(sandbox) => { update_indexes(&mut sandbox_name_to_id, &mut agent_pod_to_id, &sandbox); let event = WatchSandboxesEvent { - payload: Some(openshell_core::proto::watch_sandboxes_event::Payload::Sandbox( + payload: Some(watch_sandboxes_event::Payload::Sandbox( WatchSandboxesSandboxEvent { sandbox: Some(sandbox) } )), }; @@ -396,7 +473,7 @@ impl KubernetesComputeDriver { Ok(sandbox_id) => { remove_indexes(&mut sandbox_name_to_id, &mut agent_pod_to_id, &sandbox_id); let event = WatchSandboxesEvent { - payload: Some(openshell_core::proto::watch_sandboxes_event::Payload::Deleted( + payload: Some(watch_sandboxes_event::Payload::Deleted( WatchSandboxesDeletedEvent { sandbox_id } )), }; @@ -417,7 +494,7 @@ impl KubernetesComputeDriver { Ok(sandbox) => { update_indexes(&mut sandbox_name_to_id, &mut agent_pod_to_id, &sandbox); let event = WatchSandboxesEvent { - payload: Some(openshell_core::proto::watch_sandboxes_event::Payload::Sandbox( + payload: Some(watch_sandboxes_event::Payload::Sandbox( WatchSandboxesSandboxEvent { sandbox: Some(sandbox) } )), }; @@ -452,7 +529,7 @@ impl KubernetesComputeDriver { &obj, ) { let event = WatchSandboxesEvent { - payload: Some(openshell_core::proto::watch_sandboxes_event::Payload::PlatformEvent( + payload: Some(watch_sandboxes_event::Payload::PlatformEvent( WatchSandboxesPlatformEvent { sandbox_id, event: Some(event) } )), }; @@ -517,9 +594,7 @@ fn sandbox_from_object(namespace: &str, obj: DynamicObject) -> Result Result Option { .unwrap_or_default() .to_string(), conditions, + deleting: obj.metadata.deletion_timestamp.is_some(), }) } @@ -1302,265 +1377,11 @@ fn condition_from_value(value: &serde_json::Value) -> Option { }) } -#[cfg_attr(not(test), allow(dead_code))] -fn rewrite_user_facing_conditions(status: &mut Option, spec: Option<&SandboxSpec>) { - let gpu_requested = spec.is_some_and(|sandbox_spec| sandbox_spec.gpu); - if !gpu_requested { - return; - } - - if let Some(status) = status { - for condition in &mut status.conditions { - if condition.r#type == "Ready" - && condition.status.eq_ignore_ascii_case("false") - && condition.reason.eq_ignore_ascii_case("Unschedulable") - { - condition.message = "GPU sandbox could not be scheduled on the active gateway. Another GPU sandbox may already be using the available GPU, or the gateway may not currently be able to satisfy GPU placement. Please refer to documentation and use `openshell doctor` commands to inspect GPU support and gateway configuration.".to_string(); - } - } - } -} - -fn derive_phase(status: &Option, deleting: bool) -> SandboxPhase { - if deleting { - return SandboxPhase::Deleting; - } - - if let Some(status) = status { - for condition in &status.conditions { - if condition.r#type == "Ready" { - return if condition.status.eq_ignore_ascii_case("true") { - SandboxPhase::Ready - } else if condition.status.eq_ignore_ascii_case("false") { - if is_terminal_failure_condition(condition) { - SandboxPhase::Error - } else { - SandboxPhase::Provisioning - } - } else { - SandboxPhase::Provisioning - }; - } - } - return SandboxPhase::Provisioning; - } - - SandboxPhase::Unknown -} - -fn is_terminal_failure_condition(condition: &SandboxCondition) -> bool { - let reason = condition.reason.to_ascii_lowercase(); - - // These are transient conditions from the sandbox controller that indicate - // the sandbox is still being provisioned and may become ready: - // - // - ReconcilerError: Controller-level transient error, will be retried - // - DependenciesNotReady: Pod/Service not ready yet, normal during provisioning - // - // Any other Ready=False condition is considered terminal (e.g., the controller - // determined a permanent failure like ImagePullBackOff, Unschedulable, etc.) - let transient_reasons = ["reconcilererror", "dependenciesnotready"]; - - !transient_reasons.contains(&reason.as_str()) -} - #[cfg(test)] mod tests { use super::*; use prost_types::{Struct, Value, value::Kind}; - fn make_condition(reason: &str, message: &str) -> SandboxCondition { - SandboxCondition { - r#type: "Ready".to_string(), - status: "False".to_string(), - reason: reason.to_string(), - message: message.to_string(), - last_transition_time: String::new(), - } - } - - #[test] - fn terminal_failure_treats_unknown_reasons_as_terminal() { - // Any Ready=False condition with an unknown reason is terminal. - // We trust the sandbox controller's assessment. - let terminal_cases = [ - ("Failed", "Something went wrong"), - ("CrashLoopBackOff", "Container keeps crashing"), - ("ImagePullBackOff", "Failed to pull image"), - ("ErrImagePull", "Error pulling image"), - ("Unschedulable", "No nodes match"), - ("SomeOtherReason", "Any other reason is terminal"), - ]; - - for (reason, message) in terminal_cases { - let condition = make_condition(reason, message); - assert!( - is_terminal_failure_condition(&condition), - "Expected terminal failure for reason={reason}, message={message}" - ); - } - } - - #[test] - fn terminal_failure_ignores_transient_reasons() { - // These reasons are transient - the sandbox may still become ready: - // - ReconcilerError: controller will retry - // - DependenciesNotReady: pod/service still being created - let transient_cases = [ - ( - "ReconcilerError", - "Error seen: failed to update pod: Operation cannot be fulfilled", - ), - ("reconcilererror", "lowercase also works"), - ("RECONCILERERROR", "uppercase also works"), - ( - "DependenciesNotReady", - "Pod exists with phase: Pending; Service Exists", - ), - ("dependenciesnotready", "lowercase also works"), - ]; - - for (reason, message) in transient_cases { - let condition = make_condition(reason, message); - assert!( - !is_terminal_failure_condition(&condition), - "Expected transient (non-terminal) for reason={reason}, message={message}" - ); - } - } - - #[test] - fn derive_phase_returns_provisioning_for_transient_conditions() { - // Transient conditions (ReconcilerError, DependenciesNotReady) should - // result in Provisioning phase, not Error. - let transient_conditions = [ - ("ReconcilerError", "Error seen: failed to update pod"), - ( - "DependenciesNotReady", - "Pod exists with phase: Pending; Service Exists", - ), - ]; - - for (reason, message) in transient_conditions { - let status = Some(SandboxStatus { - sandbox_name: "test".to_string(), - agent_pod: "test-pod".to_string(), - agent_fd: String::new(), - sandbox_fd: String::new(), - conditions: vec![SandboxCondition { - r#type: "Ready".to_string(), - status: "False".to_string(), - reason: reason.to_string(), - message: message.to_string(), - last_transition_time: String::new(), - }], - }); - - assert_eq!( - derive_phase(&status, false), - SandboxPhase::Provisioning, - "Expected Provisioning for transient reason={reason}" - ); - } - } - - #[test] - fn derive_phase_returns_error_for_terminal_ready_false() { - let status = Some(SandboxStatus { - sandbox_name: "test".to_string(), - agent_pod: "test-pod".to_string(), - agent_fd: String::new(), - sandbox_fd: String::new(), - conditions: vec![SandboxCondition { - r#type: "Ready".to_string(), - status: "False".to_string(), - reason: "ImagePullBackOff".to_string(), - message: "Failed to pull image".to_string(), - last_transition_time: String::new(), - }], - }); - - assert_eq!(derive_phase(&status, false), SandboxPhase::Error); - } - - #[test] - fn rewrite_user_facing_conditions_rewrites_gpu_unschedulable_message() { - let mut status = Some(SandboxStatus { - sandbox_name: "test".to_string(), - agent_pod: "test-pod".to_string(), - agent_fd: String::new(), - sandbox_fd: String::new(), - conditions: vec![SandboxCondition { - r#type: "Ready".to_string(), - status: "False".to_string(), - reason: "Unschedulable".to_string(), - message: "0/1 nodes are available: 1 Insufficient nvidia.com/gpu.".to_string(), - last_transition_time: String::new(), - }], - }); - - rewrite_user_facing_conditions( - &mut status, - Some(&SandboxSpec { - gpu: true, - ..Default::default() - }), - ); - - let message = &status.unwrap().conditions[0].message; - assert_eq!( - message, - "GPU sandbox could not be scheduled on the active gateway. Another GPU sandbox may already be using the available GPU, or the gateway may not currently be able to satisfy GPU placement. Please refer to documentation and use `openshell doctor` commands to inspect GPU support and gateway configuration." - ); - } - - #[test] - fn rewrite_user_facing_conditions_leaves_non_gpu_unschedulable_message_unchanged() { - let original = "0/1 nodes are available: 1 Insufficient cpu."; - let mut status = Some(SandboxStatus { - sandbox_name: "test".to_string(), - agent_pod: "test-pod".to_string(), - agent_fd: String::new(), - sandbox_fd: String::new(), - conditions: vec![SandboxCondition { - r#type: "Ready".to_string(), - status: "False".to_string(), - reason: "Unschedulable".to_string(), - message: original.to_string(), - last_transition_time: String::new(), - }], - }); - - rewrite_user_facing_conditions( - &mut status, - Some(&SandboxSpec { - gpu: false, - ..Default::default() - }), - ); - - assert_eq!(status.unwrap().conditions[0].message, original); - } - - #[test] - fn derive_phase_returns_ready_for_ready_true() { - let status = Some(SandboxStatus { - sandbox_name: "test".to_string(), - agent_pod: "test-pod".to_string(), - agent_fd: String::new(), - sandbox_fd: String::new(), - conditions: vec![SandboxCondition { - r#type: "Ready".to_string(), - status: "True".to_string(), - reason: "DependenciesReady".to_string(), - message: "Pod is Ready; Service Exists".to_string(), - last_transition_time: String::new(), - }], - }); - - assert_eq!(derive_phase(&status, false), SandboxPhase::Ready); - } - #[test] fn apply_required_env_always_injects_ssh_handshake_secret() { let mut env = Vec::new(); diff --git a/crates/openshell-driver-kubernetes/src/grpc.rs b/crates/openshell-driver-kubernetes/src/grpc.rs index 15589a9e3..a2457a218 100644 --- a/crates/openshell-driver-kubernetes/src/grpc.rs +++ b/crates/openshell-driver-kubernetes/src/grpc.rs @@ -2,12 +2,13 @@ // SPDX-License-Identifier: Apache-2.0 use futures::{Stream, StreamExt}; -use openshell_core::proto::compute_driver_server::ComputeDriver; -use openshell_core::proto::{ - ComputeCreateSandboxRequest, ComputeCreateSandboxResponse, ComputeDeleteSandboxRequest, - ComputeDeleteSandboxResponse, GetCapabilitiesRequest, GetCapabilitiesResponse, - ResolveSandboxEndpointRequest, ResolveSandboxEndpointResponse, ValidateSandboxCreateRequest, - ValidateSandboxCreateResponse, WatchSandboxesEvent, WatchSandboxesRequest, +use openshell_core::proto::compute::v1::{ + CreateSandboxRequest, CreateSandboxResponse, DeleteSandboxRequest, DeleteSandboxResponse, + GetCapabilitiesRequest, GetCapabilitiesResponse, GetSandboxRequest, GetSandboxResponse, + ListSandboxesRequest, ListSandboxesResponse, ResolveSandboxEndpointRequest, + ResolveSandboxEndpointResponse, StopSandboxRequest, StopSandboxResponse, + ValidateSandboxCreateRequest, ValidateSandboxCreateResponse, WatchSandboxesEvent, + WatchSandboxesRequest, compute_driver_server::ComputeDriver, }; use std::pin::Pin; use tonic::{Request, Response, Status}; @@ -51,10 +52,49 @@ impl ComputeDriver for ComputeDriverService { Ok(Response::new(ValidateSandboxCreateResponse {})) } + async fn get_sandbox( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + if request.sandbox_name.is_empty() { + return Err(Status::invalid_argument("sandbox_name is required")); + } + + let sandbox = self + .driver + .get_sandbox(&request.sandbox_name) + .await + .map_err(Status::internal)? + .ok_or_else(|| Status::not_found("sandbox not found"))?; + + if !request.sandbox_id.is_empty() && request.sandbox_id != sandbox.id { + return Err(Status::failed_precondition( + "sandbox_id did not match the fetched sandbox", + )); + } + + Ok(Response::new(GetSandboxResponse { + sandbox: Some(sandbox), + })) + } + + async fn list_sandboxes( + &self, + _request: Request, + ) -> Result, Status> { + let sandboxes = self + .driver + .list_sandboxes() + .await + .map_err(Status::internal)?; + Ok(Response::new(ListSandboxesResponse { sandboxes })) + } + async fn create_sandbox( &self, - request: Request, - ) -> Result, Status> { + request: Request, + ) -> Result, Status> { let sandbox = request .into_inner() .sandbox @@ -63,20 +103,29 @@ impl ComputeDriver for ComputeDriverService { .create_sandbox(&sandbox) .await .map_err(|err| Status::internal(err.to_string()))?; - Ok(Response::new(ComputeCreateSandboxResponse {})) + Ok(Response::new(CreateSandboxResponse {})) + } + + async fn stop_sandbox( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented( + "stop sandbox is not implemented by the kubernetes compute driver", + )) } async fn delete_sandbox( &self, - request: Request, - ) -> Result, Status> { + request: Request, + ) -> Result, Status> { let request = request.into_inner(); let deleted = self .driver .delete_sandbox(&request.sandbox_name) .await .map_err(Status::internal)?; - Ok(Response::new(ComputeDeleteSandboxResponse { deleted })) + Ok(Response::new(DeleteSandboxResponse { deleted })) } async fn resolve_sandbox_endpoint( diff --git a/crates/openshell-driver-kubernetes/src/main.rs b/crates/openshell-driver-kubernetes/src/main.rs index 32160a6dd..76c567f59 100644 --- a/crates/openshell-driver-kubernetes/src/main.rs +++ b/crates/openshell-driver-kubernetes/src/main.rs @@ -8,7 +8,7 @@ use tracing::info; use tracing_subscriber::EnvFilter; use openshell_core::VERSION; -use openshell_core::proto::compute_driver_server::ComputeDriverServer; +use openshell_core::proto::compute::v1::compute_driver_server::ComputeDriverServer; use openshell_driver_kubernetes::{ ComputeDriverService, KubernetesComputeConfig, KubernetesComputeDriver, }; diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs index da92862a5..465aca5e4 100644 --- a/crates/openshell-server/src/compute/mod.rs +++ b/crates/openshell-server/src/compute/mod.rs @@ -9,9 +9,14 @@ use crate::sandbox_index::SandboxIndex; use crate::sandbox_watch::SandboxWatchBus; use crate::tracing_bus::TracingLogBus; use futures::{Stream, StreamExt}; +use openshell_core::proto::compute::v1::{ + DriverCondition, DriverPlatformEvent, DriverSandbox, DriverSandboxSpec, DriverSandboxStatus, + DriverSandboxTemplate, ResolveSandboxEndpointResponse, WatchSandboxesEvent, sandbox_endpoint, + watch_sandboxes_event, +}; use openshell_core::proto::{ - ResolveSandboxEndpointResponse, Sandbox, SandboxCondition, SandboxPhase, SandboxSpec, - SandboxStatus, SshSession, WatchSandboxesEvent, + PlatformEvent, Sandbox, SandboxCondition, SandboxPhase, SandboxSpec, SandboxStatus, + SandboxTemplate, SshSession, }; use openshell_driver_kubernetes::{ KubernetesComputeConfig, KubernetesComputeDriver, KubernetesDriverError, @@ -56,12 +61,12 @@ pub enum ResolvedEndpoint { #[tonic::async_trait] pub trait ComputeBackend: fmt::Debug + Send + Sync { fn default_image(&self) -> &str; - async fn validate_sandbox_create(&self, sandbox: &Sandbox) -> Result<(), Status>; - async fn create_sandbox(&self, sandbox: &Sandbox) -> Result<(), ComputeError>; + async fn validate_sandbox_create(&self, sandbox: &DriverSandbox) -> Result<(), Status>; + async fn create_sandbox(&self, sandbox: &DriverSandbox) -> Result<(), ComputeError>; async fn delete_sandbox(&self, sandbox_name: &str) -> Result; async fn resolve_sandbox_endpoint( &self, - sandbox: &Sandbox, + sandbox: &DriverSandbox, ) -> Result; async fn watch_sandboxes(&self) -> Result; } @@ -84,11 +89,11 @@ impl ComputeBackend for InProcessKubernetesBackend { self.driver.default_image() } - async fn validate_sandbox_create(&self, sandbox: &Sandbox) -> Result<(), Status> { + async fn validate_sandbox_create(&self, sandbox: &DriverSandbox) -> Result<(), Status> { self.driver.validate_sandbox_create(sandbox).await } - async fn create_sandbox(&self, sandbox: &Sandbox) -> Result<(), ComputeError> { + async fn create_sandbox(&self, sandbox: &DriverSandbox) -> Result<(), ComputeError> { self.driver .create_sandbox(sandbox) .await @@ -104,7 +109,7 @@ impl ComputeBackend for InProcessKubernetesBackend { async fn resolve_sandbox_endpoint( &self, - sandbox: &Sandbox, + sandbox: &DriverSandbox, ) -> Result { let response = self .driver @@ -165,7 +170,8 @@ impl ComputeRuntime { } pub async fn validate_sandbox_create(&self, sandbox: &Sandbox) -> Result<(), Status> { - self.backend.validate_sandbox_create(sandbox).await + let driver_sandbox = driver_sandbox_from_public(sandbox); + self.backend.validate_sandbox_create(&driver_sandbox).await } pub async fn create_sandbox(&self, sandbox: Sandbox) -> Result { @@ -187,7 +193,8 @@ impl ComputeRuntime { .await .map_err(|e| Status::internal(format!("persist sandbox failed: {e}")))?; - match self.backend.create_sandbox(&sandbox).await { + let driver_sandbox = driver_sandbox_from_public(&sandbox); + match self.backend.create_sandbox(&driver_sandbox).await { Ok(()) => { self.sandbox_watch_bus.notify(&sandbox.id); Ok(sandbox) @@ -278,8 +285,9 @@ impl ComputeRuntime { &self, sandbox: &Sandbox, ) -> Result { + let driver_sandbox = driver_sandbox_from_public(sandbox); self.backend - .resolve_sandbox_endpoint(sandbox) + .resolve_sandbox_endpoint(&driver_sandbox) .await .map_err(|err| match err { ComputeError::Precondition(message) => Status::failed_precondition(message), @@ -329,24 +337,24 @@ impl ComputeRuntime { } async fn apply_watch_event(&self, event: WatchSandboxesEvent) -> Result<(), String> { - use openshell_core::proto::watch_sandboxes_event::Payload; - match event.payload { - Some(Payload::Sandbox(sandbox)) => { + Some(watch_sandboxes_event::Payload::Sandbox(sandbox)) => { if let Some(sandbox) = sandbox.sandbox { self.apply_sandbox_update(sandbox).await?; } } - Some(Payload::Deleted(deleted)) => { + Some(watch_sandboxes_event::Payload::Deleted(deleted)) => { self.apply_deleted(&deleted.sandbox_id).await?; } - Some(Payload::PlatformEvent(platform_event)) => { + Some(watch_sandboxes_event::Payload::PlatformEvent(platform_event)) => { if let Some(event) = platform_event.event { self.tracing_log_bus.platform_event_bus.publish( &platform_event.sandbox_id, openshell_core::proto::SandboxStreamEvent { payload: Some( - openshell_core::proto::sandbox_stream_event::Payload::Event(event), + openshell_core::proto::sandbox_stream_event::Payload::Event( + public_platform_event_from_driver(&event), + ), ), }, ); @@ -357,20 +365,20 @@ impl ComputeRuntime { Ok(()) } - async fn apply_sandbox_update(&self, incoming: Sandbox) -> Result<(), String> { + async fn apply_sandbox_update(&self, incoming: DriverSandbox) -> Result<(), String> { let existing = self .store .get_message::(&incoming.id) .await .map_err(|e| e.to_string())?; - let mut status = incoming.status.clone(); + let mut status = incoming.status.as_ref().map(public_status_from_driver); rewrite_user_facing_conditions( &mut status, existing.as_ref().and_then(|sandbox| sandbox.spec.as_ref()), ); - let phase = SandboxPhase::try_from(incoming.phase).unwrap_or(SandboxPhase::Unknown); + let mut phase = derive_phase(incoming.status.as_ref()); let mut sandbox = existing.unwrap_or_else(|| Sandbox { id: incoming.id.clone(), name: incoming.name.clone(), @@ -382,6 +390,9 @@ impl ComputeRuntime { }); let old_phase = SandboxPhase::try_from(sandbox.phase).unwrap_or(SandboxPhase::Unknown); + if old_phase == SandboxPhase::Deleting && phase != SandboxPhase::Error { + phase = SandboxPhase::Deleting; + } if old_phase != phase { info!( sandbox_id = %incoming.id, @@ -398,7 +409,7 @@ impl ComputeRuntime { for condition in &status.conditions { if condition.r#type == "Ready" && condition.status.eq_ignore_ascii_case("false") - && is_terminal_failure_condition(condition) + && is_terminal_failure_reason(&condition.reason) { warn!( sandbox_id = %incoming.id, @@ -444,6 +455,69 @@ impl ComputeRuntime { } } +fn driver_sandbox_from_public(sandbox: &Sandbox) -> DriverSandbox { + DriverSandbox { + id: sandbox.id.clone(), + name: sandbox.name.clone(), + namespace: sandbox.namespace.clone(), + spec: sandbox.spec.as_ref().map(driver_sandbox_spec_from_public), + status: sandbox + .status + .as_ref() + .map(|status| driver_status_from_public(status, sandbox.phase)), + } +} + +fn driver_sandbox_spec_from_public(spec: &SandboxSpec) -> DriverSandboxSpec { + DriverSandboxSpec { + log_level: spec.log_level.clone(), + environment: spec.environment.clone(), + template: spec + .template + .as_ref() + .map(driver_sandbox_template_from_public), + gpu: spec.gpu, + } +} + +fn driver_sandbox_template_from_public(template: &SandboxTemplate) -> DriverSandboxTemplate { + DriverSandboxTemplate { + image: template.image.clone(), + runtime_class_name: template.runtime_class_name.clone(), + agent_socket: template.agent_socket.clone(), + labels: template.labels.clone(), + annotations: template.annotations.clone(), + environment: template.environment.clone(), + resources: template.resources.clone(), + volume_claim_templates: template.volume_claim_templates.clone(), + } +} + +fn driver_status_from_public(status: &SandboxStatus, phase: i32) -> DriverSandboxStatus { + DriverSandboxStatus { + sandbox_name: status.sandbox_name.clone(), + agent_pod: status.agent_pod.clone(), + agent_fd: status.agent_fd.clone(), + sandbox_fd: status.sandbox_fd.clone(), + conditions: status + .conditions + .iter() + .map(driver_condition_from_public) + .collect(), + deleting: SandboxPhase::try_from(phase) == Ok(SandboxPhase::Deleting), + } +} + +fn driver_condition_from_public(condition: &SandboxCondition) -> DriverCondition { + DriverCondition { + r#type: condition.r#type.clone(), + status: condition.status.clone(), + reason: condition.reason.clone(), + message: condition.message.clone(), + last_transition_time: condition.last_transition_time.clone(), + } +} + impl ObjectType for Sandbox { fn object_type() -> &'static str { "sandbox" @@ -473,11 +547,11 @@ fn resolved_endpoint_from_response( .map_err(|_| ComputeError::Message("compute driver returned invalid port".to_string()))?; match endpoint.target.as_ref() { - Some(openshell_core::proto::sandbox_endpoint::Target::Ip(ip)) => ip + Some(sandbox_endpoint::Target::Ip(ip)) => ip .parse() .map(|ip| ResolvedEndpoint::Ip(ip, port)) .map_err(|e| ComputeError::Message(format!("invalid endpoint IP: {e}"))), - Some(openshell_core::proto::sandbox_endpoint::Target::Host(host)) => { + Some(sandbox_endpoint::Target::Host(host)) => { Ok(ResolvedEndpoint::Host(host.clone(), port)) } None => Err(ComputeError::Message( @@ -486,6 +560,68 @@ fn resolved_endpoint_from_response( } } +fn public_status_from_driver(status: &DriverSandboxStatus) -> SandboxStatus { + SandboxStatus { + sandbox_name: status.sandbox_name.clone(), + agent_pod: status.agent_pod.clone(), + agent_fd: status.agent_fd.clone(), + sandbox_fd: status.sandbox_fd.clone(), + conditions: status + .conditions + .iter() + .map(public_condition_from_driver) + .collect(), + } +} + +fn public_condition_from_driver(condition: &DriverCondition) -> SandboxCondition { + SandboxCondition { + r#type: condition.r#type.clone(), + status: condition.status.clone(), + reason: condition.reason.clone(), + message: condition.message.clone(), + last_transition_time: condition.last_transition_time.clone(), + } +} + +fn public_platform_event_from_driver(event: &DriverPlatformEvent) -> PlatformEvent { + PlatformEvent { + timestamp_ms: event.timestamp_ms, + source: event.source.clone(), + r#type: event.r#type.clone(), + reason: event.reason.clone(), + message: event.message.clone(), + metadata: event.metadata.clone(), + } +} + +fn derive_phase(status: Option<&DriverSandboxStatus>) -> SandboxPhase { + if let Some(status) = status { + if status.deleting { + return SandboxPhase::Deleting; + } + + for condition in &status.conditions { + if condition.r#type == "Ready" { + return if condition.status.eq_ignore_ascii_case("true") { + SandboxPhase::Ready + } else if condition.status.eq_ignore_ascii_case("false") { + if is_terminal_failure_reason(&condition.reason) { + SandboxPhase::Error + } else { + SandboxPhase::Provisioning + } + } else { + SandboxPhase::Provisioning + }; + } + } + return SandboxPhase::Provisioning; + } + + SandboxPhase::Unknown +} + fn rewrite_user_facing_conditions(status: &mut Option, spec: Option<&SandboxSpec>) { let gpu_requested = spec.is_some_and(|sandbox_spec| sandbox_spec.gpu); if !gpu_requested { @@ -504,8 +640,200 @@ fn rewrite_user_facing_conditions(status: &mut Option, spec: Opti } } -fn is_terminal_failure_condition(condition: &SandboxCondition) -> bool { - let reason = condition.reason.to_ascii_lowercase(); +fn is_terminal_failure_reason(reason: &str) -> bool { + let reason = reason.to_ascii_lowercase(); let transient_reasons = ["reconcilererror", "dependenciesnotready"]; !transient_reasons.contains(&reason.as_str()) } + +#[cfg(test)] +mod tests { + use super::*; + + fn make_driver_condition(reason: &str, message: &str) -> DriverCondition { + DriverCondition { + r#type: "Ready".to_string(), + status: "False".to_string(), + reason: reason.to_string(), + message: message.to_string(), + last_transition_time: String::new(), + } + } + + fn make_driver_status(condition: DriverCondition) -> DriverSandboxStatus { + DriverSandboxStatus { + sandbox_name: "test".to_string(), + agent_pod: "test-pod".to_string(), + agent_fd: String::new(), + sandbox_fd: String::new(), + conditions: vec![condition], + deleting: false, + } + } + + #[test] + fn terminal_failure_treats_unknown_reasons_as_terminal() { + let terminal_cases = [ + ("Failed", "Something went wrong"), + ("CrashLoopBackOff", "Container keeps crashing"), + ("ImagePullBackOff", "Failed to pull image"), + ("ErrImagePull", "Error pulling image"), + ("Unschedulable", "No nodes match"), + ("SomeOtherReason", "Any other reason is terminal"), + ]; + + for (reason, message) in terminal_cases { + assert!( + is_terminal_failure_reason(reason), + "Expected terminal failure for reason={reason}, message={message}" + ); + } + } + + #[test] + fn terminal_failure_ignores_transient_reasons() { + let transient_cases = [ + ( + "ReconcilerError", + "Error seen: failed to update pod: Operation cannot be fulfilled", + ), + ("reconcilererror", "lowercase also works"), + ("RECONCILERERROR", "uppercase also works"), + ( + "DependenciesNotReady", + "Pod exists with phase: Pending; Service Exists", + ), + ("dependenciesnotready", "lowercase also works"), + ]; + + for (reason, message) in transient_cases { + assert!( + !is_terminal_failure_reason(reason), + "Expected transient (non-terminal) for reason={reason}, message={message}" + ); + } + } + + #[test] + fn derive_phase_returns_unknown_without_status() { + assert_eq!(derive_phase(None), SandboxPhase::Unknown); + } + + #[test] + fn derive_phase_returns_deleting_when_driver_marks_deleting() { + let status = DriverSandboxStatus { + deleting: true, + ..make_driver_status(make_driver_condition( + "DependenciesNotReady", + "Pod still pending", + )) + }; + + assert_eq!(derive_phase(Some(&status)), SandboxPhase::Deleting); + } + + #[test] + fn derive_phase_returns_provisioning_for_transient_conditions() { + let transient_conditions = [ + ("ReconcilerError", "Error seen: failed to update pod"), + ( + "DependenciesNotReady", + "Pod exists with phase: Pending; Service Exists", + ), + ]; + + for (reason, message) in transient_conditions { + let status = make_driver_status(make_driver_condition(reason, message)); + assert_eq!( + derive_phase(Some(&status)), + SandboxPhase::Provisioning, + "Expected Provisioning for transient reason={reason}" + ); + } + } + + #[test] + fn derive_phase_returns_error_for_terminal_ready_false() { + let status = make_driver_status(make_driver_condition( + "ImagePullBackOff", + "Failed to pull image", + )); + + assert_eq!(derive_phase(Some(&status)), SandboxPhase::Error); + } + + #[test] + fn derive_phase_returns_ready_for_ready_true() { + let status = DriverSandboxStatus { + conditions: vec![DriverCondition { + r#type: "Ready".to_string(), + status: "True".to_string(), + reason: "DependenciesReady".to_string(), + message: "Pod is Ready; Service Exists".to_string(), + last_transition_time: String::new(), + }], + ..make_driver_status(make_driver_condition("", "")) + }; + + assert_eq!(derive_phase(Some(&status)), SandboxPhase::Ready); + } + + #[test] + fn rewrite_user_facing_conditions_rewrites_gpu_unschedulable_message() { + let mut status = Some(SandboxStatus { + sandbox_name: "test".to_string(), + agent_pod: "test-pod".to_string(), + agent_fd: String::new(), + sandbox_fd: String::new(), + conditions: vec![SandboxCondition { + r#type: "Ready".to_string(), + status: "False".to_string(), + reason: "Unschedulable".to_string(), + message: "0/1 nodes are available: 1 Insufficient nvidia.com/gpu.".to_string(), + last_transition_time: String::new(), + }], + }); + + rewrite_user_facing_conditions( + &mut status, + Some(&SandboxSpec { + gpu: true, + ..Default::default() + }), + ); + + let message = &status.unwrap().conditions[0].message; + assert_eq!( + message, + "GPU sandbox could not be scheduled on the active gateway. Another GPU sandbox may already be using the available GPU, or the gateway may not currently be able to satisfy GPU placement. Please refer to documentation and use `openshell doctor` commands to inspect GPU support and gateway configuration." + ); + } + + #[test] + fn rewrite_user_facing_conditions_leaves_non_gpu_unschedulable_message_unchanged() { + let original = "0/1 nodes are available: 1 Insufficient cpu."; + let mut status = Some(SandboxStatus { + sandbox_name: "test".to_string(), + agent_pod: "test-pod".to_string(), + agent_fd: String::new(), + sandbox_fd: String::new(), + conditions: vec![SandboxCondition { + r#type: "Ready".to_string(), + status: "False".to_string(), + reason: "Unschedulable".to_string(), + message: original.to_string(), + last_transition_time: String::new(), + }], + }); + + rewrite_user_facing_conditions( + &mut status, + Some(&SandboxSpec { + gpu: false, + ..Default::default() + }), + ); + + assert_eq!(status.unwrap().conditions[0].message, original); + } +} diff --git a/deploy/docker/Dockerfile.images b/deploy/docker/Dockerfile.images index 060c9b738..b7e854677 100644 --- a/deploy/docker/Dockerfile.images +++ b/deploy/docker/Dockerfile.images @@ -47,6 +47,7 @@ COPY Cargo.toml Cargo.lock ./ COPY crates/openshell-bootstrap/Cargo.toml crates/openshell-bootstrap/Cargo.toml COPY crates/openshell-cli/Cargo.toml crates/openshell-cli/Cargo.toml COPY crates/openshell-core/Cargo.toml crates/openshell-core/Cargo.toml +COPY crates/openshell-driver-kubernetes/Cargo.toml crates/openshell-driver-kubernetes/Cargo.toml COPY crates/openshell-ocsf/Cargo.toml crates/openshell-ocsf/Cargo.toml COPY crates/openshell-policy/Cargo.toml crates/openshell-policy/Cargo.toml COPY crates/openshell-providers/Cargo.toml crates/openshell-providers/Cargo.toml @@ -63,6 +64,7 @@ RUN mkdir -p \ crates/openshell-bootstrap/src \ crates/openshell-cli/src \ crates/openshell-core/src \ + crates/openshell-driver-kubernetes/src \ crates/openshell-ocsf/src \ crates/openshell-policy/src \ crates/openshell-providers/src \ @@ -75,6 +77,8 @@ RUN mkdir -p \ touch crates/openshell-bootstrap/src/lib.rs && \ printf 'fn main() {}\n' > crates/openshell-cli/src/main.rs && \ touch crates/openshell-core/src/lib.rs && \ + touch crates/openshell-driver-kubernetes/src/lib.rs && \ + printf 'fn main() {}\n' > crates/openshell-driver-kubernetes/src/main.rs && \ touch crates/openshell-ocsf/src/lib.rs && \ touch crates/openshell-policy/src/lib.rs && \ touch crates/openshell-providers/src/lib.rs && \ @@ -109,6 +113,7 @@ FROM rust-deps AS gateway-workspace ARG OPENSHELL_CARGO_VERSION COPY crates/openshell-core/ crates/openshell-core/ +COPY crates/openshell-driver-kubernetes/ crates/openshell-driver-kubernetes/ COPY crates/openshell-policy/ crates/openshell-policy/ COPY crates/openshell-providers/ crates/openshell-providers/ COPY crates/openshell-router/ crates/openshell-router/ diff --git a/proto/compute_driver.proto b/proto/compute_driver.proto index 9005ca23b..8a6f32e9f 100644 --- a/proto/compute_driver.proto +++ b/proto/compute_driver.proto @@ -5,9 +5,16 @@ syntax = "proto3"; package openshell.compute.v1; -import "sandbox.proto"; +import "google/protobuf/struct.proto"; // Internal compute-driver contract used by the gateway. +// +// Conventions: +// - This file owns driver-native request, response, and observation types. +// - Compute drivers must not import or return the public `openshell.v1.Sandbox` +// resource model. +// - The gateway translates between these internal driver-native messages and +// the public OpenShell API resource model. service ComputeDriver { // Report driver capabilities and defaults. rpc GetCapabilities(GetCapabilitiesRequest) returns (GetCapabilitiesResponse); @@ -16,11 +23,20 @@ service ComputeDriver { rpc ValidateSandboxCreate(ValidateSandboxCreateRequest) returns (ValidateSandboxCreateResponse); + // Fetch the platform-observed sandbox state for one sandbox. + rpc GetSandbox(GetSandboxRequest) returns (GetSandboxResponse); + + // List platform-observed sandbox state for all sandboxes. + rpc ListSandboxes(ListSandboxesRequest) returns (ListSandboxesResponse); + // Provision platform resources for a sandbox. - rpc CreateSandbox(ComputeCreateSandboxRequest) returns (ComputeCreateSandboxResponse); + rpc CreateSandbox(CreateSandboxRequest) returns (CreateSandboxResponse); + + // Stop platform resources for a sandbox without deleting its record. + rpc StopSandbox(StopSandboxRequest) returns (StopSandboxResponse); // Tear down platform resources for a sandbox. - rpc DeleteSandbox(ComputeDeleteSandboxRequest) returns (ComputeDeleteSandboxResponse); + rpc DeleteSandbox(DeleteSandboxRequest) returns (DeleteSandboxResponse); // Resolve the current endpoint for sandbox exec/SSH transport. rpc ResolveSandboxEndpoint(ResolveSandboxEndpointRequest) @@ -33,68 +49,217 @@ service ComputeDriver { message GetCapabilitiesRequest {} message GetCapabilitiesResponse { + // Human-readable driver name. string driver_name = 1; + // Driver implementation version string. string driver_version = 2; + // Default sandbox image recommended by the driver. string default_image = 3; + // True when the driver can provision GPU-backed sandboxes. bool supports_gpu = 4; } +// Driver-owned sandbox model used for create requests and platform observations. +// +// This intentionally omits gateway-owned lifecycle fields such as the public +// `openshell.v1.SandboxPhase` and persisted metadata. The gateway derives and +// stores those fields after translating driver observations. +message DriverSandbox { + // Stable sandbox ID assigned by the gateway. + string id = 1; + // Compute-runtime sandbox name. + string name = 2; + // Compute-platform namespace or equivalent tenancy boundary. + string namespace = 3; + // Provisioning input supplied by the gateway. Drivers may omit this in + // observed snapshots returned by Get/List/Watch. + DriverSandboxSpec spec = 4; + // Raw platform-observed status. + DriverSandboxStatus status = 5; +} + +// Driver-owned provisioning inputs required to create a sandbox. +message DriverSandboxSpec { + // Log level exposed to processes running inside the sandbox. + string log_level = 1; + // Environment variables injected into the sandbox runtime. + map environment = 5; + // Runtime template consumed by the driver during provisioning. + DriverSandboxTemplate template = 6; + // Request NVIDIA GPU resources for this sandbox. + bool gpu = 9; +} + +// Driver-owned runtime template consumed by the compute platform. +message DriverSandboxTemplate { + // Fully-qualified OCI image reference used to boot the sandbox. + string image = 1; + // Optional runtime class name requested from the compute platform. + string runtime_class_name = 2; + // Optional agent socket path exposed to the workload. + string agent_socket = 3; + // Labels applied to compute-platform resources. + map labels = 4; + // Annotations applied to compute-platform resources. + map annotations = 5; + // Additional environment variables injected by the template. + map environment = 6; + // Platform-specific compute resource requirements and limits. + google.protobuf.Struct resources = 7; + // Optional platform-specific volume claim templates. + google.protobuf.Struct volume_claim_templates = 9; +} + +// Raw status observed directly from the compute platform. +// +// The gateway derives the public `openshell.v1.SandboxPhase` from these +// conditions plus `deleting`. +message DriverSandboxStatus { + // Compute-platform sandbox object name. + string sandbox_name = 1; + // Name of the agent pod or equivalent runtime instance. + string agent_pod = 2; + // File descriptor or endpoint for reaching the agent service, when available. + string agent_fd = 3; + // File descriptor or endpoint for reaching the sandbox service, when available. + string sandbox_fd = 4; + // Raw readiness and lifecycle conditions reported by the platform. + repeated DriverCondition conditions = 5; + // True when the compute platform has begun deleting this sandbox. + bool deleting = 6; +} + +// Raw compute-platform condition. +message DriverCondition { + // Condition class reported by the compute platform. + string type = 1; + // Condition status value such as `True`, `False`, or `Unknown`. + string status = 2; + // Short machine-readable reason associated with the condition. + string reason = 3; + // Human-readable condition message. + string message = 4; + // Timestamp reported by the platform for the last transition. + string last_transition_time = 5; +} + +// Raw compute-platform event correlated to a sandbox. +message DriverPlatformEvent { + // Event timestamp in milliseconds since epoch. + int64 timestamp_ms = 1; + // Event source (for example `kubernetes`). + string source = 2; + // Event type or severity (for example `Normal` or `Warning`). + string type = 3; + // Short machine-readable reason code. + string reason = 4; + // Human-readable event message. + string message = 5; + // Optional platform-specific metadata attached to the event. + map metadata = 6; +} + message ValidateSandboxCreateRequest { - openshell.sandbox.v1.Sandbox sandbox = 1; + // Proposed sandbox configuration to validate before provisioning. + DriverSandbox sandbox = 1; } message ValidateSandboxCreateResponse {} -message ComputeCreateSandboxRequest { - openshell.sandbox.v1.Sandbox sandbox = 1; +message GetSandboxRequest { + // Stable sandbox ID stored by the gateway. + string sandbox_id = 1; + // Compute-runtime name used by the driver. + string sandbox_name = 2; +} + +message GetSandboxResponse { + // Platform-observed sandbox snapshot returned by the driver. + DriverSandbox sandbox = 1; +} + +message ListSandboxesRequest {} + +message ListSandboxesResponse { + // Platform-observed sandbox snapshots returned by the driver. + repeated DriverSandbox sandboxes = 1; +} + +message CreateSandboxRequest { + // Sandbox configuration to provision on the compute platform. + DriverSandbox sandbox = 1; +} + +message CreateSandboxResponse {} + +message StopSandboxRequest { + // Stable sandbox ID stored by the gateway. + string sandbox_id = 1; + // Compute-runtime name used by the driver. + string sandbox_name = 2; } -message ComputeCreateSandboxResponse {} +message StopSandboxResponse {} -message ComputeDeleteSandboxRequest { +message DeleteSandboxRequest { + // Stable sandbox ID stored by the gateway. string sandbox_id = 1; + // Compute-runtime name used by the driver. string sandbox_name = 2; } -message ComputeDeleteSandboxResponse { +message DeleteSandboxResponse { + // True when a platform resource was deleted by this request. bool deleted = 1; } message ResolveSandboxEndpointRequest { - openshell.sandbox.v1.Sandbox sandbox = 1; + // Sandbox to resolve for exec or SSH connectivity. + DriverSandbox sandbox = 1; } message SandboxEndpoint { oneof target { + // Direct IP address for the sandbox endpoint. string ip = 1; + // DNS host name for the sandbox endpoint. string host = 2; } + // TCP port for the sandbox endpoint. uint32 port = 3; } message ResolveSandboxEndpointResponse { + // Current endpoint the gateway should use to reach the sandbox. SandboxEndpoint endpoint = 1; } message WatchSandboxesRequest {} message WatchSandboxesSandboxEvent { - openshell.sandbox.v1.Sandbox sandbox = 1; + // Updated driver-native snapshot for one sandbox. + DriverSandbox sandbox = 1; } message WatchSandboxesDeletedEvent { + // Sandbox ID removed from the compute platform. string sandbox_id = 1; } message WatchSandboxesPlatformEvent { + // Sandbox ID correlated to the platform event. string sandbox_id = 1; - openshell.sandbox.v1.PlatformEvent event = 2; + // Raw platform event emitted for the sandbox. + DriverPlatformEvent event = 2; } message WatchSandboxesEvent { oneof payload { + // Updated or newly observed sandbox snapshot. WatchSandboxesSandboxEvent sandbox = 1; + // Sandbox deletion observation. WatchSandboxesDeletedEvent deleted = 2; + // Raw platform event correlated to a sandbox. WatchSandboxesPlatformEvent platform_event = 3; } } diff --git a/proto/openshell.proto b/proto/openshell.proto index 43d903b6d..0ee1e8904 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -6,9 +6,17 @@ syntax = "proto3"; package openshell.v1; import "datamodel.proto"; +import "google/protobuf/struct.proto"; import "sandbox.proto"; // OpenShell service provides sandbox, provider, and runtime management capabilities. +// +// Conventions: +// - This file owns the public API resource model exposed to OpenShell clients. +// - `Sandbox`, `SandboxSpec`, `SandboxStatus`, and `SandboxPhase` are gateway-owned +// public types. Internal compute drivers must not import or return them directly. +// - The gateway translates internal compute-driver observations into these public +// resource messages before persisting or returning them to clients. service OpenShell { // Check the health of the service. rpc Health(HealthRequest) returns (HealthResponse); @@ -140,9 +148,129 @@ message HealthResponse { string version = 2; } +// Public sandbox resource exposed by the OpenShell API. +// +// This is the canonical gateway-owned view of a sandbox. It merges user intent +// (`spec`) with gateway-managed metadata and status derived from internal +// compute-driver observations. +message Sandbox { + // Stable sandbox ID generated by the gateway. + string id = 1; + // User-visible sandbox name. + string name = 2; + // Namespace used by the backing compute platform. + string namespace = 3; + // Desired sandbox configuration submitted through the API. + SandboxSpec spec = 4; + // Latest user-facing observed status derived by the gateway. + SandboxStatus status = 5; + // Gateway-derived lifecycle summary. + SandboxPhase phase = 6; + // Milliseconds since Unix epoch when the sandbox was created. + int64 created_at_ms = 7; + // Currently active policy version (updated when sandbox reports loaded). + uint32 current_policy_version = 8; +} + +// Desired sandbox configuration provided through the public API. +message SandboxSpec { + // Log level exposed to processes running inside the sandbox. + string log_level = 1; + // Environment variables injected into the sandbox runtime. + map environment = 5; + // Container or VM template used to provision the sandbox. + SandboxTemplate template = 6; + // Required sandbox policy configuration. + openshell.sandbox.v1.SandboxPolicy policy = 7; + // Provider names to attach to this sandbox. + repeated string providers = 8; + // Request NVIDIA GPU resources for this sandbox. + bool gpu = 9; +} + +// Public sandbox template mapped onto compute-driver template inputs. +message SandboxTemplate { + // Fully-qualified OCI image reference used to boot the sandbox. + string image = 1; + // Optional runtime class name requested from the compute platform. + string runtime_class_name = 2; + // Optional agent socket path exposed to the workload. + string agent_socket = 3; + // Labels applied to compute-platform resources for this sandbox. + map labels = 4; + // Annotations applied to compute-platform resources for this sandbox. + map annotations = 5; + // Additional environment variables injected by the template. + map environment = 6; + // Platform-specific compute resource requirements and limits. + google.protobuf.Struct resources = 7; + // Optional platform-specific volume claim templates. + google.protobuf.Struct volume_claim_templates = 9; +} + +// User-facing sandbox status derived by the gateway from compute-driver observations. +// +// Lifecycle summary is exposed separately as `Sandbox.phase`. Public status does +// not embed driver-only flags such as `deleting`. +message SandboxStatus { + // Compute-platform sandbox object name. + string sandbox_name = 1; + // Name of the agent pod or equivalent runtime instance. + string agent_pod = 2; + // File descriptor or endpoint for reaching the agent service, when available. + string agent_fd = 3; + // File descriptor or endpoint for reaching the sandbox service, when available. + string sandbox_fd = 4; + // Latest user-facing readiness and lifecycle conditions. + repeated SandboxCondition conditions = 5; +} + +// User-facing sandbox condition derived from driver-native conditions. +message SandboxCondition { + // Condition class, typically mirroring the underlying platform condition type. + string type = 1; + // Condition status value such as `True`, `False`, or `Unknown`. + string status = 2; + // Short machine-readable reason associated with the condition. + string reason = 3; + // Human-readable condition message. + string message = 4; + // Timestamp reported by the underlying platform for the last transition. + string last_transition_time = 5; +} + +// High-level sandbox lifecycle phase derived by the gateway. +// +// Clients should rely on this normalized lifecycle summary for readiness and +// deletion decisions instead of interpreting raw conditions. +enum SandboxPhase { + SANDBOX_PHASE_UNSPECIFIED = 0; + SANDBOX_PHASE_PROVISIONING = 1; + SANDBOX_PHASE_READY = 2; + SANDBOX_PHASE_ERROR = 3; + SANDBOX_PHASE_DELETING = 4; + SANDBOX_PHASE_UNKNOWN = 5; +} + +// Public platform event exposed on the sandbox watch stream. +message PlatformEvent { + // Event timestamp in milliseconds since epoch. + int64 timestamp_ms = 1; + // Event source (e.g. "kubernetes", "docker", "process"). + string source = 2; + // Event type/severity (e.g. "Normal", "Warning"). + string type = 3; + // Short reason code (e.g. "Started", "Pulled", "Failed"). + string reason = 4; + // Human-readable event message. + string message = 5; + // Optional metadata as key-value pairs. + map metadata = 6; +} + // Create sandbox request. message CreateSandboxRequest { - openshell.sandbox.v1.SandboxSpec spec = 1; + SandboxSpec spec = 1; // Optional user-supplied sandbox name. When empty the server generates one. string name = 2; } @@ -167,12 +295,12 @@ message DeleteSandboxRequest { // Sandbox response. message SandboxResponse { - openshell.sandbox.v1.Sandbox sandbox = 1; + Sandbox sandbox = 1; } // List sandboxes response. message ListSandboxesResponse { - repeated openshell.sandbox.v1.Sandbox sandboxes = 1; + repeated Sandbox sandboxes = 1; } // Delete sandbox response. @@ -336,11 +464,11 @@ message WatchSandboxRequest { message SandboxStreamEvent { oneof payload { // Latest sandbox snapshot. - openshell.sandbox.v1.Sandbox sandbox = 1; + Sandbox sandbox = 1; // One server log line/event. SandboxLogLine log = 2; // One platform event. - openshell.sandbox.v1.PlatformEvent event = 3; + PlatformEvent event = 3; // Warning from the server (e.g. missed messages due to lag). SandboxStreamWarning warning = 4; // Draft policy update notification. diff --git a/proto/sandbox.proto b/proto/sandbox.proto index f810f1e0a..c350d8d85 100644 --- a/proto/sandbox.proto +++ b/proto/sandbox.proto @@ -5,90 +5,13 @@ syntax = "proto3"; package openshell.sandbox.v1; -import "google/protobuf/struct.proto"; - -// Sandbox model stored by OpenShell. -message Sandbox { - string id = 1; - string name = 2; - string namespace = 3; - SandboxSpec spec = 4; - SandboxStatus status = 5; - SandboxPhase phase = 6; - // Milliseconds since Unix epoch when the sandbox was created. - int64 created_at_ms = 7; - // Currently active policy version (updated when sandbox reports loaded). - uint32 current_policy_version = 8; -} - -// OpenShell-level sandbox spec. -message SandboxSpec { - string log_level = 1; - map environment = 5; - SandboxTemplate template = 6; - // Required sandbox policy configuration. - SandboxPolicy policy = 7; - // Provider names to attach to this sandbox. - repeated string providers = 8; - // Request NVIDIA GPU resources for this sandbox. - bool gpu = 9; -} - -// Sandbox template mapped onto compute-driver template inputs. -message SandboxTemplate { - string image = 1; - string runtime_class_name = 2; - string agent_socket = 3; - map labels = 4; - map annotations = 5; - map environment = 6; - google.protobuf.Struct resources = 7; - google.protobuf.Struct volume_claim_templates = 9; -} - -// Sandbox status captured from the compute platform. -message SandboxStatus { - string sandbox_name = 1; - string agent_pod = 2; - string agent_fd = 3; - string sandbox_fd = 4; - repeated SandboxCondition conditions = 5; -} - -// Sandbox condition mirrors the compute platform condition model. -message SandboxCondition { - string type = 1; - string status = 2; - string reason = 3; - string message = 4; - string last_transition_time = 5; -} - -// High-level sandbox lifecycle phase. -enum SandboxPhase { - SANDBOX_PHASE_UNSPECIFIED = 0; - SANDBOX_PHASE_PROVISIONING = 1; - SANDBOX_PHASE_READY = 2; - SANDBOX_PHASE_ERROR = 3; - SANDBOX_PHASE_DELETING = 4; - SANDBOX_PHASE_UNKNOWN = 5; -} - -// Platform event correlated to a sandbox. -message PlatformEvent { - // Event timestamp in milliseconds since epoch. - int64 timestamp_ms = 1; - // Event source (e.g. "kubernetes", "docker", "process"). - string source = 2; - // Event type/severity (e.g. "Normal", "Warning"). - string type = 3; - // Short reason code (e.g. "Started", "Pulled", "Failed"). - string reason = 4; - // Human-readable event message. - string message = 5; - // Optional metadata as key-value pairs. - map metadata = 6; -} +// Sandbox-supervisor configuration and policy messages. +// +// Conventions: +// - This file owns messages exchanged between the gateway and the sandbox +// supervisor/runtime. +// - Public sandbox resource types live in `openshell.proto`. +// - Internal compute-driver sandbox observation types live in `compute_driver.proto`. // Sandbox security policy configuration. message SandboxPolicy { diff --git a/tasks/scripts/cluster-deploy-fast.sh b/tasks/scripts/cluster-deploy-fast.sh index 307e76233..c38259288 100755 --- a/tasks/scripts/cluster-deploy-fast.sh +++ b/tasks/scripts/cluster-deploy-fast.sh @@ -152,7 +152,7 @@ matches_gateway() { deploy/docker/Dockerfile.images|tasks/scripts/docker-build-image.sh) return 0 ;; - crates/openshell-core/*|crates/openshell-policy/*|crates/openshell-providers/*) + crates/openshell-core/*|crates/openshell-driver-kubernetes/*|crates/openshell-policy/*|crates/openshell-providers/*) return 0 ;; crates/openshell-router/*|crates/openshell-server/*) @@ -209,7 +209,7 @@ compute_fingerprint() { local committed_trees="" case "${component}" in gateway) - committed_trees=$(git ls-tree HEAD Cargo.toml Cargo.lock proto/ deploy/docker/cross-build.sh deploy/docker/Dockerfile.images tasks/scripts/docker-build-image.sh crates/openshell-core/ crates/openshell-policy/ crates/openshell-providers/ crates/openshell-router/ crates/openshell-server/ 2>/dev/null || true) + committed_trees=$(git ls-tree HEAD Cargo.toml Cargo.lock proto/ deploy/docker/cross-build.sh deploy/docker/Dockerfile.images tasks/scripts/docker-build-image.sh crates/openshell-core/ crates/openshell-driver-kubernetes/ crates/openshell-policy/ crates/openshell-providers/ crates/openshell-router/ crates/openshell-server/ 2>/dev/null || true) ;; supervisor) committed_trees=$(git ls-tree HEAD Cargo.toml Cargo.lock proto/ deploy/docker/cross-build.sh deploy/docker/Dockerfile.images tasks/scripts/docker-build-image.sh crates/openshell-core/ crates/openshell-policy/ crates/openshell-router/ crates/openshell-sandbox/ 2>/dev/null || true) From 4d875d9b016da3d22ab64a6f5fc24bb0f599353a Mon Sep 17 00:00:00 2001 From: Drew Newberry Date: Mon, 13 Apr 2026 23:44:38 -0700 Subject: [PATCH 3/7] wip --- architecture/gateway.md | 6 +- .../openshell-driver-kubernetes/src/driver.rs | 156 ++++--- .../openshell-driver-kubernetes/src/grpc.rs | 27 +- crates/openshell-server/src/compute/mod.rs | 403 +++++++++++++++++- proto/compute_driver.proto | 57 ++- 5 files changed, 562 insertions(+), 87 deletions(-) diff --git a/architecture/gateway.md b/architecture/gateway.md index b783e7c03..cc43374a4 100644 --- a/architecture/gateway.md +++ b/architecture/gateway.md @@ -149,7 +149,7 @@ pub struct ServerState { ``` - **`store`** -- persistence backend (SQLite or Postgres) for all object types. -- **`compute`** -- gateway-owned compute orchestration. Persists sandbox lifecycle transitions, validates create requests through the compute backend, resolves exec/SSH endpoints, and consumes the backend watch stream. +- **`compute`** -- gateway-owned compute orchestration. Persists sandbox lifecycle transitions, validates create requests through the compute backend, resolves exec/SSH endpoints, consumes the backend watch stream, and periodically reconciles orphaned `Provisioning` records that no longer have a backing compute resource. - **`sandbox_index`** -- in-memory bidirectional index mapping sandbox names and agent pod names to sandbox IDs. Updated from compute-driver sandbox snapshots. - **`sandbox_watch_bus`** -- `broadcast`-based notification bus keyed by sandbox ID. Producers call `notify(&id)` when the persisted sandbox record changes; consumers in `WatchSandbox` streams receive `()` signals and re-read the record. - **`tracing_log_bus`** -- captures `tracing` events that include a `sandbox_id` field and republishes them as `SandboxLogLine` messages. Maintains a per-sandbox tail buffer (default 200 entries). Also contains a nested `PlatformEventBus` for compute-driver platform events. @@ -381,7 +381,7 @@ All buses use `tokio::sync::broadcast` channels keyed by sandbox ID. Buffer size Broadcast lag is translated to `Status::resource_exhausted` via `broadcast_to_status()`. -**Cleanup:** Each bus exposes a `remove(sandbox_id)` method that drops the broadcast sender (closing active receivers with `RecvError::Closed`) and frees internal map entries. Cleanup is wired into both the `handle_deleted` reconciler (Kubernetes watcher) and the `delete_sandbox` gRPC handler to prevent unbounded memory growth from accumulated entries for deleted sandboxes. +**Cleanup:** Each bus exposes a `remove(sandbox_id)` method that drops the broadcast sender (closing active receivers with `RecvError::Closed`) and frees internal map entries. Cleanup is wired into the compute watch reconciler, the periodic orphan sweep for stale `Provisioning` records, and the `delete_sandbox` gRPC handler to prevent unbounded memory growth from accumulated entries for deleted sandboxes. **Validation:** `WatchSandbox` validates that the sandbox exists before subscribing to any bus, preventing entries from being created for non-existent IDs. `PushSandboxLogs` validates sandbox existence once on the first batch of the stream. @@ -393,7 +393,7 @@ The `ExecSandbox` RPC (`crates/openshell-server/src/grpc.rs`) executes a command 1. Validate request: `sandbox_id`, `command`, and environment key format (`^[A-Za-z_][A-Za-z0-9_]*$`). 2. Verify sandbox exists and is in `Ready` phase. -3. Resolve target: prefer agent pod IP (via `sandbox_client.agent_pod_ip()`), fall back to Kubernetes service DNS (`..svc.cluster.local`). +3. Resolve target: prefer agent pod IP, fall back to Kubernetes service DNS (`..svc.cluster.local`). If the sandbox is not connectable yet (for example the pod exists but has no IP), the gateway returns `FAILED_PRECONDITION` instead of surfacing the condition as an internal server fault. 4. Build the remote command string: sort environment variables, shell-escape all values, prepend `cd &&` if `workdir` is set. 5. **Start a single-use SSH proxy**: binds an ephemeral local TCP port, accepts one connection, performs the NSSH1 handshake with the sandbox, and bidirectionally copies data. 6. **Connect via `russh`**: establishes an SSH connection through the local proxy, authenticates with `none` auth as user `sandbox`, opens a session channel, and executes the command. diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index 3bc99e520..ef81ab07e 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -393,14 +393,28 @@ impl KubernetesComputeDriver { } } } + + pub async fn sandbox_exists(&self, name: &str) -> Result { + let api = self.api(); + match tokio::time::timeout(KUBE_API_TIMEOUT, api.get(name)).await { + Ok(Ok(_)) => Ok(true), + Ok(Err(KubeError::Api(err))) if err.code == 404 => Ok(false), + Ok(Err(err)) => Err(err.to_string()), + Err(_elapsed) => Err(format!( + "timed out after {}s waiting for Kubernetes API", + KUBE_API_TIMEOUT.as_secs() + )), + } + } + pub async fn resolve_sandbox_endpoint( &self, sandbox: &Sandbox, - ) -> Result { + ) -> Result { if let Some(status) = sandbox.status.as_ref() - && !status.agent_pod.is_empty() + && !status.instance_id.is_empty() { - match self.agent_pod_ip(&status.agent_pod).await { + match self.agent_pod_ip(&status.instance_id).await { Ok(Some(ip)) => { return Ok(ResolveSandboxEndpointResponse { endpoint: Some(SandboxEndpoint { @@ -410,16 +424,22 @@ impl KubernetesComputeDriver { }); } Ok(None) => { - return Err("sandbox agent pod IP is not available".to_string()); + return Err(KubernetesDriverError::Precondition( + "sandbox agent pod IP is not available".to_string(), + )); } Err(err) => { - return Err(format!("failed to resolve agent pod IP: {err}")); + return Err(KubernetesDriverError::Message(format!( + "failed to resolve agent pod IP: {err}" + ))); } } } if sandbox.name.is_empty() { - return Err("sandbox has no name".to_string()); + return Err(KubernetesDriverError::Precondition( + "sandbox has no name".to_string(), + )); } Ok(ResolveSandboxEndpointResponse { @@ -615,9 +635,9 @@ fn update_indexes( sandbox_name_to_id.insert(sandbox.name.clone(), sandbox.id.clone()); } if let Some(status) = sandbox.status.as_ref() - && !status.agent_pod.is_empty() + && !status.instance_id.is_empty() { - agent_pod_to_id.insert(status.agent_pod.clone(), sandbox.id.clone()); + agent_pod_to_id.insert(status.instance_id.clone(), sandbox.id.clone()); } } @@ -920,7 +940,7 @@ fn sandbox_to_k8s_spec( // transforms are applied inside sandbox_template_to_k8s. let user_has_vct = spec .and_then(|s| s.template.as_ref()) - .and_then(|t| struct_to_json(&t.volume_claim_templates)) + .and_then(|t| platform_config_struct(t, "volume_claim_templates")) .is_some(); let inject_workspace = !user_has_vct; @@ -954,13 +974,14 @@ fn sandbox_to_k8s_spec( inject_workspace, ), ); - if !template.agent_socket.is_empty() { + if !template.agent_socket_path.is_empty() { root.insert( "agentSocket".to_string(), - serde_json::json!(template.agent_socket), + serde_json::json!(template.agent_socket_path), ); } - if let Some(volume_templates) = struct_to_json(&template.volume_claim_templates) { + if let Some(volume_templates) = platform_config_struct(template, "volume_claim_templates") + { root.insert("volumeClaimTemplates".to_string(), volume_templates); } } @@ -1029,18 +1050,15 @@ fn sandbox_template_to_k8s( if !template.labels.is_empty() { metadata.insert("labels".to_string(), serde_json::json!(template.labels)); } - if !template.annotations.is_empty() { - metadata.insert( - "annotations".to_string(), - serde_json::json!(template.annotations), - ); + if let Some(annotations) = platform_config_struct(template, "annotations") { + metadata.insert("annotations".to_string(), annotations); } let mut spec = serde_json::Map::new(); - if !template.runtime_class_name.is_empty() { + if let Some(runtime_class) = platform_config_string(template, "runtime_class_name") { spec.insert( "runtimeClassName".to_string(), - serde_json::json!(template.runtime_class_name), + serde_json::json!(runtime_class), ); } @@ -1158,8 +1176,29 @@ fn sandbox_template_to_k8s( } fn container_resources(template: &SandboxTemplate, gpu: bool) -> Option { - let mut resources = - struct_to_json(&template.resources).unwrap_or_else(|| serde_json::json!({})); + // Start from the raw resources passthrough in platform_config (preserves + // custom resource types like GPU limits that users set via the public API + // Struct), then overlay the typed DriverResourceRequirements on top. + let mut resources = platform_config_struct(template, "resources_raw") + .unwrap_or_else(|| serde_json::json!({})); + + // Overlay typed CPU/memory from DriverResourceRequirements. + if let Some(ref req) = template.resources { + let obj = resources.as_object_mut().unwrap(); + let mut apply = |section: &str, key: &str, value: &str| { + if !value.is_empty() { + let sec = obj + .entry(section) + .or_insert_with(|| serde_json::json!({})); + sec[key] = serde_json::json!(value); + } + }; + apply("requests", "cpu", &req.cpu_request); + apply("requests", "memory", &req.memory_request); + apply("limits", "cpu", &req.cpu_limit); + apply("limits", "memory", &req.memory_limit); + } + if gpu { apply_gpu_limit(&mut resources); } @@ -1283,13 +1322,29 @@ fn upsert_env(env: &mut Vec, name: &str, value: &str) { env.push(serde_json::json!({"name": name, "value": value})); } -fn struct_to_json(input: &Option) -> Option { - let input = input.as_ref()?; - let mut map = serde_json::Map::new(); - for (key, value) in &input.fields { - map.insert(key.clone(), proto_value_to_json(value)); +/// Extract a string value from the template's `platform_config` Struct. +fn platform_config_string(template: &SandboxTemplate, key: &str) -> Option { + let config = template.platform_config.as_ref()?; + let value = config.fields.get(key)?; + match value.kind.as_ref() { + Some(prost_types::value::Kind::StringValue(s)) if !s.is_empty() => Some(s.clone()), + _ => None, + } +} + +/// Extract a nested Struct value from the template's `platform_config`, +/// converting it to `serde_json::Value`. +fn platform_config_struct(template: &SandboxTemplate, key: &str) -> Option { + let config = template.platform_config.as_ref()?; + let value = config.fields.get(key)?; + let json = proto_value_to_json(value); + // Return None for null/empty objects so callers can distinguish + // "field absent" from "field present but empty". + match &json { + serde_json::Value::Null => None, + serde_json::Value::Object(m) if m.is_empty() => None, + _ => Some(json), } - Some(serde_json::Value::Object(map)) } fn proto_value_to_json(value: &prost_types::Value) -> serde_json::Value { @@ -1334,7 +1389,7 @@ fn status_from_object(obj: &DynamicObject) -> Option { .and_then(|val| val.as_str()) .unwrap_or_default() .to_string(), - agent_pod: status_obj + instance_id: status_obj .get("agentPod") .and_then(|val| val.as_str()) .unwrap_or_default() @@ -1552,12 +1607,6 @@ mod tests { ); } - fn string_value(value: &str) -> Value { - Value { - kind: Some(Kind::StringValue(value.to_string())), - } - } - #[test] fn gpu_sandbox_adds_runtime_class_and_gpu_limit() { let pod_template = sandbox_template_to_k8s( @@ -1590,7 +1639,16 @@ mod tests { #[test] fn gpu_sandbox_uses_template_runtime_class_name_when_set() { let template = SandboxTemplate { - runtime_class_name: "kata-containers".to_string(), + platform_config: Some(Struct { + fields: [( + "runtime_class_name".to_string(), + Value { + kind: Some(Kind::StringValue("kata-containers".to_string())), + }, + )] + .into_iter() + .collect(), + }), ..SandboxTemplate::default() }; @@ -1620,7 +1678,16 @@ mod tests { #[test] fn non_gpu_sandbox_uses_template_runtime_class_name_when_set() { let template = SandboxTemplate { - runtime_class_name: "kata-containers".to_string(), + platform_config: Some(Struct { + fields: [( + "runtime_class_name".to_string(), + Value { + kind: Some(Kind::StringValue("kata-containers".to_string())), + }, + )] + .into_iter() + .collect(), + }), ..SandboxTemplate::default() }; @@ -1649,20 +1716,11 @@ mod tests { #[test] fn gpu_sandbox_preserves_existing_resource_limits() { + use openshell_core::proto::compute::v1::DriverResourceRequirements; let template = SandboxTemplate { - resources: Some(Struct { - fields: [( - "limits".to_string(), - Value { - kind: Some(Kind::StructValue(Struct { - fields: [("cpu".to_string(), string_value("2"))] - .into_iter() - .collect(), - })), - }, - )] - .into_iter() - .collect(), + resources: Some(DriverResourceRequirements { + cpu_limit: "2".to_string(), + ..Default::default() }), ..SandboxTemplate::default() }; diff --git a/crates/openshell-driver-kubernetes/src/grpc.rs b/crates/openshell-driver-kubernetes/src/grpc.rs index a2457a218..67e0795ec 100644 --- a/crates/openshell-driver-kubernetes/src/grpc.rs +++ b/crates/openshell-driver-kubernetes/src/grpc.rs @@ -13,7 +13,7 @@ use openshell_core::proto::compute::v1::{ use std::pin::Pin; use tonic::{Request, Response, Status}; -use crate::KubernetesComputeDriver; +use crate::{KubernetesComputeDriver, KubernetesDriverError}; #[derive(Debug, Clone)] pub struct ComputeDriverService { @@ -140,7 +140,7 @@ impl ComputeDriver for ComputeDriverService { .resolve_sandbox_endpoint(&sandbox) .await .map(Response::new) - .map_err(Status::internal) + .map_err(status_from_driver_error) } type WatchSandboxesStream = @@ -159,3 +159,26 @@ impl ComputeDriver for ComputeDriverService { Ok(Response::new(Box::pin(stream))) } } + +fn status_from_driver_error(err: KubernetesDriverError) -> Status { + match err { + KubernetesDriverError::AlreadyExists => Status::already_exists("sandbox already exists"), + KubernetesDriverError::Precondition(message) => Status::failed_precondition(message), + KubernetesDriverError::Message(message) => Status::internal(message), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn precondition_driver_errors_map_to_failed_precondition_status() { + let status = status_from_driver_error(KubernetesDriverError::Precondition( + "sandbox agent pod IP is not available".to_string(), + )); + + assert_eq!(status.code(), tonic::Code::FailedPrecondition); + assert_eq!(status.message(), "sandbox agent pod IP is not available"); + } +} diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs index 465aca5e4..f69120166 100644 --- a/crates/openshell-server/src/compute/mod.rs +++ b/crates/openshell-server/src/compute/mod.rs @@ -10,8 +10,9 @@ use crate::sandbox_watch::SandboxWatchBus; use crate::tracing_bus::TracingLogBus; use futures::{Stream, StreamExt}; use openshell_core::proto::compute::v1::{ - DriverCondition, DriverPlatformEvent, DriverSandbox, DriverSandboxSpec, DriverSandboxStatus, - DriverSandboxTemplate, ResolveSandboxEndpointResponse, WatchSandboxesEvent, sandbox_endpoint, + DriverCondition, DriverPlatformEvent, DriverResourceRequirements, DriverSandbox, + DriverSandboxSpec, DriverSandboxStatus, DriverSandboxTemplate, + ResolveSandboxEndpointResponse, WatchSandboxesEvent, sandbox_endpoint, watch_sandboxes_event, }; use openshell_core::proto::{ @@ -28,11 +29,18 @@ use std::pin::Pin; use std::sync::Arc; use std::time::Duration; use tonic::Status; -use tracing::{info, warn}; +use tracing::{debug, info, warn}; type ComputeWatchStream = Pin> + Send>>; +/// Interval between store-vs-backend reconciliation sweeps. +const RECONCILE_INTERVAL: Duration = Duration::from_secs(60); + +/// How long a sandbox can remain provisioning in the store without a +/// corresponding backend resource before it is considered orphaned. +const ORPHAN_GRACE_PERIOD: Duration = Duration::from_secs(300); + #[derive(Debug, thiserror::Error)] pub enum ComputeError { #[error("sandbox already exists")] @@ -53,6 +61,7 @@ impl From for ComputeError { } } +#[derive(Debug)] pub enum ResolvedEndpoint { Ip(IpAddr, u16), Host(String, u16), @@ -64,6 +73,7 @@ pub trait ComputeBackend: fmt::Debug + Send + Sync { async fn validate_sandbox_create(&self, sandbox: &DriverSandbox) -> Result<(), Status>; async fn create_sandbox(&self, sandbox: &DriverSandbox) -> Result<(), ComputeError>; async fn delete_sandbox(&self, sandbox_name: &str) -> Result; + async fn sandbox_exists(&self, sandbox_name: &str) -> Result; async fn resolve_sandbox_endpoint( &self, sandbox: &DriverSandbox, @@ -107,6 +117,13 @@ impl ComputeBackend for InProcessKubernetesBackend { .map_err(ComputeError::Message) } + async fn sandbox_exists(&self, sandbox_name: &str) -> Result { + self.driver + .sandbox_exists(sandbox_name) + .await + .map_err(ComputeError::Message) + } + async fn resolve_sandbox_endpoint( &self, sandbox: &DriverSandbox, @@ -115,7 +132,7 @@ impl ComputeBackend for InProcessKubernetesBackend { .driver .resolve_sandbox_endpoint(sandbox) .await - .map_err(ComputeError::Message)?; + .map_err(ComputeError::from)?; resolved_endpoint_from_response(&response) } @@ -297,8 +314,12 @@ impl ComputeRuntime { pub fn spawn_watchers(&self) { let runtime = Arc::new(self.clone()); + let watch_runtime = runtime.clone(); + tokio::spawn(async move { + watch_runtime.watch_loop().await; + }); tokio::spawn(async move { - runtime.watch_loop().await; + runtime.reconcile_loop().await; }); } @@ -336,6 +357,82 @@ impl ComputeRuntime { } } + async fn reconcile_loop(self: Arc) { + // Let startup settle before pruning store records. + tokio::time::sleep(RECONCILE_INTERVAL).await; + + loop { + if let Err(err) = self.reconcile_orphaned_sandboxes(ORPHAN_GRACE_PERIOD).await { + warn!(error = %err, "Store reconciliation sweep failed"); + } + tokio::time::sleep(RECONCILE_INTERVAL).await; + } + } + + async fn reconcile_orphaned_sandboxes(&self, grace_period: Duration) -> Result<(), String> { + let records = self + .store + .list(Sandbox::object_type(), 500, 0) + .await + .map_err(|e| e.to_string())?; + + let now_ms = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis() + .try_into() + .unwrap_or(i64::MAX); + let grace_ms = grace_period.as_millis().try_into().unwrap_or(i64::MAX); + + for record in records { + let sandbox = match Sandbox::decode(record.payload.as_slice()) { + Ok(sandbox) => sandbox, + Err(err) => { + warn!(error = %err, "Failed to decode sandbox record during reconciliation"); + continue; + } + }; + + if sandbox.phase != SandboxPhase::Provisioning as i32 { + continue; + } + + let age_ms = now_ms.saturating_sub(record.created_at_ms); + if age_ms < grace_ms { + continue; + } + + match self.backend.sandbox_exists(&sandbox.name).await { + Ok(true) => {} + Ok(false) => { + info!( + sandbox_id = %sandbox.id, + sandbox_name = %sandbox.name, + age_secs = age_ms / 1000, + "Removing orphaned sandbox from store (no corresponding backend resource)" + ); + if let Err(err) = self.store.delete(Sandbox::object_type(), &sandbox.id).await { + warn!(sandbox_id = %sandbox.id, error = %err, "Failed to remove orphaned sandbox"); + continue; + } + self.sandbox_index.remove_sandbox(&sandbox.id); + self.sandbox_watch_bus.notify(&sandbox.id); + self.cleanup_sandbox_state(&sandbox.id); + } + Err(err) => { + debug!( + sandbox_id = %sandbox.id, + sandbox_name = %sandbox.name, + error = %err, + "Skipping orphan check due to backend error" + ); + } + } + } + + Ok(()) + } + async fn apply_watch_event(&self, event: WatchSandboxesEvent) -> Result<(), String> { match event.payload { Some(watch_sandboxes_event::Payload::Sandbox(sandbox)) => { @@ -378,7 +475,7 @@ impl ComputeRuntime { existing.as_ref().and_then(|sandbox| sandbox.spec.as_ref()), ); - let mut phase = derive_phase(incoming.status.as_ref()); + let phase = derive_phase(incoming.status.as_ref()); let mut sandbox = existing.unwrap_or_else(|| Sandbox { id: incoming.id.clone(), name: incoming.name.clone(), @@ -390,9 +487,6 @@ impl ComputeRuntime { }); let old_phase = SandboxPhase::try_from(sandbox.phase).unwrap_or(SandboxPhase::Unknown); - if old_phase == SandboxPhase::Deleting && phase != SandboxPhase::Error { - phase = SandboxPhase::Deleting; - } if old_phase != phase { info!( sandbox_id = %incoming.id, @@ -483,20 +577,131 @@ fn driver_sandbox_spec_from_public(spec: &SandboxSpec) -> DriverSandboxSpec { fn driver_sandbox_template_from_public(template: &SandboxTemplate) -> DriverSandboxTemplate { DriverSandboxTemplate { image: template.image.clone(), - runtime_class_name: template.runtime_class_name.clone(), - agent_socket: template.agent_socket.clone(), + agent_socket_path: template.agent_socket.clone(), labels: template.labels.clone(), - annotations: template.annotations.clone(), environment: template.environment.clone(), - resources: template.resources.clone(), - volume_claim_templates: template.volume_claim_templates.clone(), + resources: extract_typed_resources(&template.resources), + platform_config: build_platform_config(template), + } +} + +/// Extract typed CPU/memory quantities from the public `resources` Struct. +/// +/// The public API exposes resources as an untyped `google.protobuf.Struct` +/// with the Kubernetes limits/requests shape. We pull out the well-known +/// keys into the typed `DriverResourceRequirements` message. +fn extract_typed_resources( + resources: &Option, +) -> Option { + let s = resources.as_ref()?; + + fn get_quantity(s: &prost_types::Struct, section: &str, key: &str) -> String { + s.fields + .get(section) + .and_then(|v| match v.kind.as_ref() { + Some(prost_types::value::Kind::StructValue(inner)) => inner.fields.get(key), + _ => None, + }) + .and_then(|v| match v.kind.as_ref() { + Some(prost_types::value::Kind::StringValue(val)) => Some(val.clone()), + _ => None, + }) + .unwrap_or_default() + } + + let req = DriverResourceRequirements { + cpu_request: get_quantity(s, "requests", "cpu"), + cpu_limit: get_quantity(s, "limits", "cpu"), + memory_request: get_quantity(s, "requests", "memory"), + memory_limit: get_quantity(s, "limits", "memory"), + }; + + // Return None when all fields are empty so drivers can distinguish + // "no resource requirements" from "zero requirements". + if req.cpu_request.is_empty() + && req.cpu_limit.is_empty() + && req.memory_request.is_empty() + && req.memory_limit.is_empty() + { + None + } else { + Some(req) + } +} + +/// Build the opaque `platform_config` Struct from platform-specific public +/// template fields (runtime_class_name, annotations, volume_claim_templates) +/// plus any resource fields beyond CPU/memory. +fn build_platform_config(template: &SandboxTemplate) -> Option { + use prost_types::{Struct, Value, value::Kind}; + + let mut fields = std::collections::BTreeMap::new(); + + if !template.runtime_class_name.is_empty() { + fields.insert( + "runtime_class_name".to_string(), + Value { + kind: Some(Kind::StringValue(template.runtime_class_name.clone())), + }, + ); + } + + if !template.annotations.is_empty() { + let annotation_fields = template + .annotations + .iter() + .map(|(k, v)| { + ( + k.clone(), + Value { + kind: Some(Kind::StringValue(v.clone())), + }, + ) + }) + .collect(); + fields.insert( + "annotations".to_string(), + Value { + kind: Some(Kind::StructValue(Struct { + fields: annotation_fields, + })), + }, + ); + } + + // Pass through the raw volume_claim_templates Struct as a nested value. + if let Some(ref vct) = template.volume_claim_templates { + fields.insert( + "volume_claim_templates".to_string(), + Value { + kind: Some(Kind::StructValue(vct.clone())), + }, + ); + } + + // Pass through any non-cpu/memory resource fields from the original + // resources Struct so the driver can handle GPU limits, custom resources, + // etc. that don't map to the typed DriverResourceRequirements. + if let Some(ref res) = template.resources { + fields.insert( + "resources_raw".to_string(), + Value { + kind: Some(Kind::StructValue(res.clone())), + }, + ); + } + + if fields.is_empty() { + None + } else { + Some(Struct { fields }) } } fn driver_status_from_public(status: &SandboxStatus, phase: i32) -> DriverSandboxStatus { DriverSandboxStatus { sandbox_name: status.sandbox_name.clone(), - agent_pod: status.agent_pod.clone(), + instance_id: status.agent_pod.clone(), agent_fd: status.agent_fd.clone(), sandbox_fd: status.sandbox_fd.clone(), conditions: status @@ -563,7 +768,7 @@ fn resolved_endpoint_from_response( fn public_status_from_driver(status: &DriverSandboxStatus) -> SandboxStatus { SandboxStatus { sandbox_name: status.sandbox_name.clone(), - agent_pod: status.agent_pod.clone(), + agent_pod: status.instance_id.clone(), agent_fd: status.agent_fd.clone(), sandbox_fd: status.sandbox_fd.clone(), conditions: status @@ -649,6 +854,76 @@ fn is_terminal_failure_reason(reason: &str) -> bool { #[cfg(test)] mod tests { use super::*; + use futures::stream; + use std::sync::Arc; + + #[derive(Debug, Default)] + struct TestBackend { + sandbox_exists: bool, + resolve_precondition: Option, + } + + #[tonic::async_trait] + impl ComputeBackend for TestBackend { + fn default_image(&self) -> &'static str { + "openshell/sandbox:test" + } + + async fn validate_sandbox_create(&self, _sandbox: &DriverSandbox) -> Result<(), Status> { + Ok(()) + } + + async fn create_sandbox(&self, _sandbox: &DriverSandbox) -> Result<(), ComputeError> { + Ok(()) + } + + async fn delete_sandbox(&self, _sandbox_name: &str) -> Result { + Ok(true) + } + + async fn sandbox_exists(&self, _sandbox_name: &str) -> Result { + Ok(self.sandbox_exists) + } + + async fn resolve_sandbox_endpoint( + &self, + _sandbox: &DriverSandbox, + ) -> Result { + if let Some(message) = &self.resolve_precondition { + return Err(ComputeError::Precondition(message.clone())); + } + + Ok(ResolvedEndpoint::Host( + "sandbox.default.svc.cluster.local".to_string(), + 2222, + )) + } + + async fn watch_sandboxes(&self) -> Result { + Ok(Box::pin(stream::empty())) + } + } + + async fn test_runtime(backend: Arc) -> ComputeRuntime { + let store = Arc::new(Store::connect("sqlite::memory:").await.unwrap()); + ComputeRuntime { + backend, + store, + sandbox_index: SandboxIndex::new(), + sandbox_watch_bus: SandboxWatchBus::new(), + tracing_log_bus: TracingLogBus::new(), + } + } + + fn sandbox_record(id: &str, name: &str, phase: SandboxPhase) -> Sandbox { + Sandbox { + id: id.to_string(), + name: name.to_string(), + namespace: "default".to_string(), + phase: phase as i32, + ..Default::default() + } + } fn make_driver_condition(reason: &str, message: &str) -> DriverCondition { DriverCondition { @@ -663,7 +938,7 @@ mod tests { fn make_driver_status(condition: DriverCondition) -> DriverSandboxStatus { DriverSandboxStatus { sandbox_name: "test".to_string(), - agent_pod: "test-pod".to_string(), + instance_id: "test-pod".to_string(), agent_fd: String::new(), sandbox_fd: String::new(), conditions: vec![condition], @@ -836,4 +1111,98 @@ mod tests { assert_eq!(status.unwrap().conditions[0].message, original); } + + #[tokio::test] + async fn apply_sandbox_update_allows_delete_failures_to_recover() { + let runtime = test_runtime(Arc::new(TestBackend::default())).await; + let sandbox = sandbox_record("sb-1", "sandbox-a", SandboxPhase::Deleting); + runtime.store.put_message(&sandbox).await.unwrap(); + + runtime + .apply_sandbox_update(DriverSandbox { + id: "sb-1".to_string(), + name: "sandbox-a".to_string(), + namespace: "default".to_string(), + spec: None, + status: Some(DriverSandboxStatus { + sandbox_name: "sandbox-a".to_string(), + instance_id: "agent-pod".to_string(), + agent_fd: String::new(), + sandbox_fd: String::new(), + conditions: vec![DriverCondition { + r#type: "Ready".to_string(), + status: "True".to_string(), + reason: "DependenciesReady".to_string(), + message: "Pod is Ready".to_string(), + last_transition_time: String::new(), + }], + deleting: false, + }), + }) + .await + .unwrap(); + + let stored = runtime + .store + .get_message::("sb-1") + .await + .unwrap() + .unwrap(); + assert_eq!( + SandboxPhase::try_from(stored.phase).unwrap(), + SandboxPhase::Ready + ); + } + + #[tokio::test] + async fn resolve_sandbox_endpoint_preserves_precondition_errors() { + let runtime = test_runtime(Arc::new(TestBackend { + sandbox_exists: true, + resolve_precondition: Some("sandbox agent pod IP is not available".to_string()), + })) + .await; + + let err = runtime + .resolve_sandbox_endpoint(&sandbox_record("sb-1", "sandbox-a", SandboxPhase::Ready)) + .await + .expect_err("endpoint resolution should preserve failed-precondition errors"); + + assert_eq!(err.code(), tonic::Code::FailedPrecondition); + assert_eq!(err.message(), "sandbox agent pod IP is not available"); + } + + #[tokio::test] + async fn reconcile_orphaned_sandboxes_removes_stale_provisioning_records() { + let runtime = test_runtime(Arc::new(TestBackend::default())).await; + let sandbox = sandbox_record("sb-1", "sandbox-a", SandboxPhase::Provisioning); + runtime.store.put_message(&sandbox).await.unwrap(); + runtime.sandbox_index.update_from_sandbox(&sandbox); + + let mut watch_rx = runtime.sandbox_watch_bus.subscribe(&sandbox.id); + + runtime + .reconcile_orphaned_sandboxes(Duration::ZERO) + .await + .unwrap(); + + assert!( + runtime + .store + .get_message::(&sandbox.id) + .await + .unwrap() + .is_none() + ); + assert!( + runtime + .sandbox_index + .sandbox_id_for_sandbox_name(&sandbox.name) + .is_none() + ); + let _ = watch_rx.try_recv(); + assert!(matches!( + watch_rx.try_recv(), + Err(tokio::sync::broadcast::error::TryRecvError::Closed) + )); + } } diff --git a/proto/compute_driver.proto b/proto/compute_driver.proto index 8a6f32e9f..53b0ac27d 100644 --- a/proto/compute_driver.proto +++ b/proto/compute_driver.proto @@ -91,23 +91,44 @@ message DriverSandboxSpec { } // Driver-owned runtime template consumed by the compute platform. +// +// This message describes the sandbox workload in backend-neutral terms. +// Platform-specific knobs (Kubernetes runtimeClassName, annotations, +// volumeClaimTemplates, etc.) belong in `platform_config`. message DriverSandboxTemplate { // Fully-qualified OCI image reference used to boot the sandbox. string image = 1; - // Optional runtime class name requested from the compute platform. - string runtime_class_name = 2; - // Optional agent socket path exposed to the workload. - string agent_socket = 3; - // Labels applied to compute-platform resources. + // Socket path inside the sandbox where the agent service listens. + string agent_socket_path = 3; + // Metadata labels applied to compute-platform resources. + // Drivers map these to the platform's native tagging mechanism + // (Kubernetes labels, cloud instance tags, etc.). map labels = 4; - // Annotations applied to compute-platform resources. - map annotations = 5; - // Additional environment variables injected by the template. + // Additional environment variables injected into the sandbox runtime. map environment = 6; - // Platform-specific compute resource requirements and limits. - google.protobuf.Struct resources = 7; - // Optional platform-specific volume claim templates. - google.protobuf.Struct volume_claim_templates = 9; + // Typed compute-resource requirements for the sandbox workload. + DriverResourceRequirements resources = 10; + // Opaque, platform-specific configuration passed through to the driver. + // The gateway does not inspect this; each driver defines its own schema. + // For the Kubernetes driver this carries fields such as runtimeClassName, + // annotations, and volumeClaimTemplates. + google.protobuf.Struct platform_config = 11; +} + +// Typed compute-resource requirements. +// +// Values use Kubernetes-style quantity strings (e.g. "500m", "2", "4Gi") +// because they are a well-known, widely-adopted notation. Drivers for +// non-Kubernetes platforms must parse these strings into their native units. +message DriverResourceRequirements { + // Minimum CPU cores requested (e.g. "500m", "2"). + string cpu_request = 1; + // Maximum CPU cores allowed (e.g. "500m", "4"). + string cpu_limit = 2; + // Minimum memory requested (e.g. "256Mi", "4Gi"). + string memory_request = 3; + // Maximum memory allowed (e.g. "512Mi", "8Gi"). + string memory_limit = 4; } // Raw status observed directly from the compute platform. @@ -117,11 +138,15 @@ message DriverSandboxTemplate { message DriverSandboxStatus { // Compute-platform sandbox object name. string sandbox_name = 1; - // Name of the agent pod or equivalent runtime instance. - string agent_pod = 2; - // File descriptor or endpoint for reaching the agent service, when available. + // Platform-assigned instance identifier for the compute unit running the + // sandbox agent (e.g. Kubernetes pod name, VM instance ID, hostname). + // The gateway uses this to correlate incoming connections back to a sandbox. + string instance_id = 2; + // File descriptor or address for reaching the agent service inside the + // sandbox, when available. string agent_fd = 3; - // File descriptor or endpoint for reaching the sandbox service, when available. + // File descriptor or address for reaching the sandbox supervisor service, + // when available. string sandbox_fd = 4; // Raw readiness and lifecycle conditions reported by the platform. repeated DriverCondition conditions = 5; From 6e0726ba77514ab0f4a5d35bfb44fa7a6c0f2675 Mon Sep 17 00:00:00 2001 From: Drew Newberry Date: Tue, 14 Apr 2026 10:10:00 -0700 Subject: [PATCH 4/7] style(rust): apply branch check formatting fixes --- crates/openshell-driver-kubernetes/src/driver.rs | 11 +++++------ crates/openshell-server/src/compute/mod.rs | 5 ++--- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index ef81ab07e..440703af5 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -980,7 +980,8 @@ fn sandbox_to_k8s_spec( serde_json::json!(template.agent_socket_path), ); } - if let Some(volume_templates) = platform_config_struct(template, "volume_claim_templates") + if let Some(volume_templates) = + platform_config_struct(template, "volume_claim_templates") { root.insert("volumeClaimTemplates".to_string(), volume_templates); } @@ -1179,17 +1180,15 @@ fn container_resources(template: &SandboxTemplate, gpu: bool) -> Option Date: Tue, 14 Apr 2026 11:20:30 -0700 Subject: [PATCH 5/7] python protos --- python/openshell/_proto/__init__.py | 11 +++++++++++ python/openshell/sandbox.py | 16 ++++++++-------- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/python/openshell/_proto/__init__.py b/python/openshell/_proto/__init__.py index e69de29bb..3ace22421 100644 --- a/python/openshell/_proto/__init__.py +++ b/python/openshell/_proto/__init__.py @@ -0,0 +1,11 @@ +from . import datamodel_pb2, openshell_pb2 + +# Sandbox messages and phase enums moved into openshell.proto. Keep aliases on +# datamodel_pb2 so existing Python callers and E2E tests continue to work. +for _name in ("Sandbox", "SandboxSpec", "SandboxTemplate"): + if not hasattr(datamodel_pb2, _name): + setattr(datamodel_pb2, _name, getattr(openshell_pb2, _name)) + +for _name in dir(openshell_pb2): + if _name.startswith("SANDBOX_PHASE_") and not hasattr(datamodel_pb2, _name): + setattr(datamodel_pb2, _name, getattr(openshell_pb2, _name)) diff --git a/python/openshell/sandbox.py b/python/openshell/sandbox.py index 19bdcdf65..b52f42997 100644 --- a/python/openshell/sandbox.py +++ b/python/openshell/sandbox.py @@ -187,7 +187,7 @@ def health(self) -> openshell_pb2.HealthResponse: def create( self, *, - spec: datamodel_pb2.SandboxSpec | None = None, + spec: openshell_pb2.SandboxSpec | None = None, ) -> SandboxRef: request_spec = spec if spec is not None else _default_spec() response = self._stub.CreateSandbox( @@ -201,7 +201,7 @@ def create( def create_session( self, *, - spec: datamodel_pb2.SandboxSpec | None = None, + spec: openshell_pb2.SandboxSpec | None = None, ) -> SandboxSession: return SandboxSession(self, self.create(spec=spec)) @@ -253,9 +253,9 @@ def wait_ready( deadline = time.time() + timeout_seconds while time.time() < deadline: sandbox = self.get(sandbox_name) - if sandbox.phase == datamodel_pb2.SANDBOX_PHASE_READY: + if sandbox.phase == openshell_pb2.SANDBOX_PHASE_READY: return sandbox - if sandbox.phase == datamodel_pb2.SANDBOX_PHASE_ERROR: + if sandbox.phase == openshell_pb2.SANDBOX_PHASE_ERROR: raise SandboxError(f"sandbox {sandbox_name} entered error phase") time.sleep(1) raise SandboxError(f"sandbox {sandbox_name} was not ready within timeout") @@ -435,7 +435,7 @@ def __init__( cluster: str | None = None, sandbox: str | SandboxRef | None = None, delete_on_exit: bool = True, - spec: datamodel_pb2.SandboxSpec | None = None, + spec: openshell_pb2.SandboxSpec | None = None, timeout: float = 30.0, ready_timeout_seconds: float = 120.0, ) -> None: @@ -576,7 +576,7 @@ def _serialize_python_callable( return base64.b64encode(payload).decode("ascii") -def _sandbox_ref(sandbox: datamodel_pb2.Sandbox) -> SandboxRef: +def _sandbox_ref(sandbox: openshell_pb2.Sandbox) -> SandboxRef: return SandboxRef( id=sandbox.id, name=sandbox.name, @@ -585,13 +585,13 @@ def _sandbox_ref(sandbox: datamodel_pb2.Sandbox) -> SandboxRef: ) -def _default_spec() -> datamodel_pb2.SandboxSpec: +def _default_spec() -> openshell_pb2.SandboxSpec: # Omit the policy field so the sandbox container discovers its policy # from /etc/openshell/policy.yaml (baked into the image at build time). # This avoids duplicating policy defaults between the SDK and the # container image and ensures sandboxes get the full dev-sandbox-policy # (including network_policies) out of the box. - return datamodel_pb2.SandboxSpec() + return openshell_pb2.SandboxSpec() def _xdg_config_home() -> pathlib.Path: From fe451a812f722d946d2414dbd375be4b0d3af9a6 Mon Sep 17 00:00:00 2001 From: Drew Newberry Date: Tue, 14 Apr 2026 11:44:20 -0700 Subject: [PATCH 6/7] fix(python): remove stale datamodel proto import --- python/openshell/sandbox.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/openshell/sandbox.py b/python/openshell/sandbox.py index b52f42997..eba79f0fd 100644 --- a/python/openshell/sandbox.py +++ b/python/openshell/sandbox.py @@ -16,7 +16,6 @@ import grpc from ._proto import ( - datamodel_pb2, inference_pb2, inference_pb2_grpc, openshell_pb2, From 179abfcacbb4c34dc97f8b29144078cadc71d77a Mon Sep 17 00:00:00 2001 From: Drew Newberry Date: Tue, 14 Apr 2026 12:50:42 -0700 Subject: [PATCH 7/7] add --drivers flag --- crates/openshell-core/src/config.rs | 94 ++++++++++++++++++++++ crates/openshell-core/src/image.rs | 12 +++ crates/openshell-core/src/lib.rs | 2 +- crates/openshell-server/src/lib.rs | 117 ++++++++++++++++++++++++---- crates/openshell-server/src/main.rs | 21 +++++ 5 files changed, 228 insertions(+), 18 deletions(-) diff --git a/crates/openshell-core/src/config.rs b/crates/openshell-core/src/config.rs index 750aa98b0..279752b4f 100644 --- a/crates/openshell-core/src/config.rs +++ b/crates/openshell-core/src/config.rs @@ -4,8 +4,48 @@ //! Configuration management for OpenShell components. use serde::{Deserialize, Serialize}; +use std::fmt; use std::net::SocketAddr; use std::path::PathBuf; +use std::str::FromStr; + +/// Compute backends the gateway can orchestrate sandboxes through. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ComputeDriverKind { + Kubernetes, + Podman, +} + +impl ComputeDriverKind { + #[must_use] + pub const fn as_str(self) -> &'static str { + match self { + Self::Kubernetes => "kubernetes", + Self::Podman => "podman", + } + } +} + +impl fmt::Display for ComputeDriverKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +impl FromStr for ComputeDriverKind { + type Err = String; + + fn from_str(value: &str) -> Result { + match value.trim().to_ascii_lowercase().as_str() { + "kubernetes" => Ok(Self::Kubernetes), + "podman" => Ok(Self::Podman), + other => Err(format!( + "unsupported compute driver '{other}'. expected one of: kubernetes, podman" + )), + } + } +} /// Server configuration. #[derive(Debug, Clone, Serialize, Deserialize)] @@ -24,6 +64,14 @@ pub struct Config { /// Database URL for persistence. pub database_url: String, + /// Compute drivers configured for the gateway. + /// + /// The config shape allows multiple drivers so the gateway can evolve + /// toward multi-backend routing. Current releases require exactly one + /// configured driver. + #[serde(default = "default_compute_drivers")] + pub compute_drivers: Vec, + /// Kubernetes namespace for sandboxes. #[serde(default = "default_sandbox_namespace")] pub sandbox_namespace: String, @@ -120,6 +168,7 @@ impl Config { log_level: default_log_level(), tls, database_url: String::new(), + compute_drivers: default_compute_drivers(), sandbox_namespace: default_sandbox_namespace(), sandbox_image: String::new(), sandbox_image_pull_policy: String::new(), @@ -157,6 +206,16 @@ impl Config { self } + /// Create a new configuration with the configured compute drivers. + #[must_use] + pub fn with_compute_drivers(mut self, drivers: I) -> Self + where + I: IntoIterator, + { + self.compute_drivers = drivers.into_iter().collect(); + self + } + /// Create a new configuration with a sandbox namespace. #[must_use] pub fn with_sandbox_namespace(mut self, namespace: impl Into) -> Self { @@ -261,6 +320,10 @@ fn default_sandbox_namespace() -> String { "default".to_string() } +fn default_compute_drivers() -> Vec { + vec![ComputeDriverKind::Kubernetes] +} + fn default_ssh_gateway_host() -> String { "127.0.0.1".to_string() } @@ -284,3 +347,34 @@ const fn default_ssh_handshake_skew_secs() -> u64 { const fn default_ssh_session_ttl_secs() -> u64 { 86400 // 24 hours } + +#[cfg(test)] +mod tests { + use super::{ComputeDriverKind, Config}; + + #[test] + fn compute_driver_kind_parses_supported_values() { + assert_eq!( + "kubernetes".parse::().unwrap(), + ComputeDriverKind::Kubernetes + ); + assert_eq!( + "podman".parse::().unwrap(), + ComputeDriverKind::Podman + ); + } + + #[test] + fn compute_driver_kind_rejects_unknown_values() { + let err = "docker".parse::().unwrap_err(); + assert!(err.contains("unsupported compute driver 'docker'")); + } + + #[test] + fn config_defaults_to_kubernetes_driver() { + assert_eq!( + Config::new(None).compute_drivers, + vec![ComputeDriverKind::Kubernetes] + ); + } +} diff --git a/crates/openshell-core/src/image.rs b/crates/openshell-core/src/image.rs index ff10baaac..6a628e2a9 100644 --- a/crates/openshell-core/src/image.rs +++ b/crates/openshell-core/src/image.rs @@ -42,9 +42,16 @@ pub fn resolve_community_image(value: &str) -> String { #[allow(unsafe_code)] mod tests { use super::*; + use std::sync::{Mutex, OnceLock}; + + fn env_lock() -> &'static Mutex<()> { + static ENV_LOCK: OnceLock> = OnceLock::new(); + ENV_LOCK.get_or_init(|| Mutex::new(())) + } #[test] fn bare_name_expands_to_community_registry() { + let _guard = env_lock().lock().unwrap(); let result = resolve_community_image("base"); assert_eq!( result, @@ -54,6 +61,7 @@ mod tests { #[test] fn bare_name_with_env_override() { + let _guard = env_lock().lock().unwrap(); // Use a temp env override. Safety: test-only, and these env-var tests // are not run concurrently with other tests reading the same var. let key = "OPENSHELL_COMMUNITY_REGISTRY"; @@ -71,24 +79,28 @@ mod tests { #[test] fn full_reference_with_slash_passes_through() { + let _guard = env_lock().lock().unwrap(); let input = "ghcr.io/myorg/myimage:v1"; assert_eq!(resolve_community_image(input), input); } #[test] fn reference_with_colon_passes_through() { + let _guard = env_lock().lock().unwrap(); let input = "myimage:latest"; assert_eq!(resolve_community_image(input), input); } #[test] fn reference_with_dot_passes_through() { + let _guard = env_lock().lock().unwrap(); let input = "registry.example.com"; assert_eq!(resolve_community_image(input), input); } #[test] fn trailing_slash_in_env_is_trimmed() { + let _guard = env_lock().lock().unwrap(); let key = "OPENSHELL_COMMUNITY_REGISTRY"; let prev = std::env::var(key).ok(); // SAFETY: single-threaded test context; no other thread reads this var. diff --git a/crates/openshell-core/src/lib.rs b/crates/openshell-core/src/lib.rs index 30fb205ff..c0b08f1a5 100644 --- a/crates/openshell-core/src/lib.rs +++ b/crates/openshell-core/src/lib.rs @@ -19,7 +19,7 @@ pub mod paths; pub mod proto; pub mod settings; -pub use config::{Config, TlsConfig}; +pub use config::{ComputeDriverKind, Config, TlsConfig}; pub use error::{Error, Result}; /// Build version string derived from git metadata. diff --git a/crates/openshell-server/src/lib.rs b/crates/openshell-server/src/lib.rs index c9ff7704c..a8d820b4d 100644 --- a/crates/openshell-server/src/lib.rs +++ b/crates/openshell-server/src/lib.rs @@ -23,7 +23,7 @@ mod tls; pub mod tracing_bus; mod ws_tunnel; -use openshell_core::{Config, Error, Result}; +use openshell_core::{ComputeDriverKind, Config, Error, Result}; use std::collections::HashMap; use std::io::ErrorKind; use std::sync::{Arc, Mutex}; @@ -129,26 +129,14 @@ pub async fn run_server(config: Config, tracing_log_bus: TracingLogBus) -> Resul let sandbox_index = SandboxIndex::new(); let sandbox_watch_bus = SandboxWatchBus::new(); - let compute = ComputeRuntime::new_kubernetes( - KubernetesComputeConfig { - namespace: config.sandbox_namespace.clone(), - default_image: config.sandbox_image.clone(), - image_pull_policy: config.sandbox_image_pull_policy.clone(), - grpc_endpoint: config.grpc_endpoint.clone(), - ssh_listen_addr: format!("0.0.0.0:{}", config.sandbox_ssh_port), - ssh_port: config.sandbox_ssh_port, - ssh_handshake_secret: config.ssh_handshake_secret.clone(), - ssh_handshake_skew_secs: config.ssh_handshake_skew_secs, - client_tls_secret_name: config.client_tls_secret_name.clone(), - host_gateway_ip: config.host_gateway_ip.clone(), - }, + let compute = build_compute_runtime( + &config, store.clone(), sandbox_index.clone(), sandbox_watch_bus.clone(), tracing_log_bus.clone(), ) - .await - .map_err(|e| Error::execution(format!("failed to create compute runtime: {e}")))?; + .await?; let state = Arc::new(ServerState::new( config.clone(), store.clone(), @@ -224,9 +212,67 @@ pub async fn run_server(config: Config, tracing_log_bus: TracingLogBus) -> Resul } } +async fn build_compute_runtime( + config: &Config, + store: Arc, + sandbox_index: SandboxIndex, + sandbox_watch_bus: SandboxWatchBus, + tracing_log_bus: TracingLogBus, +) -> Result { + let driver = configured_compute_driver(config)?; + info!(driver = %driver, "Using compute driver"); + + match driver { + ComputeDriverKind::Kubernetes => ComputeRuntime::new_kubernetes( + KubernetesComputeConfig { + namespace: config.sandbox_namespace.clone(), + default_image: config.sandbox_image.clone(), + image_pull_policy: config.sandbox_image_pull_policy.clone(), + grpc_endpoint: config.grpc_endpoint.clone(), + ssh_listen_addr: format!("0.0.0.0:{}", config.sandbox_ssh_port), + ssh_port: config.sandbox_ssh_port, + ssh_handshake_secret: config.ssh_handshake_secret.clone(), + ssh_handshake_skew_secs: config.ssh_handshake_skew_secs, + client_tls_secret_name: config.client_tls_secret_name.clone(), + host_gateway_ip: config.host_gateway_ip.clone(), + }, + store, + sandbox_index, + sandbox_watch_bus, + tracing_log_bus, + ) + .await + .map_err(|e| Error::execution(format!("failed to create compute runtime: {e}"))), + ComputeDriverKind::Podman => Err(Error::config( + "compute driver 'podman' is not implemented yet", + )), + } +} + +fn configured_compute_driver(config: &Config) -> Result { + match config.compute_drivers.as_slice() { + [] => Err(Error::config( + "at least one compute driver must be configured", + )), + [driver @ ComputeDriverKind::Kubernetes] => Ok(*driver), + [ComputeDriverKind::Podman] => Err(Error::config( + "compute driver 'podman' is not implemented yet", + )), + drivers => Err(Error::config(format!( + "multiple compute drivers are not supported yet; configured drivers: {}", + drivers + .iter() + .map(ToString::to_string) + .collect::>() + .join(",") + ))), + } +} + #[cfg(test)] mod tests { - use super::is_benign_tls_handshake_failure; + use super::{configured_compute_driver, is_benign_tls_handshake_failure}; + use openshell_core::{ComputeDriverKind, Config}; use std::io::{Error, ErrorKind}; #[test] @@ -248,4 +294,41 @@ mod tests { assert!(!is_benign_tls_handshake_failure(&error)); } } + + #[test] + fn configured_compute_driver_defaults_to_kubernetes() { + assert_eq!( + configured_compute_driver(&Config::new(None)).unwrap(), + ComputeDriverKind::Kubernetes + ); + } + + #[test] + fn configured_compute_driver_requires_at_least_one_entry() { + let config = Config::new(None).with_compute_drivers([]); + let err = configured_compute_driver(&config).unwrap_err(); + assert!(err.to_string().contains("at least one compute driver")); + } + + #[test] + fn configured_compute_driver_rejects_multiple_entries() { + let config = Config::new(None) + .with_compute_drivers([ComputeDriverKind::Kubernetes, ComputeDriverKind::Podman]); + let err = configured_compute_driver(&config).unwrap_err(); + assert!( + err.to_string() + .contains("multiple compute drivers are not supported yet") + ); + assert!(err.to_string().contains("kubernetes,podman")); + } + + #[test] + fn configured_compute_driver_rejects_unimplemented_driver() { + let config = Config::new(None).with_compute_drivers([ComputeDriverKind::Podman]); + let err = configured_compute_driver(&config).unwrap_err(); + assert!( + err.to_string() + .contains("compute driver 'podman' is not implemented yet") + ); + } } diff --git a/crates/openshell-server/src/main.rs b/crates/openshell-server/src/main.rs index 5178693a5..ed6c73825 100644 --- a/crates/openshell-server/src/main.rs +++ b/crates/openshell-server/src/main.rs @@ -5,6 +5,7 @@ use clap::Parser; use miette::{IntoDiagnostic, Result}; +use openshell_core::ComputeDriverKind; use std::net::SocketAddr; use std::path::PathBuf; use tracing::info; @@ -42,6 +43,21 @@ struct Args { #[arg(long, env = "OPENSHELL_DB_URL", required = true)] db_url: String, + /// Compute drivers configured for this gateway. + /// + /// Accepts a comma-delimited list such as `kubernetes` or + /// `kubernetes,podman`. The configuration format is future-proofed for + /// multiple drivers, but the gateway currently requires exactly one. + #[arg( + long, + alias = "driver", + env = "OPENSHELL_DRIVERS", + value_delimiter = ',', + default_value = "kubernetes", + value_parser = parse_compute_driver + )] + drivers: Vec, + /// Kubernetes namespace for sandboxes. #[arg(long, env = "OPENSHELL_SANDBOX_NAMESPACE", default_value = "default")] sandbox_namespace: String, @@ -157,6 +173,7 @@ async fn main() -> Result<()> { config = config .with_database_url(args.db_url) + .with_compute_drivers(args.drivers) .with_sandbox_namespace(args.sandbox_namespace) .with_ssh_gateway_host(args.ssh_gateway_host) .with_ssh_gateway_port(args.ssh_gateway_port) @@ -198,3 +215,7 @@ async fn main() -> Result<()> { run_server(config, tracing_log_bus).await.into_diagnostic() } + +fn parse_compute_driver(value: &str) -> std::result::Result { + value.parse() +}