diff --git a/.claude/agents/principal-engineer-reviewer.md b/.claude/agents/principal-engineer-reviewer.md index fe231f806..ae7e49ea2 100644 --- a/.claude/agents/principal-engineer-reviewer.md +++ b/.claude/agents/principal-engineer-reviewer.md @@ -25,8 +25,12 @@ OpenShell project. Your reviews balance three priorities equally: 3. **Security** — What are the threat surfaces? Are trust boundaries respected? Is input validated at system boundaries? Are secrets, credentials, and - tokens handled correctly? Think about the OWASP top 10, supply chain risks, - and privilege escalation. + tokens handled correctly? Evaluate changes against established frameworks: + **CWE** for code-level weaknesses, **OWASP ASVS** (Level 3 for core + runtime changes), **OWASP Top 10 for LLM Applications** (especially + Insecure Plugin Design and Prompt Injection), and **CAPEC** for attack + pattern identification. Consider supply chain risks and privilege + escalation paths. ## Project context @@ -95,6 +99,53 @@ Structure your review clearly: Omit empty sections. Keep it concise — density over length. +## Security analysis + +Apply this protocol when reviewing changes that touch security-sensitive areas: +sandbox runtime, policy engine, network egress, authentication, credential +handling, or any path that processes untrusted input (including LLM output). + +1. **Threat modeling** — Map the data flow for the change. Where does untrusted + input (from an LLM, user, or network) enter? Where does it exit (to a + shell, filesystem, network, or database)? Identify trust boundaries that + the change crosses. + +2. **Weakness mapping** — Tag every security concern with its **CWE ID**. This + makes findings actionable and trackable. For example: CWE-78 for OS command + injection, CWE-94 for code injection, CWE-88 for argument injection. + +3. **Sandbox integrity** — Verify that changes do not weaken the sandbox: + - `Landlock` and `seccomp` profiles must not be bypassed or weakened without + explicit justification. + - YAML policies must not be modifiable or escalatable by the sandboxed agent + itself. + - Default-deny posture must be preserved. + +4. **Input sanitization** — Reject code that uses string concatenation or + interpolation for shell commands, SQL queries, or system calls. Demand + parameterized execution or strict allow-list validation. + +5. **Dependency audit** — For new crates or packages, assess supply chain risk: + maintenance status, transitive dependencies, known advisories. + +### Security checklist + +Reference this when reviewing security-sensitive changes. Not every item +applies to every PR — use judgment. + +- **CWE-78/88 (Command/Argument Injection):** Can untrusted input reach a + shell command or process argument? +- **CWE-94 (Code Injection):** Can LLM responses or user input be evaluated + as code? +- **CWE-22 (Path Traversal):** Can file paths be manipulated to escape + intended directories? +- **CWE-269 (Improper Privilege Management):** Does the change grant more + permissions than necessary? +- **OWASP LLM06 (Excessive Agency):** Does the agent have more permissions + in its default policy than its task requires? +- **Supply chain:** Do new dependencies introduce known vulnerabilities or + unmaintained transitive dependencies? + ## Principles - Don't nitpick style unless it harms readability. Trust `rustfmt` and the diff --git a/.github/workflows/docs-preview-pr.yml b/.github/workflows/docs-preview-pr.yml index ccade7411..6c0672ba2 100644 --- a/.github/workflows/docs-preview-pr.yml +++ b/.github/workflows/docs-preview-pr.yml @@ -49,6 +49,7 @@ jobs: find _build -name .buildinfo -exec rm {} \; - name: Deploy preview + if: github.event.pull_request.head.repo.full_name == github.repository uses: rossjrw/pr-preview-action@v1 with: source-dir: ./_build/docs/ diff --git a/.github/workflows/e2e-test.yml b/.github/workflows/e2e-test.yml index f14ccb880..a89f4508f 100644 --- a/.github/workflows/e2e-test.yml +++ b/.github/workflows/e2e-test.yml @@ -19,9 +19,25 @@ permissions: jobs: e2e: - name: E2E + name: "E2E (${{ matrix.suite }})" runs-on: ${{ inputs.runner }} timeout-minutes: 30 + strategy: + fail-fast: false + matrix: + include: + - suite: python + cluster: e2e-python + port: "8080" + cmd: "mise run --no-prepare --skip-deps e2e:python" + - suite: rust + cluster: e2e-rust + port: "8081" + cmd: "mise run --no-prepare --skip-deps e2e:rust" + - suite: gateway-resume + cluster: e2e-resume + port: "8082" + cmd: "cargo test --manifest-path e2e/rust/Cargo.toml --features e2e --test gateway_resume" container: image: ghcr.io/nvidia/openshell/ci:latest credentials: @@ -38,6 +54,7 @@ jobs: OPENSHELL_REGISTRY_NAMESPACE: nvidia/openshell OPENSHELL_REGISTRY_USERNAME: ${{ github.actor }} OPENSHELL_REGISTRY_PASSWORD: ${{ secrets.GITHUB_TOKEN }} + OPENSHELL_GATEWAY: ${{ matrix.cluster }} steps: - uses: actions/checkout@v4 @@ -48,21 +65,26 @@ jobs: run: docker pull ghcr.io/nvidia/openshell/cluster:${{ inputs.image-tag }} - name: Install Python dependencies and generate protobuf stubs + if: matrix.suite == 'python' run: uv sync --frozen && mise run --no-prepare python:proto - - name: Bootstrap and deploy cluster + - name: Build Rust CLI + if: matrix.suite != 'python' + run: cargo build -p openshell-cli --features openshell-core/dev-settings + + - name: Install SSH client + if: matrix.suite != 'python' + run: apt-get update && apt-get install -y --no-install-recommends openssh-client && rm -rf /var/lib/apt/lists/* + + - name: Bootstrap cluster env: GATEWAY_HOST: host.docker.internal - GATEWAY_PORT: "8080" + GATEWAY_PORT: ${{ matrix.port }} + CLUSTER_NAME: ${{ matrix.cluster }} SKIP_IMAGE_PUSH: "1" SKIP_CLUSTER_IMAGE_BUILD: "1" OPENSHELL_CLUSTER_IMAGE: ghcr.io/nvidia/openshell/cluster:${{ inputs.image-tag }} run: mise run --no-prepare --skip-deps cluster - - name: Install SSH client for Rust CLI e2e tests - run: apt-get update && apt-get install -y --no-install-recommends openssh-client && rm -rf /var/lib/apt/lists/* - - - name: Run E2E tests - run: | - mise run --no-prepare --skip-deps e2e:python - mise run --no-prepare --skip-deps e2e:rust + - name: Run tests + run: ${{ matrix.cmd }} diff --git a/.github/workflows/issue-triage.yml b/.github/workflows/issue-triage.yml index 50bdd31e1..ec87af503 100644 --- a/.github/workflows/issue-triage.yml +++ b/.github/workflows/issue-triage.yml @@ -23,7 +23,7 @@ jobs: // The template placeholder starts with "Example:" — if that's still // there or the section is empty, the reporter didn't fill it in. const diagnosticMatch = body.match( - /## Agent Diagnostic\s*\n([\s\S]*?)(?=\n## |\n$)/ + /### Agent Diagnostic\s*\n([\s\S]*?)(?=\n### |\n$)/ ); const hasSubstantiveDiagnostic = diagnosticMatch diff --git a/.opencode/agents/principal-engineer-reviewer.md b/.opencode/agents/principal-engineer-reviewer.md index 452548d2c..68c3a86d0 100644 --- a/.opencode/agents/principal-engineer-reviewer.md +++ b/.opencode/agents/principal-engineer-reviewer.md @@ -25,8 +25,12 @@ OpenShell project. Your reviews balance three priorities equally: 3. **Security** — What are the threat surfaces? Are trust boundaries respected? Is input validated at system boundaries? Are secrets, credentials, and - tokens handled correctly? Think about the OWASP top 10, supply chain risks, - and privilege escalation. + tokens handled correctly? Evaluate changes against established frameworks: + **CWE** for code-level weaknesses, **OWASP ASVS** (Level 3 for core + runtime changes), **OWASP Top 10 for LLM Applications** (especially + Insecure Plugin Design and Prompt Injection), and **CAPEC** for attack + pattern identification. Consider supply chain risks and privilege + escalation paths. ## Project context @@ -95,6 +99,53 @@ Structure your review clearly: Omit empty sections. Keep it concise — density over length. +## Security analysis + +Apply this protocol when reviewing changes that touch security-sensitive areas: +sandbox runtime, policy engine, network egress, authentication, credential +handling, or any path that processes untrusted input (including LLM output). + +1. **Threat modeling** — Map the data flow for the change. Where does untrusted + input (from an LLM, user, or network) enter? Where does it exit (to a + shell, filesystem, network, or database)? Identify trust boundaries that + the change crosses. + +2. **Weakness mapping** — Tag every security concern with its **CWE ID**. This + makes findings actionable and trackable. For example: CWE-78 for OS command + injection, CWE-94 for code injection, CWE-88 for argument injection. + +3. **Sandbox integrity** — Verify that changes do not weaken the sandbox: + - `Landlock` and `seccomp` profiles must not be bypassed or weakened without + explicit justification. + - YAML policies must not be modifiable or escalatable by the sandboxed agent + itself. + - Default-deny posture must be preserved. + +4. **Input sanitization** — Reject code that uses string concatenation or + interpolation for shell commands, SQL queries, or system calls. Demand + parameterized execution or strict allow-list validation. + +5. **Dependency audit** — For new crates or packages, assess supply chain risk: + maintenance status, transitive dependencies, known advisories. + +### Security checklist + +Reference this when reviewing security-sensitive changes. Not every item +applies to every PR — use judgment. + +- **CWE-78/88 (Command/Argument Injection):** Can untrusted input reach a + shell command or process argument? +- **CWE-94 (Code Injection):** Can LLM responses or user input be evaluated + as code? +- **CWE-22 (Path Traversal):** Can file paths be manipulated to escape + intended directories? +- **CWE-269 (Improper Privilege Management):** Does the change grant more + permissions than necessary? +- **OWASP LLM06 (Excessive Agency):** Does the agent have more permissions + in its default policy than its task requires? +- **Supply chain:** Do new dependencies introduce known vulnerabilities or + unmaintained transitive dependencies? + ## Principles - Don't nitpick style unless it harms readability. Trust `rustfmt` and the diff --git a/AGENTS.md b/AGENTS.md index 79dc29d1b..979965941 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -35,6 +35,7 @@ These pipelines connect skills into end-to-end workflows. Individual skill files | `crates/openshell-policy/` | Policy engine | Filesystem, network, process, and inference constraints | | `crates/openshell-router/` | Privacy router | Privacy-aware LLM routing | | `crates/openshell-bootstrap/` | Cluster bootstrap | K3s cluster setup, image loading, mTLS PKI | +| `crates/openshell-ocsf/` | OCSF logging | OCSF v1.7.0 event types, builders, shorthand/JSONL formatters, tracing layers | | `crates/openshell-core/` | Shared core | Common types, configuration, error handling | | `crates/openshell-providers/` | Provider management | Credential provider backends | | `crates/openshell-tui/` | Terminal UI | Ratatui-based dashboard for monitoring | @@ -66,6 +67,85 @@ These pipelines connect skills into end-to-end workflows. Individual skill files - Store plan documents in `architecture/plans`. This is git ignored so its for easier access for humans. When asked to create Spikes or issues, you can skip to GitHub issues. Only use the plans dir when you aren't writing data somewhere else specific. - When asked to write a plan, write it there without asking for the location. +## Sandbox Logging (OCSF) + +When adding or modifying log emissions in `openshell-sandbox`, determine whether the event should use OCSF structured logging or plain `tracing`. + +### When to use OCSF + +Use an OCSF builder + `ocsf_emit!()` for events that represent **observable sandbox behavior** visible to operators, security teams, or agents monitoring the sandbox: + +- Network decisions (allow, deny, bypass detection) +- HTTP/L7 enforcement decisions +- SSH authentication (accepted, denied, nonce replay) +- Process lifecycle (start, exit, timeout, signal failure) +- Security findings (unsafe policy, unavailable controls, replay attacks) +- Configuration changes (policy load/reload, TLS setup, inference routes, settings) +- Application lifecycle (supervisor start, SSH server ready) + +### When to use plain tracing + +Use `info!()`, `debug!()`, `warn!()` for **internal operational plumbing** that doesn't represent a security decision or observable state change: + +- gRPC connection attempts and retries +- "About to do X" events where the result is logged separately +- Internal SSH channel state (unknown channel, PTY resize) +- Zombie process reaping, denial flush telemetry +- DEBUG/TRACE level diagnostics + +### Choosing the OCSF event class + +| Event type | Builder | When to use | +|---|---|---| +| TCP connections, proxy tunnels, bypass | `NetworkActivityBuilder` | L4 network decisions, proxy operational events | +| HTTP requests, L7 enforcement | `HttpActivityBuilder` | Per-request method/path decisions | +| SSH sessions | `SshActivityBuilder` | Authentication, channel operations | +| Process start/stop | `ProcessActivityBuilder` | Entrypoint lifecycle, signal failures | +| Security alerts | `DetectionFindingBuilder` | Nonce replay, bypass detection, unsafe policy. Dual-emit with the domain event. | +| Policy/config changes | `ConfigStateChangeBuilder` | Policy load, Landlock apply, TLS setup, inference routes, settings | +| Supervisor lifecycle | `AppLifecycleBuilder` | Sandbox start, SSH server ready/failed | + +### Severity guidelines + +| Severity | When | +|---|---| +| `Informational` | Allowed connections, successful operations, config loaded | +| `Low` | DNS failures, non-fatal operational warnings, LOG rule failures | +| `Medium` | Denied connections, policy violations, deprecated config | +| `High` | Security findings (nonce replay, Landlock unavailable) | +| `Critical` | Process timeout kills | + +### Example: adding a new network event + +```rust +use openshell_ocsf::{ + ocsf_emit, NetworkActivityBuilder, ActivityId, ActionId, + DispositionId, Endpoint, Process, SeverityId, StatusId, +}; + +let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host, port)) + .actor_process(Process::new(&binary, pid)) + .firewall_rule(&policy_name, &engine_type) + .message(format!("CONNECT denied {host}:{port}")) + .build(); +ocsf_emit!(event); +``` + +### Key points + +- `crate::ocsf_ctx()` returns the process-wide `SandboxContext`. It is always available (falls back to defaults in tests). +- `ocsf_emit!()` is non-blocking and cannot panic. It stores the event in a thread-local and emits via `tracing::info!()`. +- The shorthand layer and JSONL layer extract the event from the thread-local. The shorthand format is derived automatically from the builder fields. +- For security findings, **dual-emit**: one domain event (e.g., `SshActivityBuilder`) AND one `DetectionFindingBuilder` for the same incident. +- Never log secrets, credentials, or query parameters in OCSF messages. The OCSF JSONL file may be shipped to external systems. +- The `message` field should be a concise, grep-friendly summary. Details go in builder fields (dst_endpoint, firewall_rule, etc.). + ## Sandbox Infra Changes - If you change sandbox infrastructure, ensure `mise run sandbox` succeeds. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c25d30b92..1ebf71df2 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -186,9 +186,14 @@ These are the primary `mise` tasks for day-to-day development: | `tasks/` | `mise` task definitions and build scripts | | `deploy/` | Dockerfiles, Helm chart, Kubernetes manifests | | `architecture/` | Architecture docs and plans | +| `rfc/` | Request for Comments proposals | | `docs/` | User-facing documentation (Sphinx/MyST) | | `.agents/` | Agent skills and persona definitions | +## RFCs + +For cross-cutting architectural decisions, API contract changes, or process proposals that need broad consensus, use the RFC process. RFCs live in `rfc/` — copy the template, fill it in, and open a PR for discussion. See [rfc/README.md](rfc/README.md) for the full lifecycle and guidelines on when to write an RFC versus a spike issue or architecture doc. + ## Documentation If your change affects user-facing behavior (new flags, changed defaults, new features, bug fixes that contradict existing docs), update the relevant pages under `docs/` in the same PR. diff --git a/Cargo.lock b/Cargo.lock index 9d8247e5d..ac8d1d830 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -232,9 +232,9 @@ dependencies = [ [[package]] name = "aws-lc-rs" -version = "1.16.1" +version = "1.16.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94bffc006df10ac2a68c83692d734a465f8ee6c5b384d8545a636f81d858f4bf" +checksum = "a054912289d18629dc78375ba2c3726a3afe3ff71b4edba9dedfca0e3446d1fc" dependencies = [ "aws-lc-sys", "untrusted 0.7.1", @@ -243,9 +243,9 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.38.0" +version = "0.39.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4321e568ed89bb5a7d291a7f37997c2c0df89809d7b6d12062c81ddb54aa782e" +checksum = "83a25cf98105baa966497416dbd42565ce3a8cf8dbfd59803ec9ad46f3126399" dependencies = [ "cc", "cmake", @@ -2497,6 +2497,16 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "libyml" +version = "0.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3302702afa434ffa30847a83305f0a69d6abd74293b6554c18ec85c7ef30c980" +dependencies = [ + "anyhow", + "version_check", +] + [[package]] name = "linux-raw-sys" version = "0.4.15" @@ -2805,6 +2815,7 @@ checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" name = "openshell-bootstrap" version = "0.0.0" dependencies = [ + "async-stream", "base64 0.22.1", "bollard", "bytes", @@ -2901,7 +2912,7 @@ dependencies = [ "miette", "openshell-core", "serde", - "serde_yaml", + "serde_yml", ] [[package]] @@ -2921,7 +2932,7 @@ dependencies = [ "reqwest", "serde", "serde_json", - "serde_yaml", + "serde_yml", "tempfile", "thiserror 2.0.18", "tokio", @@ -2935,8 +2946,10 @@ name = "openshell-sandbox" version = "0.0.0" dependencies = [ "anyhow", + "base64 0.22.1", "bytes", "clap", + "futures", "hex", "hmac", "ipnet", @@ -2945,6 +2958,7 @@ dependencies = [ "miette", "nix", "openshell-core", + "openshell-ocsf", "openshell-policy", "openshell-router", "rand_core 0.6.4", @@ -2955,7 +2969,7 @@ dependencies = [ "rustls-pemfile", "seccompiler", "serde_json", - "serde_yaml", + "serde_yml", "sha2 0.10.9", "temp-env", "tempfile", @@ -2963,6 +2977,7 @@ dependencies = [ "tokio", "tokio-rustls", "tokio-stream", + "tokio-tungstenite 0.26.2", "tonic", "tracing", "tracing-appender", @@ -4344,6 +4359,21 @@ dependencies = [ "unsafe-libyaml", ] +[[package]] +name = "serde_yml" +version = "0.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59e2dd588bf1597a252c3b920e0143eb99b0f76e4e082f4c92ce34fbc9e71ddd" +dependencies = [ + "indexmap 2.13.0", + "itoa", + "libyml", + "memchr", + "ryu", + "serde", + "version_check", +] + [[package]] name = "serdect" version = "0.4.2" @@ -4898,9 +4928,9 @@ dependencies = [ [[package]] name = "tar" -version = "0.4.44" +version = "0.4.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d863878d212c87a19c1a610eb53bb01fe12951c0501cf5a0d65f724914a667a" +checksum = "22692a6476a21fa75fdfc11d452fda482af402c008cdbaf3476414e122040973" dependencies = [ "filetime", "libc", diff --git a/Cargo.toml b/Cargo.toml index 4fecf1940..08b699d47 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,7 +64,7 @@ nix = { version = "0.29", features = ["signal", "process", "user", "fs", "term"] # Serialization serde = { version = "1", features = ["derive"] } serde_json = "1" -serde_yaml = "0.9" +serde_yml = "0.0.12" # HTTP client reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } diff --git a/README.md b/README.md index 44f0bb13c..f2981f7af 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# OpenShell +# NVIDIA OpenShell [![License](https://img.shields.io/badge/License-Apache_2.0-blue)](https://github.com/NVIDIA/OpenShell/blob/main/LICENSE) [![PyPI](https://img.shields.io/badge/PyPI-openshell-orange?logo=pypi)](https://pypi.org/project/openshell/) @@ -128,7 +128,7 @@ OpenShell can pass host GPUs into sandboxes for local inference, fine-tuning, or openshell sandbox create --gpu --from [gpu-enabled-sandbox] -- claude ``` -The CLI auto-bootstraps a GPU-enabled gateway on first use. GPU intent is also inferred automatically for community images with `gpu` in the name. +The CLI auto-bootstraps a GPU-enabled gateway on first use, auto-selecting CDI when available and otherwise falling back to Docker's NVIDIA GPU request path (`--gpus all`). GPU intent is also inferred automatically for community images with `gpu` in the name. **Requirements:** NVIDIA drivers and the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) must be installed on the host. The sandbox image itself must include the appropriate GPU drivers and libraries for your workload — the default `base` image does not. See the [BYOC example](https://github.com/NVIDIA/OpenShell/tree/main/examples/bring-your-own-container) for building a custom sandbox image with GPU support. @@ -229,6 +229,10 @@ All implementation work is human-gated — agents propose plans, humans approve, OpenShell is built agent-first — your agent is your first collaborator. Before opening issues or submitting code, point your agent at the repo and let it use the skills in `.agents/skills/` to investigate, diagnose, and prototype. See [CONTRIBUTING.md](CONTRIBUTING.md) for the full agent skills table, contribution workflow, and development setup. +## Notice and Disclaimer + +This software automatically retrieves, accesses or interacts with external materials. Those retrieved materials are not distributed with this software and are governed solely by separate terms, conditions and licenses. You are solely responsible for finding, reviewing and complying with all applicable terms, conditions, and licenses, and for verifying the security, integrity and suitability of any retrieved materials for your specific use case. This software is provided "AS IS", without warranty of any kind. The author makes no representations or warranties regarding any retrieved materials, and assumes no liability for any losses, damages, liabilities or legal consequences from your use or inability to use this software or any retrieved materials. Use this software and the retrieved materials at your own risk. + ## License This project is licensed under the [Apache License 2.0](https://github.com/NVIDIA/OpenShell/blob/main/LICENSE). diff --git a/architecture/gateway-single-node.md b/architecture/gateway-single-node.md index 57aebd3a5..6389c728e 100644 --- a/architecture/gateway-single-node.md +++ b/architecture/gateway-single-node.md @@ -185,7 +185,7 @@ For the target daemon (local or remote): After the container starts: -1. **Clean stale nodes**: `clean_stale_nodes()` finds `NotReady` nodes via `kubectl get nodes` and deletes them. This is needed when a container is recreated but reuses the persistent volume -- k3s registers a new node (using the container ID as hostname) while old node entries persist in etcd. Non-fatal on error; returns the count of removed nodes. +1. **Clean stale nodes**: `clean_stale_nodes()` finds nodes whose name does not match the deterministic k3s `--node-name` and deletes them. That node name is derived from the gateway name but normalized to a Kubernetes-safe lowercase form so existing gateway names that contain `_`, `.`, or uppercase characters still produce a valid node identity. This cleanup is needed when a container is recreated but reuses the persistent volume -- old node entries can persist in etcd. Non-fatal on error; returns the count of removed nodes. 2. **Push local images** (optional, local deploy only): If `OPENSHELL_PUSH_IMAGES` is set, the comma-separated image refs are exported from the local Docker daemon as a single tar, uploaded into the container via `docker put_archive`, and imported into containerd via `ctr images import` in the `k8s.io` namespace. After import, `kubectl rollout restart deployment/openshell openshell` is run, followed by `kubectl rollout status --timeout=180s` to wait for completion. See `crates/openshell-bootstrap/src/push.rs`. 3. **Wait for gateway health**: `wait_for_gateway_ready()` polls the Docker HEALTHCHECK status up to 180 times, 2 seconds apart (6 min total). A background task streams container logs during this wait. Failure modes: - Container exits during polling: error includes recent log lines. @@ -260,7 +260,7 @@ On Docker custom networks, `/etc/resolv.conf` contains `127.0.0.11` (Docker's in 2. Getting the container's `eth0` IP as a routable address. 3. Adding DNAT rules in PREROUTING to forward DNS from pod namespaces through to Docker's DNS. 4. Writing a custom resolv.conf pointing to the container IP. -5. Passing `--resolv-conf=/etc/rancher/k3s/resolv.conf` to k3s. +5. Passing `--kubelet-arg=resolv-conf=/etc/rancher/k3s/resolv.conf` to k3s. Falls back to `8.8.8.8` / `8.8.4.4` if iptables detection fails. @@ -296,11 +296,13 @@ When environment variables are set, the entrypoint modifies the HelmChart manife GPU support is part of the single-node gateway bootstrap path rather than a separate architecture. -- `openshell gateway start --gpu` threads a boolean deploy option through `crates/openshell-cli`, `crates/openshell-bootstrap`, and `crates/openshell-bootstrap/src/docker.rs`. -- When enabled, the cluster container is created with Docker `DeviceRequests`, which is the API equivalent of `docker run --gpus all`. +- `openshell gateway start --gpu` threads GPU device options through `crates/openshell-cli`, `crates/openshell-bootstrap`, and `crates/openshell-bootstrap/src/docker.rs`. +- When enabled, the cluster container is created with Docker `DeviceRequests`. The injection mechanism is selected based on whether CDI is enabled on the daemon (`SystemInfo.CDISpecDirs` via `GET /info`): + - **CDI enabled** (daemon reports non-empty `CDISpecDirs`): CDI device injection — `driver="cdi"` with `nvidia.com/gpu=all`. Specs are expected to be pre-generated on the host (e.g. automatically by the `nvidia-cdi-refresh.service` or manually via `nvidia-ctk generate`). + - **CDI not enabled**: `--gpus all` device request — `driver="nvidia"`, `count=-1`, which relies on the NVIDIA Container Runtime hook. - `deploy/docker/Dockerfile.images` installs NVIDIA Container Toolkit packages in a dedicated Ubuntu stage and copies the runtime binaries, config, and `libnvidia-container` shared libraries into the final Ubuntu-based cluster image. - `deploy/docker/cluster-entrypoint.sh` checks `GPU_ENABLED=true` and copies GPU-only manifests from `/opt/openshell/gpu-manifests/` into k3s's manifests directory. -- `deploy/kube/gpu-manifests/nvidia-device-plugin-helmchart.yaml` installs the NVIDIA device plugin chart, currently pinned to `0.18.2`. NFD and GFD are disabled; the device plugin's default `nodeAffinity` (which requires `feature.node.kubernetes.io/pci-10de.present=true` or `nvidia.com/gpu.present=true` from NFD/GFD) is overridden to empty so the DaemonSet schedules on the single-node cluster without requiring those labels. +- `deploy/kube/gpu-manifests/nvidia-device-plugin-helmchart.yaml` installs the NVIDIA device plugin chart, currently pinned to `0.18.2`. NFD and GFD are disabled; the device plugin's default `nodeAffinity` (which requires `feature.node.kubernetes.io/pci-10de.present=true` or `nvidia.com/gpu.present=true` from NFD/GFD) is overridden to empty so the DaemonSet schedules on the single-node cluster without requiring those labels. The chart is configured with `deviceListStrategy: cdi-cri` so the device plugin injects devices via direct CDI device requests in the CRI. - k3s auto-detects `nvidia-container-runtime` on `PATH`, registers the `nvidia` containerd runtime, and creates the `nvidia` `RuntimeClass` automatically. - The OpenShell Helm chart grants the gateway service account cluster-scoped read access to `node.k8s.io/runtimeclasses` and core `nodes` so GPU sandbox admission can verify both the `nvidia` `RuntimeClass` and allocatable GPU capacity before creating a sandbox. @@ -308,13 +310,19 @@ The runtime chain is: ```text Host GPU drivers & NVIDIA Container Toolkit - └─ Docker: --gpus all (DeviceRequests in bollard API) + └─ Docker: DeviceRequests (CDI when enabled, --gpus all otherwise) └─ k3s/containerd: nvidia-container-runtime on PATH -> auto-detected └─ k8s: nvidia-device-plugin DaemonSet advertises nvidia.com/gpu - └─ Pods: request nvidia.com/gpu in resource limits + └─ Pods: request nvidia.com/gpu in resource limits (CDI injection — no runtimeClassName needed) ``` -The expected smoke test is a plain pod requesting `nvidia.com/gpu: 1` with `runtimeClassName: nvidia` and running `nvidia-smi`. +### `--gpu` flag + +The `--gpu` flag on `gateway start` enables GPU passthrough. OpenShell auto-selects CDI when enabled on the daemon and falls back to Docker's NVIDIA GPU request path (`--gpus all`) otherwise. + +Device injection uses CDI (`deviceListStrategy: cdi-cri`): the device plugin injects devices via direct CDI device requests in the CRI. Sandbox pods only need `nvidia.com/gpu: 1` in their resource limits, and GPU pods do not set `runtimeClassName`. + +The expected smoke test is a plain pod requesting `nvidia.com/gpu: 1` without `runtimeClassName` and running `nvidia-smi`. ## Remote Image Transfer @@ -381,7 +389,7 @@ When `openshell sandbox create` cannot connect to a gateway (connection refused, 1. `should_attempt_bootstrap()` in `crates/openshell-cli/src/bootstrap.rs` checks the error type. It returns `true` for connectivity errors and missing default TLS materials, but `false` for TLS handshake/auth errors. 2. If running in a terminal, the user is prompted to confirm. 3. `run_bootstrap()` deploys a gateway named `"openshell"`, sets it as active, and returns fresh `TlsOptions` pointing to the newly-written mTLS certs. -4. When `sandbox create` requests GPU explicitly (`--gpu`) or infers it from an image whose final name component contains `gpu` (such as `nvidia-gpu`), the bootstrap path enables gateway GPU support before retrying sandbox creation. +4. When `sandbox create` requests GPU explicitly (`--gpu`) or infers it from an image whose final name component contains `gpu` (such as `nvidia-gpu`), the bootstrap path enables gateway GPU support before retrying sandbox creation, using the same CDI-or-fallback selection as `gateway start --gpu`. ## Container Environment Variables diff --git a/architecture/gateway.md b/architecture/gateway.md index 39f97c8c1..72574410d 100644 --- a/architecture/gateway.md +++ b/architecture/gateway.md @@ -501,7 +501,7 @@ The Helm chart template is at `deploy/helm/openshell/templates/statefulset.yaml` `SandboxClient` (`crates/openshell-server/src/sandbox/mod.rs`) manages `agents.x-k8s.io/v1alpha1/Sandbox` CRDs. -- **Create**: Translates a `Sandbox` proto into a Kubernetes `DynamicObject` with labels (`openshell.ai/sandbox-id`, `openshell.ai/managed-by: openshell`) and a spec that includes the pod template, environment variables, and gateway-required env vars (`OPENSHELL_SANDBOX_ID`, `OPENSHELL_ENDPOINT`, `OPENSHELL_SSH_LISTEN_ADDR`, etc.). +- **Create**: Translates a `Sandbox` proto into a Kubernetes `DynamicObject` with labels (`openshell.ai/sandbox-id`, `openshell.ai/managed-by: openshell`) and a spec that includes the pod template, environment variables, and gateway-required env vars (`OPENSHELL_SANDBOX_ID`, `OPENSHELL_ENDPOINT`, `OPENSHELL_SSH_LISTEN_ADDR`, etc.). When callers do not provide custom `volumeClaimTemplates`, the server injects a default `workspace` PVC and mounts it at `/sandbox` so the default sandbox home/workdir survives pod rescheduling. - **Delete**: Calls the Kubernetes API to delete the CRD by name. Returns `false` if already gone (404). - **Pod IP resolution**: `agent_pod_ip()` fetches the agent pod and reads `status.podIP`. diff --git a/architecture/inference-routing.md b/architecture/inference-routing.md index 0d3a95afb..9d45d7cd9 100644 --- a/architecture/inference-routing.md +++ b/architecture/inference-routing.md @@ -92,10 +92,10 @@ File: `proto/inference.proto` Key messages: -- `SetClusterInferenceRequest` -- `provider_name` + `model_id` + optional `no_verify` override, with verification enabled by default -- `SetClusterInferenceResponse` -- `provider_name` + `model_id` + `version` +- `SetClusterInferenceRequest` -- `provider_name` + `model_id` + `timeout_secs` + optional `no_verify` override, with verification enabled by default +- `SetClusterInferenceResponse` -- `provider_name` + `model_id` + `timeout_secs` + `version` - `GetInferenceBundleResponse` -- `repeated ResolvedRoute routes` + `revision` + `generated_at_ms` -- `ResolvedRoute` -- `name`, `base_url`, `protocols`, `api_key`, `model_id`, `provider_type` +- `ResolvedRoute` -- `name`, `base_url`, `protocols`, `api_key`, `model_id`, `provider_type`, `timeout_secs` ## Data Plane (Sandbox) @@ -106,7 +106,7 @@ Files: - `crates/openshell-sandbox/src/lib.rs` -- inference context initialization, route refresh - `crates/openshell-sandbox/src/grpc_client.rs` -- `fetch_inference_bundle()` -In cluster mode, the sandbox starts a background refresh loop as soon as the inference context is created. The loop polls the gateway every 5 seconds by default (`OPENSHELL_ROUTE_REFRESH_INTERVAL_SECS` override) and uses the bundle revision hash to skip no-op cache writes. +In cluster mode, the sandbox starts a background refresh loop as soon as the inference context is created. The loop polls the gateway every 5 seconds by default (`OPENSHELL_ROUTE_REFRESH_INTERVAL_SECS` override) and uses the bundle revision hash to skip no-op cache writes. The revision hash covers all route fields including `timeout_secs`, so any configuration change (provider, model, or timeout) triggers a cache update on the next poll. ### Interception flow @@ -143,7 +143,7 @@ If no pattern matches, the proxy returns `403 Forbidden` with `{"error": "connec ### Route cache - `InferenceContext` holds a `Router`, the pattern list, and an `Arc>>` route cache. -- In cluster mode, `spawn_route_refresh()` polls `GetInferenceBundle` every 30 seconds (`ROUTE_REFRESH_INTERVAL_SECS`). On failure, stale routes are kept. +- In cluster mode, `spawn_route_refresh()` polls `GetInferenceBundle` every 5 seconds (`OPENSHELL_ROUTE_REFRESH_INTERVAL_SECS`). On failure, stale routes are kept. - In file mode (`--inference-routes`), routes load once at startup from YAML. No refresh task is spawned. - In cluster mode, an empty initial bundle still enables the inference context so the refresh task can pick up later configuration. @@ -209,9 +209,11 @@ File: `crates/openshell-router/src/mock.rs` Routes with `mock://` scheme endpoints return canned responses without making HTTP requests. Mock responses are protocol-aware (OpenAI chat completion, OpenAI completion, Anthropic messages, or generic JSON). Mock routes include an `x-openshell-mock: true` response header. -### HTTP client +### Per-request timeout -The router uses a `reqwest::Client` with a 60-second timeout. Timeouts and connection failures map to `RouterError::UpstreamUnavailable`. +Each `ResolvedRoute` carries a `timeout` field (`Duration`). The `reqwest::Client` has no global timeout; instead, each outgoing request applies `.timeout(route.timeout)` on the request builder. When `timeout_secs` is `0` in the proto message, the default of 60 seconds is used (defined as `DEFAULT_ROUTE_TIMEOUT` in `config.rs`). Timeouts and connection failures map to `RouterError::UpstreamUnavailable`. + +Timeout changes propagate dynamically to running sandboxes. The bundle revision hash includes `timeout_secs`, so when the timeout is updated via `openshell inference update --timeout`, the refresh loop detects the revision change and updates the route cache within one polling interval (5 seconds by default). ## Standalone Route File @@ -297,13 +299,16 @@ The system route is stored as a separate `InferenceRoute` record in the gateway Cluster inference commands: -- `openshell inference set --provider --model ` -- configures user-facing cluster inference -- `openshell inference set --system --provider --model ` -- configures system inference +- `openshell inference set --provider --model [--timeout ]` -- configures user-facing cluster inference +- `openshell inference set --system --provider --model [--timeout ]` -- configures system inference +- `openshell inference update [--provider ] [--model ] [--timeout ]` -- updates individual fields without resetting others - `openshell inference get` -- displays both user and system inference configuration - `openshell inference get --system` -- displays only the system inference configuration The `--provider` flag references a provider record name (not a provider type). The provider must already exist in the cluster and have a supported inference type (`openai`, `anthropic`, or `nvidia`). +The `--timeout` flag sets the per-request timeout in seconds for upstream inference calls. When omitted or set to `0`, the default of 60 seconds applies. Timeout changes propagate to running sandboxes within the route refresh interval (5 seconds by default). + Inference writes verify by default. `--no-verify` is the explicit opt-out for endpoints that are not up yet. ## Provider Discovery diff --git a/architecture/sandbox-providers.md b/architecture/sandbox-providers.md index 16b7948bc..fe5d48a97 100644 --- a/architecture/sandbox-providers.md +++ b/architecture/sandbox-providers.md @@ -305,18 +305,31 @@ start from `env_clear()`, so the handshake secret is not present there. ### Proxy-Time Secret Resolution -When a sandboxed tool uses one of these placeholder env vars to populate an outbound HTTP -header (for example `Authorization: Bearer openshell:resolve:env:ANTHROPIC_API_KEY`), the -sandbox proxy rewrites the placeholder to the real secret value immediately before the -request is forwarded upstream. - -This applies to: - -- forward-proxy HTTP requests, and -- L7-inspected REST requests inside CONNECT tunnels. +When a sandboxed tool uses one of these placeholder env vars in an outbound HTTP request, +the sandbox proxy rewrites the placeholder to the real secret value immediately before the +request is forwarded upstream. Placeholders are resolved in four locations: + +- **HTTP header values** — exact match (`x-api-key: openshell:resolve:env:KEY`), prefixed + match (`Authorization: Bearer openshell:resolve:env:KEY`), and Base64-decoded Basic auth + tokens (`Authorization: Basic `) +- **URL query parameters** — for APIs that authenticate via query string + (e.g., `?key=openshell:resolve:env:YOUTUBE_API_KEY`) +- **URL path segments** — for APIs that embed tokens in the URL path + (e.g., `/bot/sendMessage` for Telegram Bot API) + +This applies to forward-proxy HTTP requests, L7-inspected REST requests inside CONNECT +tunnels, and credential-injection-only passthrough relays on TLS-terminated connections. + +All rewriting fails closed: if any `openshell:resolve:env:*` placeholder is detected but +cannot be resolved, the proxy rejects the request with HTTP 500 instead of forwarding the +raw placeholder upstream. Resolved secret values are validated for prohibited control +characters (CR, LF, null byte) to prevent header injection (CWE-113). Path segment +credentials are additionally validated to reject traversal sequences, path separators, and +URI delimiters (CWE-22). The real secret value remains in supervisor memory only; it is not re-injected into the -child process environment. +child process environment. See [Credential injection](sandbox.md#credential-injection) for +the full implementation details, encoding rules, and security properties. ### End-to-End Flow diff --git a/architecture/sandbox.md b/architecture/sandbox.md index 1117d0f71..c5e212f85 100644 --- a/architecture/sandbox.md +++ b/architecture/sandbox.md @@ -24,7 +24,7 @@ All paths are relative to `crates/openshell-sandbox/src/`. | `sandbox/mod.rs` | Platform abstraction -- dispatches to Linux or no-op | | `sandbox/linux/mod.rs` | Linux composition: Landlock then seccomp | | `sandbox/linux/landlock.rs` | Filesystem isolation via Landlock LSM (ABI V1) | -| `sandbox/linux/seccomp.rs` | Syscall filtering via BPF on `SYS_socket` | +| `sandbox/linux/seccomp.rs` | Syscall filtering via BPF: socket domain blocks, dangerous syscall blocks, conditional flag blocks | | `bypass_monitor.rs` | Background `/dev/kmsg` reader for iptables bypass detection events | | `sandbox/linux/netns.rs` | Network namespace creation, veth pair setup, bypass detection iptables rules, cleanup on drop | | `l7/mod.rs` | L7 types (`L7Protocol`, `TlsMode`, `EnforcementMode`, `L7EndpointConfig`), config parsing, validation, access preset expansion, deprecated `tls` value handling | @@ -33,6 +33,7 @@ All paths are relative to `crates/openshell-sandbox/src/`. | `l7/relay.rs` | Protocol-aware bidirectional relay with per-request OPA evaluation, credential-injection-only passthrough relay | | `l7/rest.rs` | HTTP/1.1 request/response parsing, body framing (Content-Length, chunked), deny response generation | | `l7/provider.rs` | `L7Provider` trait and `L7Request`/`BodyLength` types | +| `secrets.rs` | `SecretResolver` credential placeholder system — placeholder generation, multi-location rewriting (headers, query params, path segments, Basic auth), fail-closed scanning, secret validation, percent-encoding | ## Startup and Orchestration @@ -431,26 +432,26 @@ Landlock restricts the child process's filesystem access to an explicit allowlis 1. Build path lists from `filesystem.read_only` and `filesystem.read_write` 2. If `include_workdir` is true, add the working directory to `read_write` 3. If both lists are empty, skip Landlock entirely (no-op) -4. Create a Landlock ruleset targeting ABI V1: +4. Create a Landlock ruleset targeting ABI V2: - Read-only paths receive `AccessFs::from_read(abi)` rights - Read-write paths receive `AccessFs::from_all(abi)` rights -5. Call `ruleset.restrict_self()` -- this applies to the calling process and all descendants +5. For each path, attempt `PathFd::new()`. If it fails: + - `BestEffort`: Log a warning with the error classification (not found, permission denied, symlink loop, etc.) and skip the path. Continue building the ruleset from remaining valid paths. + - `HardRequirement`: Return a fatal error, aborting the sandbox. +6. If all paths failed (zero rules applied), return an error rather than calling `restrict_self()` on an empty ruleset (which would block all filesystem access) +7. Call `ruleset.restrict_self()` -- this applies to the calling process and all descendants -Error behavior depends on `LandlockCompatibility`: +Kernel-level error behavior (e.g., Landlock ABI unavailable) depends on `LandlockCompatibility`: - `BestEffort`: Log a warning and continue without filesystem isolation - `HardRequirement`: Return a fatal error, aborting the sandbox +**Baseline path filtering**: System-injected baseline paths (e.g., `/app`) are pre-filtered by `enrich_proto_baseline_paths()` / `enrich_sandbox_baseline_paths()` using `Path::exists()` before they reach Landlock. User-specified paths are not pre-filtered -- they are evaluated at Landlock apply time so misconfigurations surface as warnings or errors. + ### Seccomp syscall filtering **File:** `crates/openshell-sandbox/src/sandbox/linux/seccomp.rs` -Seccomp blocks socket creation for specific address families. The filter targets a single syscall (`SYS_socket`) and inspects argument 0 (the domain). - -**Always blocked** (regardless of network mode): -- `AF_NETLINK`, `AF_PACKET`, `AF_BLUETOOTH`, `AF_VSOCK` - -**Additionally blocked in `Block` mode** (no proxy): -- `AF_INET`, `AF_INET6` +Seccomp provides three layers of syscall restriction: socket domain blocks, unconditional syscall blocks, and conditional syscall blocks. The filter uses a default-allow policy (`SeccompAction::Allow`) with targeted rules that return `Errno(EPERM)`. **Skipped entirely** in `Allow` mode. @@ -458,8 +459,44 @@ Setup: 1. `prctl(PR_SET_NO_NEW_PRIVS, 1)` -- required before seccomp 2. `seccompiler::apply_filter()` with default action `Allow` and per-rule action `Errno(EPERM)` +#### Socket domain blocks + +| Domain | Always blocked | Additionally blocked in Block mode | +|--------|:-:|:-:| +| `AF_PACKET` | Yes | | +| `AF_BLUETOOTH` | Yes | | +| `AF_VSOCK` | Yes | | +| `AF_INET` | | Yes | +| `AF_INET6` | | Yes | +| `AF_NETLINK` | | Yes | + In `Proxy` mode, `AF_INET`/`AF_INET6` are allowed because the sandboxed process needs to connect to the proxy over the veth pair. The network namespace ensures it can only reach the proxy's IP (`10.200.0.1`). +#### Unconditional syscall blocks + +These syscalls are blocked entirely (EPERM for any invocation): + +| Syscall | Reason | +|---------|--------| +| `memfd_create` | Fileless binary execution bypasses Landlock filesystem restrictions | +| `ptrace` | Cross-process memory inspection and code injection | +| `bpf` | Kernel BPF program loading | +| `process_vm_readv` | Cross-process memory read | +| `io_uring_setup` | Async I/O subsystem with extensive CVE history | +| `mount` | Filesystem mount could subvert Landlock or overlay writable paths | + +#### Conditional syscall blocks + +These syscalls are only blocked when specific flag patterns are present: + +| Syscall | Condition | Reason | +|---------|-----------|--------| +| `execveat` | `AT_EMPTY_PATH` flag set (arg4) | Fileless execution from an anonymous fd | +| `unshare` | `CLONE_NEWUSER` flag set (arg0) | User namespace creation enables privilege escalation | +| `seccomp` | operation == `SECCOMP_SET_MODE_FILTER` (arg0) | Prevents sandboxed code from replacing the active filter | + +Conditional blocks use `MaskedEq` for flag checks (bit-test) and `Eq` for exact-value matches. This allows normal use of these syscalls while blocking the dangerous flag combinations. + ### Network namespace isolation **File:** `crates/openshell-sandbox/src/sandbox/linux/netns.rs` @@ -818,11 +855,13 @@ When `Router::proxy_with_candidates()` returns an error, `router_error_to_http() | `RouterError` variant | HTTP status | Response body | |----------------------|-------------|---------------| -| `RouteNotFound(hint)` | `400` | `no route configured for route '{hint}'` | -| `NoCompatibleRoute(protocol)` | `400` | `no compatible route for source protocol '{protocol}'` | -| `Unauthorized(msg)` | `401` | `{msg}` | -| `UpstreamUnavailable(msg)` | `503` | `{msg}` | -| `UpstreamProtocol(msg)` / `Internal(msg)` | `502` | `{msg}` | +| `RouteNotFound(_)` | `400` | `no inference route configured` | +| `NoCompatibleRoute(_)` | `400` | `no compatible inference route available` | +| `Unauthorized(_)` | `401` | `unauthorized` | +| `UpstreamUnavailable(_)` | `503` | `inference service unavailable` | +| `UpstreamProtocol(_)` / `Internal(_)` | `502` | `inference service error` | + +Response messages are generic — internal details (upstream URLs, hostnames, TLS errors, route hints) are never exposed to the sandboxed process. Full error context is logged server-side at `warn` level. ### Inference routing context @@ -962,7 +1001,7 @@ flowchart LR | `EnforcementMode` | `Audit`, `Enforce` | What to do on L7 deny (log-only vs block) | | `L7EndpointConfig` | `{ protocol, tls, enforcement }` | Per-endpoint L7 configuration | | `L7Decision` | `{ allowed, reason, matched_rule }` | Result of L7 evaluation | -| `L7RequestInfo` | `{ action, target }` | HTTP method + path for policy evaluation | +| `L7RequestInfo` | `{ action, target, query_params }` | HTTP method, path, and decoded query multimap for policy evaluation | ### Access presets @@ -1021,27 +1060,138 @@ TLS termination is automatic. The proxy peeks the first bytes of every CONNECT t System CA bundles are searched at well-known paths: `/etc/ssl/certs/ca-certificates.crt` (Debian/Ubuntu), `/etc/pki/tls/certs/ca-bundle.crt` (RHEL), `/etc/ssl/ca-bundle.pem` (openSUSE), `/etc/ssl/cert.pem` (Alpine/macOS). -### Credential-injection-only relay +### Credential injection + +**Files:** `crates/openshell-sandbox/src/secrets.rs`, `crates/openshell-sandbox/src/l7/relay.rs`, `crates/openshell-sandbox/src/l7/rest.rs`, `crates/openshell-sandbox/src/proxy.rs` + +The sandbox proxy resolves `openshell:resolve:env:*` credential placeholders in outbound HTTP requests. The `SecretResolver` holds a supervisor-only map from placeholder strings to real secret values, constructed at startup from the provider environment. Child processes only see placeholder values in their environment; the proxy rewrites them to real secrets immediately before forwarding upstream. + +#### `SecretResolver` + +```rust +pub(crate) struct SecretResolver { + by_placeholder: HashMap, +} +``` + +`SecretResolver::from_provider_env()` splits the provider environment into two maps: a child-visible map with placeholder values (`openshell:resolve:env:ANTHROPIC_API_KEY`) and a supervisor-only resolver map (`{"openshell:resolve:env:ANTHROPIC_API_KEY": "sk-real-key"}`). The placeholder grammar is `openshell:resolve:env:[A-Za-z_][A-Za-z0-9_]*`. + +#### Credential placement locations + +The resolver rewrites placeholders in four locations within HTTP requests: + +| Location | Example | Encoding | Implementation | +|----------|---------|----------|----------------| +| Header value (exact) | `x-api-key: openshell:resolve:env:KEY` | None (raw replacement) | `rewrite_header_value()` | +| Header value (prefixed) | `Authorization: Bearer openshell:resolve:env:KEY` | None (prefix preserved) | `rewrite_header_value()` | +| Basic auth token | `Authorization: Basic ` | Base64 decode → resolve → re-encode | `rewrite_basic_auth_token()` | +| URL query parameter | `?key=openshell:resolve:env:KEY` | Percent-decode → resolve → percent-encode (RFC 3986 unreserved) | `rewrite_uri_query_params()` | +| URL path segment | `/bot/sendMessage` | Percent-decode → resolve → validate → percent-encode (RFC 3986 pchar) | `rewrite_uri_path()` → `rewrite_path_segment()` | + +**Header values**: Direct match replaces the entire value. Prefixed match (e.g., `Bearer `) splits on whitespace, resolves the placeholder portion, and reassembles. Basic auth match detects `Authorization: Basic `, decodes the Base64 content, resolves any placeholders in the decoded `user:password` string, and re-encodes. + +**Query parameters**: Each `key=value` pair is checked. Values are percent-decoded before resolution and percent-encoded after (RFC 3986 Section 2.3 unreserved characters preserved: `ALPHA / DIGIT / "-" / "." / "_" / "~"`). + +**Path segments**: Handles substring matching for APIs that embed tokens within path segments (e.g., Telegram's `/bot{TOKEN}/sendMessage`). Each segment is percent-decoded, scanned for placeholder boundaries using the env var key grammar (`[A-Za-z_][A-Za-z0-9_]*`), resolved, validated for path safety, and percent-encoded per RFC 3986 Section 3.3 pchar rules (`unreserved / sub-delims / ":" / "@"`). + +#### Path credential validation (CWE-22) + +Resolved credential values destined for URL path segments are validated by `validate_credential_for_path()` before insertion. The following values are rejected: + +| Pattern | Rejection reason | +|---------|-----------------| +| `../`, `..\\`, `..` | Path traversal sequence | +| `/`, `\` | Path separator | +| `\0`, `\r`, `\n` | Control character | +| `?`, `#` | URI delimiter | + +Rejection causes the request to fail closed (HTTP 500). + +#### Secret value validation (CWE-113) + +All resolved credential values are validated at the `resolve_placeholder()` level for prohibited control characters: CR (`\r`), LF (`\n`), and null byte (`\0`). This prevents HTTP header injection via malicious credential values. The validation applies to all placement locations automatically — header values, query parameters, and path segments all pass through `resolve_placeholder()`. + +#### Fail-closed behavior + +All placeholder rewriting fails closed. If any `openshell:resolve:env:*` placeholder is detected in the request but cannot be resolved, the proxy rejects the request with HTTP 500 instead of forwarding the raw placeholder to the upstream. The fail-closed mechanism operates at two levels: + +1. **Per-location**: Each rewrite function (`rewrite_uri_query_params`, `rewrite_path_segment`, `rewrite_header_line`) returns an `UnresolvedPlaceholderError` when a placeholder is detected but the resolver has no mapping for it. + +2. **Final scan**: After all rewriting completes, `rewrite_http_header_block()` scans the output for any remaining `openshell:resolve:env:` tokens. It also checks the percent-decoded form of the request line to catch encoded placeholder bypass attempts (e.g., `openshell%3Aresolve%3Aenv%3AUNKNOWN`). + +```rust +pub(crate) struct UnresolvedPlaceholderError { + pub location: &'static str, // "header", "query_param", "path" +} +``` + +#### Rewrite-before-OPA with redaction + +When L7 inspection is active, credential placeholders in the request target (path + query) are resolved BEFORE OPA L7 policy evaluation. This is implemented in `relay_with_inspection()` and `relay_passthrough_with_credentials()` in `l7/relay.rs`: + +1. `rewrite_target_for_eval()` resolves the request target, producing two strings: + - **Resolved**: real secrets inserted — used only for the upstream connection + - **Redacted**: `[CREDENTIAL]` markers in place of secrets — used for OPA input and logs + +2. OPA `evaluate_l7_request()` receives the redacted path in `request.path`, so policy rules never see real credential values. + +3. All log statements (`L7_REQUEST`, `HTTP_REQUEST`) use the redacted target. Real credential values never appear in logs. + +4. The resolved path (with real secrets) goes only to the upstream via `relay_http_request_with_resolver()`. + +```rust +pub(crate) struct RewriteTargetResult { + pub resolved: String, // for upstream forwarding only + pub redacted: String, // for OPA + logs +} +``` + +If credential resolution fails on the request target, the relay returns HTTP 500 and closes the connection. + +#### Credential-injection-only relay **File:** `crates/openshell-sandbox/src/l7/relay.rs` (`relay_passthrough_with_credentials()`) -When TLS is auto-terminated but no L7 policy (`protocol` + `access`/`rules`) is configured on the endpoint, the proxy enters a passthrough mode that still provides value: it parses HTTP requests minimally to rewrite credential placeholders (via `SecretResolver`) and logs each request for observability. This relay: +When TLS is auto-terminated but no L7 policy (`protocol` + `access`/`rules`) is configured on the endpoint, the proxy enters a passthrough mode that still provides credential injection and observability. This relay: 1. Reads each HTTP request from the client via `RestProvider::parse_request()` -2. Logs the request method, path, host, and port at `info!()` level (tagged `"HTTP relay (credential injection)"`) -3. Forwards the request to upstream via `relay_http_request_with_resolver()`, which rewrites headers containing `openshell:resolve:env:*` placeholders with actual provider credential values -4. Relays the upstream response back to the client -5. Loops for HTTP keep-alive; exits on client close or non-reusable response +2. Resolves and redacts the request target via `rewrite_target_for_eval()` (for log safety) +3. Logs the request method, redacted path, host, and port at `info!()` level (tagged `HTTP_REQUEST`) +4. Forwards the request to upstream via `relay_http_request_with_resolver()`, which rewrites all credential placeholders in headers, query parameters, path segments, and Basic auth tokens +5. Relays the upstream response back to the client +6. Loops for HTTP keep-alive; exits on client close or non-reusable response This enables credential injection on all HTTPS endpoints automatically, without requiring the policy author to add `protocol: rest` and `access: full` just to get credentials injected. +#### Known limitation: host-binding + +The resolver resolves all placeholders regardless of destination host. If an agent has OPA-allowed access to an attacker-controlled host, it could construct a URL containing a placeholder and exfiltrate the resolved credential value to that host. OPA host restrictions are the defense — only endpoints explicitly allowed by policy receive traffic. Per-credential host binding (restricting which credentials resolve for which destination hosts) is not implemented. + +#### Data flow + +```mermaid +sequenceDiagram + participant A as Agent Process + participant P as Proxy (SecretResolver) + participant O as OPA Engine + participant U as Upstream API + + A->>P: GET /bot/send?key= HTTP/1.1
Authorization: Bearer + P->>P: rewrite_target_for_eval(target)
→ resolved: /bot{secret}/send?key={secret}
→ redacted: /bot[CREDENTIAL]/send?key=[CREDENTIAL] + P->>O: evaluate_l7_request(redacted path) + O-->>P: allow + P->>P: rewrite_http_header_block(headers)
→ resolve header placeholders
→ resolve query param placeholders
→ resolve path segment placeholders
→ fail-closed scan + P->>U: GET /bot{secret}/send?key={secret} HTTP/1.1
Authorization: Bearer {secret} + Note over P: Logs use redacted path only +``` + ### REST protocol provider **File:** `crates/openshell-sandbox/src/l7/rest.rs` Implements `L7Provider` for HTTP/1.1: -- **`parse_request()`**: Reads up to 16 KiB of headers, parses the request line (method, path), determines body framing from `Content-Length` or `Transfer-Encoding: chunked` headers. Returns `L7Request` with raw header bytes (may include overflow body bytes). +- **`parse_request()`**: Reads up to 16 KiB of headers, parses the request line (method, path), decodes query parameters into a multimap, determines body framing from `Content-Length` or `Transfer-Encoding: chunked` headers. Returns `L7Request` with raw header bytes (may include overflow body bytes). - **`relay()`**: Forwards request headers and body to upstream (handling Content-Length, chunked, and no-body cases), then reads and relays the full response back to the client. @@ -1054,11 +1204,12 @@ Implements `L7Provider` for HTTP/1.1: `relay_with_inspection()` in `crates/openshell-sandbox/src/l7/relay.rs` is the main relay loop: 1. Parse one HTTP request from client via the provider -2. Build L7 input JSON with `request.method`, `request.path`, plus the CONNECT-level context (host, port, binary, ancestors, cmdline) -3. Evaluate `data.openshell.sandbox.allow_request` and `data.openshell.sandbox.request_deny_reason` -4. Log the L7 decision (tagged `L7_REQUEST`) -5. If allowed (or audit mode): relay request to upstream and response back to client, then loop -6. If denied in enforce mode: send 403 and close the connection +2. Resolve credential placeholders in the request target via `rewrite_target_for_eval()`. OPA receives the redacted path (`[CREDENTIAL]` markers); the resolved path goes only to upstream. If resolution fails, return HTTP 500 and close the connection. +3. Build L7 input JSON with `request.method`, the **redacted** `request.path`, `request.query_params`, plus the CONNECT-level context (host, port, binary, ancestors, cmdline) +4. Evaluate `data.openshell.sandbox.allow_request` and `data.openshell.sandbox.request_deny_reason` +5. Log the L7 decision (tagged `L7_REQUEST`) using the redacted target — real credential values never appear in logs +6. If allowed (or audit mode): relay request to upstream via `relay_http_request_with_resolver()` (which rewrites all remaining credential placeholders in headers, query parameters, path segments, and Basic auth tokens) and relay the response back to client, then loop +7. If denied in enforce mode: send 403 (using redacted target in the response body) and close the connection ## Process Identity @@ -1311,6 +1462,10 @@ The sandbox uses `miette` for error reporting and `thiserror` for typed errors. | Log push gRPC stream breaks | Push loop exits, flushes remaining batch | | Proxy accept error | Log + break accept loop | | Benign connection close (EOF, reset, pipe) | Debug level (not visible to user by default) | +| Credential injection: unresolved placeholder detected | HTTP 500, connection closed (fail-closed) | +| Credential injection: resolved value contains CR/LF/null | Placeholder treated as unresolvable, fail-closed | +| Credential injection: path credential contains traversal/separator | HTTP 500, connection closed (fail-closed) | +| Credential injection: percent-encoded placeholder bypass attempt | HTTP 500, connection closed (fail-closed) | | L7 parse error | Close the connection | | SSH server failure | Async task error logged, main process unaffected | | Process timeout | Kill process, return exit code 124 | diff --git a/architecture/security-policy.md b/architecture/security-policy.md index 44898d70b..01eb96f94 100644 --- a/architecture/security-policy.md +++ b/architecture/security-policy.md @@ -320,7 +320,7 @@ Controls which filesystem paths the sandboxed process can access. Enforced via L | `read_only` | `string[]` | `[]` | Paths accessible in read-only mode | | `read_write` | `string[]` | `[]` | Paths accessible in read-write mode | -**Enforcement mapping**: Each path becomes a Landlock `PathBeneath` rule. Read-only paths receive `AccessFs::from_read(ABI::V1)` permissions. Read-write paths receive `AccessFs::from_all(ABI::V1)` permissions (read, write, execute, create, delete, rename). All other paths are denied by the Landlock ruleset. +**Enforcement mapping**: Each path becomes a Landlock `PathBeneath` rule. Read-only paths receive `AccessFs::from_read(ABI::V2)` permissions. Read-write paths receive `AccessFs::from_all(ABI::V2)` permissions (read, write, execute, create, delete, rename). All other paths are denied by the Landlock ruleset. **Filesystem preparation**: Before the child process spawns, the supervisor creates any `read_write` directories that do not exist and sets their ownership to `process.run_as_user`:`process.run_as_group` via `chown()`. See `crates/openshell-sandbox/src/lib.rs` -- `prepare_filesystem()`. @@ -358,10 +358,16 @@ Controls Landlock LSM compatibility behavior. **Static field** -- immutable afte | Value | Behavior | | ------------------ | --------------------------------------------------------------------------------------------------------------------------- | -| `best_effort` | If Landlock is unavailable (older kernel, unprivileged container), log a warning and continue without filesystem sandboxing | -| `hard_requirement` | If Landlock is unavailable, abort sandbox startup with an error | +| `best_effort` | If Landlock is unavailable (older kernel, unprivileged container), log a warning and continue without filesystem sandboxing. Individual inaccessible paths (missing, permission denied, symlink loops) are skipped with a warning while remaining rules are still applied. If all paths fail, the sandbox continues without Landlock rather than applying an empty ruleset that would block all access. | +| `hard_requirement` | If Landlock is unavailable or any configured path cannot be opened, abort sandbox startup with an error. | -See `crates/openshell-sandbox/src/sandbox/linux/landlock.rs` -- `compat_level()`. +**Per-path error handling**: `PathFd::new()` (which wraps `open(path, O_PATH | O_CLOEXEC)`) can fail for several reasons beyond path non-existence: `EACCES` (permission denied), `ELOOP` (symlink loop), `ENAMETOOLONG`, `ENOTDIR`. Each failure is classified with a human-readable reason in logs. In `best_effort` mode, the path is skipped and ruleset construction continues. In `hard_requirement` mode, the error is fatal. + +**Baseline path filtering**: The enrichment functions (`enrich_proto_baseline_paths`, `enrich_sandbox_baseline_paths`) pre-filter system-injected baseline paths (e.g., `/app`) by checking `Path::exists()` before adding them to the policy. This prevents missing baseline paths from reaching Landlock at all. User-specified paths are not pre-filtered — they are evaluated at Landlock apply time so that misconfigurations surface as warnings (`best_effort`) or errors (`hard_requirement`). + +**Zero-rule safety check**: If all paths in the ruleset fail to open, `apply()` returns an error rather than calling `restrict_self()` on an empty ruleset. An empty Landlock ruleset with `restrict_self()` would block all filesystem access — the inverse of the intended degradation behavior. This error is caught by the outer `BestEffort` handler, which logs a warning and continues without Landlock. + +See `crates/openshell-sandbox/src/sandbox/linux/landlock.rs` -- `compat_level()`, `try_open_path()`, `classify_path_fd_error()`, `classify_io_error()`. ```yaml landlock: @@ -461,9 +467,14 @@ rules: - allow: method: GET path: "/repos/**" + query: + per_page: "1*" - allow: method: POST path: "/repos/*/issues" + query: + labels: + any: ["bug*", "p1*"] ``` #### `L7Allow` @@ -473,8 +484,9 @@ rules: | `method` | `string` | HTTP method: `GET`, `HEAD`, `POST`, `PUT`, `DELETE`, `PATCH`, `OPTIONS`, or `*` (any). Case-insensitive matching. | | `path` | `string` | URL path glob pattern: `**` matches everything, otherwise `glob.match` with `/` delimiter. | | `command` | `string` | SQL command: `SELECT`, `INSERT`, `UPDATE`, `DELETE`, or `*` (any). Case-insensitive matching. For `protocol: sql` endpoints. | +| `query` | `map` | Optional REST query rules keyed by decoded query param name. Value is either a glob string (for example, `tag: "foo-*"`) or `{ any: ["foo-*", "bar-*"] }`. | -Method and command fields use `*` as wildcard for "any". Path patterns use `**` for "match everything" and standard glob patterns with `/` as a delimiter otherwise. See `sandbox-policy.rego` -- `method_matches()`, `path_matches()`, `command_matches()`. +Method and command fields use `*` as wildcard for "any". Path patterns use `**` for "match everything" and standard glob patterns with `/` as a delimiter otherwise. Query matching is case-sensitive and evaluates decoded values; when duplicate keys are present in the request, every value for that key must match the configured matcher. See `sandbox-policy.rego` -- `method_matches()`, `path_matches()`, `command_matches()`, `query_params_match()`. #### Access Presets @@ -716,7 +728,7 @@ If any condition fails, the proxy returns `403 Forbidden`. 7. Rewrites the request: absolute-form → origin-form (`GET /path HTTP/1.1`), strips hop-by-hop headers, adds `Via: 1.1 openshell-sandbox` and `Connection: close` 8. Forwards the rewritten request, then relays bidirectionally using `tokio::io::copy_bidirectional` (supports chunked transfer, SSE streams, and other long-lived responses with no idle timeout) -**V1 simplifications**: Forward proxy v1 injects `Connection: close` (no keep-alive) and does not perform L7 inspection on the forwarded traffic. Every forward proxy connection handles exactly one request-response exchange. +**V1 simplifications**: Forward proxy v1 injects `Connection: close` (no keep-alive). Every forward proxy connection handles exactly one request-response exchange. When an endpoint has L7 rules configured, the forward proxy evaluates the single request's method and path against L7 policy before forwarding. **Implementation**: See `crates/openshell-sandbox/src/proxy.rs` -- `handle_forward_proxy()`, `parse_proxy_uri()`, `rewrite_forward_request()`. @@ -838,6 +850,10 @@ The response includes an `X-OpenShell-Policy` header and `Connection: close`. Se ## Seccomp Filter Details +The seccomp filter uses a default-allow policy (`SeccompAction::Allow`) with targeted rules that return `EPERM`. It provides three layers of protection: socket domain blocks, unconditional syscall blocks, and conditional syscall blocks. See `crates/openshell-sandbox/src/sandbox/linux/seccomp.rs`. + +### Blocked socket domains + Regardless of network mode, certain socket domains are always blocked: | Domain | Constant | Reason | @@ -849,7 +865,30 @@ Regardless of network mode, certain socket domains are always blocked: In proxy mode (which is always active), `AF_INET` (2) and `AF_INET6` (10) are allowed so the sandbox process can reach the proxy. -The seccomp filter uses a default-allow policy (`SeccompAction::Allow`) with specific `socket()` syscall rules that return `EPERM` when the first argument (domain) matches a blocked value. See `crates/openshell-sandbox/src/sandbox/linux/seccomp.rs`. +### Blocked syscalls + +These syscalls are blocked unconditionally (EPERM for any invocation): + +| Syscall | NR (x86-64) | Reason | +|---------|-------------|--------| +| `memfd_create` | 319 | Fileless binary execution bypasses Landlock filesystem restrictions | +| `ptrace` | 101 | Cross-process memory inspection and code injection | +| `bpf` | 321 | Kernel BPF program loading | +| `process_vm_readv` | 310 | Cross-process memory read | +| `io_uring_setup` | 425 | Async I/O subsystem with extensive CVE history | +| `mount` | 165 | Filesystem mount could subvert Landlock or overlay writable paths | + +### Conditionally blocked syscalls + +These syscalls are blocked only when specific flag patterns are present in their arguments: + +| Syscall | NR (x86-64) | Condition | Reason | +|---------|-------------|-----------|--------| +| `execveat` | 322 | `AT_EMPTY_PATH` (0x1000) set in flags (arg4) | Fileless execution from an anonymous fd | +| `unshare` | 272 | `CLONE_NEWUSER` (0x10000000) set in flags (arg0) | User namespace creation enables privilege escalation | +| `seccomp` | 317 | operation == `SECCOMP_SET_MODE_FILTER` (1) in arg0 | Prevents sandboxed code from replacing the active filter | + +Flag checks use `MaskedEq` (`(arg & mask) == mask`) to detect the flag bit regardless of other bits. The `seccomp` syscall check uses `Eq` for exact value comparison on the operation argument. --- diff --git a/crates/openshell-bootstrap/Cargo.toml b/crates/openshell-bootstrap/Cargo.toml index ab57ad57a..942ffc48b 100644 --- a/crates/openshell-bootstrap/Cargo.toml +++ b/crates/openshell-bootstrap/Cargo.toml @@ -11,6 +11,7 @@ rust-version.workspace = true [dependencies] openshell-core = { path = "../openshell-core" } +async-stream = "0.3" base64 = "0.22" bollard = { version = "0.20", features = ["ssh"] } bytes = { workspace = true } @@ -20,11 +21,11 @@ rcgen = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } tar = "0.4" +tempfile = "3" tokio = { workspace = true } tracing = { workspace = true } [dev-dependencies] -tempfile = "3" [lints] workspace = true diff --git a/crates/openshell-bootstrap/src/build.rs b/crates/openshell-bootstrap/src/build.rs index 9624e01b3..eaa221311 100644 --- a/crates/openshell-bootstrap/src/build.rs +++ b/crates/openshell-bootstrap/src/build.rs @@ -8,7 +8,6 @@ //! to import the image into the gateway's containerd runtime. use std::collections::HashMap; -use std::io::Read; use std::path::Path; use bollard::Docker; @@ -176,36 +175,10 @@ fn walk_and_add( if path.is_dir() { walk_and_add(root, &path, ignore_patterns, builder)?; } else { - let mut file = std::fs::File::open(&path) - .into_diagnostic() - .wrap_err_with(|| format!("failed to open file: {}", path.display()))?; - let metadata = file - .metadata() - .into_diagnostic() - .wrap_err_with(|| format!("failed to read metadata: {}", path.display()))?; - - let mut header = tar::Header::new_gnu(); - header.set_size(metadata.len()); - header.set_mode(0o644); - #[cfg(unix)] - { - use std::os::unix::fs::PermissionsExt; - header.set_mode(metadata.permissions().mode()); - } - header - .set_path(&relative_normalized) - .into_diagnostic() - .wrap_err_with(|| format!("failed to set tar entry path: {relative_normalized}"))?; - header.set_cksum(); - - #[allow(clippy::cast_possible_truncation)] - let mut contents = Vec::with_capacity(metadata.len() as usize); - file.read_to_end(&mut contents) - .into_diagnostic() - .wrap_err_with(|| format!("failed to read file: {}", path.display()))?; - + // Use append_path_with_name which handles GNU LongName extensions + // for paths exceeding 100 bytes (the POSIX tar name field limit). builder - .append(&header, contents.as_slice()) + .append_path_with_name(&path, &relative_normalized) .into_diagnostic() .wrap_err_with(|| format!("failed to add file to tar: {relative_normalized}"))?; } @@ -433,6 +406,39 @@ mod tests { assert!(entries.iter().any(|e| e.contains("important.log"))); } + #[test] + fn test_long_path_exceeding_100_bytes() { + let dir = tempfile::tempdir().unwrap(); + let dir_path = dir.path(); + + // Build a nested path that exceeds 100 bytes when relative to root. + let deep_dir = dir_path.join( + "a/deeply/nested/directory/path/that/exceeds/one/hundred/bytes/total/from/the/build/context/root", + ); + fs::create_dir_all(&deep_dir).unwrap(); + fs::write(deep_dir.join("file.txt"), "deep content\n").unwrap(); + fs::write(dir_path.join("Dockerfile"), "FROM ubuntu:24.04\n").unwrap(); + + let tar_bytes = create_build_context_tar(dir_path).unwrap(); + let mut archive = tar::Archive::new(tar_bytes.as_slice()); + let entries: Vec = archive + .entries() + .unwrap() + .filter_map(std::result::Result::ok) + .map(|e| e.path().unwrap().to_string_lossy().to_string()) + .collect(); + + let long_entry = entries.iter().find(|e| e.contains("file.txt")); + assert!( + long_entry.is_some(), + "tar should contain deeply nested file; entries: {entries:?}" + ); + assert!( + long_entry.unwrap().len() > 100, + "path should exceed 100 bytes to exercise GNU LongName handling" + ); + } + #[test] fn test_simple_glob_match() { assert!(simple_glob_match("*.txt", "hello.txt")); diff --git a/crates/openshell-bootstrap/src/constants.rs b/crates/openshell-bootstrap/src/constants.rs index ff283b3ea..eee9000d1 100644 --- a/crates/openshell-bootstrap/src/constants.rs +++ b/crates/openshell-bootstrap/src/constants.rs @@ -11,11 +11,68 @@ pub const SERVER_TLS_SECRET_NAME: &str = "openshell-server-tls"; pub const SERVER_CLIENT_CA_SECRET_NAME: &str = "openshell-server-client-ca"; /// K8s secret holding the client TLS certificate, key, and CA cert (shared by CLI and sandboxes). pub const CLIENT_TLS_SECRET_NAME: &str = "openshell-client-tls"; +/// K8s secret holding the SSH handshake HMAC secret (shared by gateway and sandbox pods). +pub const SSH_HANDSHAKE_SECRET_NAME: &str = "openshell-ssh-handshake"; +const NODE_NAME_PREFIX: &str = "openshell-"; +const NODE_NAME_FALLBACK_SUFFIX: &str = "gateway"; +const KUBERNETES_MAX_NAME_LEN: usize = 253; pub fn container_name(name: &str) -> String { format!("openshell-cluster-{name}") } +/// Deterministic k3s node name derived from the gateway name. +/// +/// k3s defaults to using the container hostname (= Docker container ID) as +/// the node name. When the container is recreated (e.g. after an image +/// upgrade), the container ID changes, creating a new k3s node. The +/// `clean_stale_nodes` function then deletes PVCs whose backing PVs have +/// node affinity for the old node — wiping the server database and any +/// sandbox persistent volumes. +/// +/// By passing a deterministic `--node-name` to k3s, the node identity +/// survives container recreation, and PVCs are never orphaned. +/// +/// Gateway names allow Docker-friendly separators and uppercase characters, +/// but Kubernetes node names must be DNS-safe. Normalize the gateway name into +/// a single lowercase RFC 1123 label so previously accepted names such as +/// `prod_us` or `Prod.US` still deploy successfully. +pub fn node_name(name: &str) -> String { + format!("{NODE_NAME_PREFIX}{}", normalize_node_name_suffix(name)) +} + +fn normalize_node_name_suffix(name: &str) -> String { + let mut normalized = String::with_capacity(name.len()); + let mut last_was_separator = false; + + for ch in name.chars() { + if ch.is_ascii_alphanumeric() { + normalized.push(ch.to_ascii_lowercase()); + last_was_separator = false; + } else if !last_was_separator { + normalized.push('-'); + last_was_separator = true; + } + } + + let mut normalized = normalized.trim_matches('-').to_string(); + if normalized.is_empty() { + normalized.push_str(NODE_NAME_FALLBACK_SUFFIX); + } + + let max_suffix_len = KUBERNETES_MAX_NAME_LEN.saturating_sub(NODE_NAME_PREFIX.len()); + if normalized.len() > max_suffix_len { + normalized.truncate(max_suffix_len); + normalized.truncate(normalized.trim_end_matches('-').len()); + } + + if normalized.is_empty() { + normalized.push_str(NODE_NAME_FALLBACK_SUFFIX); + } + + normalized +} + pub fn volume_name(name: &str) -> String { format!("openshell-cluster-{name}") } @@ -23,3 +80,33 @@ pub fn volume_name(name: &str) -> String { pub fn network_name(name: &str) -> String { format!("openshell-cluster-{name}") } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn node_name_normalizes_uppercase_and_underscores() { + assert_eq!(node_name("Prod_US"), "openshell-prod-us"); + } + + #[test] + fn node_name_collapses_and_trims_separator_runs() { + assert_eq!(node_name("._Prod..__-Gateway-."), "openshell-prod-gateway"); + } + + #[test] + fn node_name_falls_back_when_gateway_name_has_no_alphanumerics() { + assert_eq!(node_name("...___---"), "openshell-gateway"); + } + + #[test] + fn node_name_truncates_to_kubernetes_name_limit() { + let gateway_name = "A".repeat(400); + let node_name = node_name(&gateway_name); + + assert!(node_name.len() <= KUBERNETES_MAX_NAME_LEN); + assert!(node_name.starts_with(NODE_NAME_PREFIX)); + assert!(node_name.ends_with('a')); + } +} diff --git a/crates/openshell-bootstrap/src/docker.rs b/crates/openshell-bootstrap/src/docker.rs index 9c365bfe8..be086e534 100644 --- a/crates/openshell-bootstrap/src/docker.rs +++ b/crates/openshell-bootstrap/src/docker.rs @@ -2,14 +2,15 @@ // SPDX-License-Identifier: Apache-2.0 use crate::RemoteOptions; -use crate::constants::{container_name, network_name, volume_name}; +use crate::constants::{container_name, network_name, node_name, volume_name}; use crate::image::{self, DEFAULT_IMAGE_REPO_BASE, DEFAULT_REGISTRY, parse_image_ref}; use bollard::API_DEFAULT_VERSION; use bollard::Docker; use bollard::errors::Error as BollardError; use bollard::models::{ - ContainerCreateBody, DeviceRequest, HostConfig, HostConfigCgroupnsModeEnum, - NetworkCreateRequest, NetworkDisconnectRequest, PortBinding, VolumeCreateRequest, + ContainerCreateBody, DeviceRequest, EndpointSettings, HostConfig, HostConfigCgroupnsModeEnum, + NetworkConnectRequest, NetworkCreateRequest, NetworkDisconnectRequest, PortBinding, + RestartPolicy, RestartPolicyNameEnum, VolumeCreateRequest, }; use bollard::query_parameters::{ CreateContainerOptions, CreateImageOptions, InspectContainerOptions, InspectNetworkOptions, @@ -22,6 +23,29 @@ use std::collections::HashMap; const REGISTRY_NAMESPACE_DEFAULT: &str = "openshell"; +/// Resolve the raw GPU device-ID list, replacing the `"auto"` sentinel with a +/// concrete device ID based on whether CDI is enabled on the daemon. +/// +/// | Input | Output | +/// |--------------|--------------------------------------------------------------| +/// | `[]` | `[]` — no GPU | +/// | `["legacy"]` | `["legacy"]` — pass through to the non-CDI fallback path | +/// | `["auto"]` | `["nvidia.com/gpu=all"]` if CDI enabled, else `["legacy"]` | +/// | `[cdi-ids…]` | unchanged | +pub(crate) fn resolve_gpu_device_ids(gpu: &[String], cdi_enabled: bool) -> Vec { + match gpu { + [] => vec![], + [v] if v == "auto" => { + if cdi_enabled { + vec!["nvidia.com/gpu=all".to_string()] + } else { + vec!["legacy".to_string()] + } + } + other => other.to_vec(), + } +} + const REGISTRY_MODE_EXTERNAL: &str = "external"; fn env_non_empty(key: &str) -> Option { @@ -443,6 +467,9 @@ pub async fn ensure_image( Ok(()) } +/// Returns the actual host port the container is using. When an existing +/// container is reused (same image), this may differ from `gateway_port` +/// because the container was originally created with a different port. pub async fn ensure_container( docker: &Docker, name: &str, @@ -454,8 +481,9 @@ pub async fn ensure_container( disable_gateway_auth: bool, registry_username: Option<&str>, registry_token: Option<&str>, - gpu: bool, -) -> Result<()> { + device_ids: &[String], + resume: bool, +) -> Result { let container_name = container_name(name); // Check if the container already exists @@ -464,33 +492,69 @@ pub async fn ensure_container( .await { Ok(info) => { - // Container exists — verify it is using the expected image. - // Resolve the desired image ref to its content-addressable ID so we - // can compare against the container's image field (which Docker - // stores as an ID). - let desired_id = docker - .inspect_image(image_ref) - .await - .ok() - .and_then(|img| img.id); + // On resume we always reuse the existing container — the persistent + // volume holds k3s etcd state, and recreating the container with + // different env vars would cause the entrypoint to rewrite the + // HelmChart manifest, triggering a Helm upgrade that changes the + // StatefulSet image reference while the old pod still runs with the + // previous image. Reusing the container avoids this entirely. + // + // On a non-resume path we check whether the image changed and + // recreate only when necessary. + let reuse = if resume { + true + } else { + let desired_id = docker + .inspect_image(image_ref) + .await + .ok() + .and_then(|img| img.id); - let container_image_id = info.image; + let container_image_id = info.image.clone(); - let image_matches = match (&desired_id, &container_image_id) { - (Some(desired), Some(current)) => desired == current, - _ => false, + match (&desired_id, &container_image_id) { + (Some(desired), Some(current)) => desired == current, + _ => false, + } }; - if image_matches { - return Ok(()); + if reuse { + // The container exists and should be reused. Its network + // attachment may be stale. When the gateway is resumed after a + // container kill, `ensure_network` destroys and recreates the + // Docker network (giving it a new ID). The stopped container + // still references the old network ID, so `docker start` would + // fail with "network not found". + // + // Fix: disconnect from any existing networks and reconnect to + // the current (just-created) network before returning. + let expected_net = network_name(name); + reconcile_container_network(docker, &container_name, &expected_net).await?; + + // Read the actual host port from the container's port bindings + // as a cross-check. The caller should already pass the correct + // port (from stored metadata), but this catches mismatches if + // the container was recreated with a different port externally. + let actual_port = info + .host_config + .as_ref() + .and_then(|hc| hc.port_bindings.as_ref()) + .and_then(|pb| pb.get("30051/tcp")) + .and_then(|bindings| bindings.as_ref()) + .and_then(|bindings| bindings.first()) + .and_then(|b| b.host_port.as_ref()) + .and_then(|p| p.parse::().ok()) + .unwrap_or(gateway_port); + + return Ok(actual_port); } - // Image changed — remove the stale container so we can recreate it + // Image changed — remove the stale container so we can recreate it. tracing::info!( "Container {} exists but uses a different image (container={}, desired={}), recreating", container_name, - container_image_id.as_deref().map_or("unknown", truncate_id), - desired_id.as_deref().map_or("unknown", truncate_id), + info.image.as_deref().map_or("unknown", truncate_id), + image_ref, ); let _ = docker.stop_container(&container_name, None).await; @@ -532,6 +596,12 @@ pub async fn ensure_container( port_bindings: Some(port_bindings), binds: Some(vec![format!("{}:/var/lib/rancher/k3s", volume_name(name))]), network_mode: Some(network_name(name)), + // Automatically restart the container when Docker restarts, unless the + // user explicitly stopped it with `gateway stop`. + restart_policy: Some(RestartPolicy { + name: Some(RestartPolicyNameEnum::UNLESS_STOPPED), + maximum_retry_count: None, + }), // Add host gateway aliases for DNS resolution. // This allows both the entrypoint script and the running gateway // process to reach services on the Docker host. @@ -542,21 +612,35 @@ pub async fn ensure_container( ..Default::default() }; - // When GPU support is requested, add NVIDIA device requests. - // This is the programmatic equivalent of `docker run --gpus all`. - // The NVIDIA Container Toolkit runtime hook injects /dev/nvidia* devices - // and GPU driver libraries from the host into the container. - if gpu { - host_config.device_requests = Some(vec![DeviceRequest { - driver: Some("nvidia".to_string()), - count: Some(-1), // all GPUs - capabilities: Some(vec![vec![ - "gpu".to_string(), - "utility".to_string(), - "compute".to_string(), - ]]), - ..Default::default() - }]); + // Inject GPU devices into the container based on the resolved device ID list. + // + // The list is pre-resolved by `resolve_gpu_device_ids` before reaching here: + // [] — no GPU passthrough + // ["legacy"] — internal non-CDI fallback path: `driver="nvidia"`, + // `count=-1`; relies on the NVIDIA Container Runtime hook + // [cdi-ids…] — CDI DeviceRequest (driver="cdi") with the given device IDs; + // Docker resolves them against the host CDI spec at /etc/cdi/ + match device_ids { + [] => {} + [id] if id == "legacy" => { + host_config.device_requests = Some(vec![DeviceRequest { + driver: Some("nvidia".to_string()), + count: Some(-1), // all GPUs + capabilities: Some(vec![vec![ + "gpu".to_string(), + "utility".to_string(), + "compute".to_string(), + ]]), + ..Default::default() + }]); + } + ids => { + host_config.device_requests = Some(vec![DeviceRequest { + driver: Some("cdi".to_string()), + device_ids: Some(ids.to_vec()), + ..Default::default() + }]); + } } let mut cmd = vec![ @@ -610,6 +694,11 @@ pub async fn ensure_container( format!("REGISTRY_HOST={registry_host}"), format!("REGISTRY_INSECURE={registry_insecure}"), format!("IMAGE_REPO_BASE={image_repo_base}"), + // Deterministic k3s node name so the node identity survives container + // recreation (e.g. after an image upgrade). Without this, k3s uses + // the container ID as the hostname/node name, which changes on every + // container recreate and triggers stale-node PVC cleanup. + format!("OPENSHELL_NODE_NAME={}", node_name(name)), ]; if let Some(endpoint) = registry_endpoint { env_vars.push(format!("REGISTRY_ENDPOINT={endpoint}")); @@ -671,7 +760,7 @@ pub async fn ensure_container( // GPU support: tell the entrypoint to deploy the NVIDIA device plugin // HelmChart CR so k8s workloads can request nvidia.com/gpu resources. - if gpu { + if !device_ids.is_empty() { env_vars.push("GPU_ENABLED=true".to_string()); } @@ -679,6 +768,14 @@ pub async fn ensure_container( let config = ContainerCreateBody { image: Some(image_ref.to_string()), + // Set the container hostname to the deterministic node name. + // k3s uses the container hostname as its default node name. Without + // this, Docker defaults to the container ID (first 12 hex chars), + // which changes on every container recreation and can cause + // `clean_stale_nodes` to delete the wrong node on resume. The + // hostname persists across container stop/start cycles, ensuring a + // stable node identity. + hostname: Some(node_name(name)), cmd: Some(cmd), env, exposed_ports: Some(exposed_ports), @@ -697,7 +794,7 @@ pub async fn ensure_container( .await .into_diagnostic() .wrap_err("failed to create gateway container")?; - Ok(()) + Ok(gateway_port) } /// Information about a container that is holding a port we need. @@ -919,6 +1016,48 @@ pub async fn destroy_gateway_resources(docker: &Docker, name: &str) -> Result<() Ok(()) } +/// Clean up the gateway container and network, preserving the persistent volume. +/// +/// Used when a resume attempt fails — we want to remove the container we may +/// have just created but keep the volume so the user can retry without losing +/// their k3s/etcd state and sandbox data. +pub async fn cleanup_gateway_container(docker: &Docker, name: &str) -> Result<()> { + let container_name = container_name(name); + let net_name = network_name(name); + + // Disconnect container from network + let _ = docker + .disconnect_network( + &net_name, + NetworkDisconnectRequest { + container: container_name.clone(), + force: Some(true), + }, + ) + .await; + + let _ = stop_container(docker, &container_name).await; + + let remove_container = docker + .remove_container( + &container_name, + Some(RemoveContainerOptions { + force: true, + ..Default::default() + }), + ) + .await; + if let Err(err) = remove_container + && !is_not_found(&err) + { + return Err(err).into_diagnostic(); + } + + force_remove_network(docker, &net_name).await?; + + Ok(()) +} + /// Forcefully remove a Docker network, disconnecting any remaining /// containers first. This ensures that stale Docker network endpoints /// cannot prevent port bindings from being released. @@ -956,6 +1095,71 @@ async fn force_remove_network(docker: &Docker, net_name: &str) -> Result<()> { } } +/// Ensure a stopped container is connected to the expected Docker network. +/// +/// When a gateway is resumed after the container was killed (but not removed), +/// `ensure_network` destroys and recreates the network with a new ID. The +/// stopped container still holds a reference to the old network ID in its +/// config, so `docker start` would fail with a 404 "network not found" error. +/// +/// This function disconnects the container from any networks that no longer +/// match the expected network name and connects it to the correct one. +async fn reconcile_container_network( + docker: &Docker, + container_name: &str, + expected_network: &str, +) -> Result<()> { + let info = docker + .inspect_container(container_name, None::) + .await + .into_diagnostic() + .wrap_err("failed to inspect container for network reconciliation")?; + + // Check the container's current network attachments via NetworkSettings. + let attached_networks: Vec = info + .network_settings + .as_ref() + .and_then(|ns| ns.networks.as_ref()) + .map(|nets| nets.keys().cloned().collect()) + .unwrap_or_default(); + + // If the container is already attached to the expected network (by name), + // Docker will resolve the name to the current network ID on start. + // However, when the network was destroyed and recreated, the container's + // stored endpoint references the old ID. Disconnect and reconnect to + // pick up the new network ID. + for net_name in &attached_networks { + let _ = docker + .disconnect_network( + net_name, + NetworkDisconnectRequest { + container: container_name.to_string(), + force: Some(true), + }, + ) + .await; + } + + // Connect to the (freshly created) expected network. + docker + .connect_network( + expected_network, + NetworkConnectRequest { + container: container_name.to_string(), + endpoint_config: Some(EndpointSettings::default()), + }, + ) + .await + .into_diagnostic() + .wrap_err("failed to connect container to gateway network")?; + + tracing::debug!( + "Reconciled network for container {container_name}: disconnected from {attached_networks:?}, connected to {expected_network}" + ); + + Ok(()) +} + fn is_not_found(err: &BollardError) -> bool { matches!( err, @@ -1195,4 +1399,53 @@ mod tests { "should return a reasonable number of sockets" ); } + + // --- resolve_gpu_device_ids --- + + #[test] + fn resolve_gpu_empty_returns_empty() { + assert_eq!(resolve_gpu_device_ids(&[], true), Vec::::new()); + assert_eq!(resolve_gpu_device_ids(&[], false), Vec::::new()); + } + + #[test] + fn resolve_gpu_auto_cdi_enabled() { + assert_eq!( + resolve_gpu_device_ids(&["auto".to_string()], true), + vec!["nvidia.com/gpu=all"], + ); + } + + #[test] + fn resolve_gpu_auto_cdi_disabled() { + assert_eq!( + resolve_gpu_device_ids(&["auto".to_string()], false), + vec!["legacy"], + ); + } + + #[test] + fn resolve_gpu_legacy_passthrough() { + assert_eq!( + resolve_gpu_device_ids(&["legacy".to_string()], true), + vec!["legacy"], + ); + assert_eq!( + resolve_gpu_device_ids(&["legacy".to_string()], false), + vec!["legacy"], + ); + } + + #[test] + fn resolve_gpu_cdi_ids_passthrough() { + let ids = vec!["nvidia.com/gpu=all".to_string()]; + assert_eq!(resolve_gpu_device_ids(&ids, true), ids); + assert_eq!(resolve_gpu_device_ids(&ids, false), ids); + + let multi = vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string(), + ]; + assert_eq!(resolve_gpu_device_ids(&multi, true), multi); + } } diff --git a/crates/openshell-bootstrap/src/errors.rs b/crates/openshell-bootstrap/src/errors.rs index b487c94a6..9e385c680 100644 --- a/crates/openshell-bootstrap/src/errors.rs +++ b/crates/openshell-bootstrap/src/errors.rs @@ -169,6 +169,21 @@ const FAILURE_PATTERNS: &[FailurePattern] = &[ match_mode: MatchMode::Any, diagnose: diagnose_docker_not_running, }, + // CDI specs missing — Docker daemon has CDI configured but no spec files exist + // or the requested device ID (nvidia.com/gpu=all) is not in any spec. + // Matches errors from Docker 25+ and containerd CDI injection paths. + FailurePattern { + matchers: &[ + "CDI device not found", + "unknown CDI device", + "failed to inject CDI devices", + "no CDI devices found", + "CDI device injection failed", + "unresolvable CDI devices", + ], + match_mode: MatchMode::Any, + diagnose: diagnose_cdi_specs_missing, + }, ]; fn diagnose_corrupted_state(gateway_name: &str) -> GatewayFailureDiagnosis { @@ -396,6 +411,29 @@ fn diagnose_certificate_issue(gateway_name: &str) -> GatewayFailureDiagnosis { } } +fn diagnose_cdi_specs_missing(_gateway_name: &str) -> GatewayFailureDiagnosis { + GatewayFailureDiagnosis { + summary: "CDI specs not found on host".to_string(), + explanation: "GPU passthrough via CDI was selected (your Docker daemon has CDI spec \ + directories configured) but no CDI device specs were found on the host. \ + Specs must be pre-generated before OpenShell can inject the GPU into the \ + cluster container." + .to_string(), + recovery_steps: vec![ + RecoveryStep::with_command( + "Generate CDI specs on the host (nvidia-ctk creates /etc/cdi/ if it does not exist)", + "sudo nvidia-ctk cdi generate --output=/etc/cdi/nvidia.yaml", + ), + RecoveryStep::with_command( + "Verify the specs were generated and include nvidia.com/gpu entries", + "nvidia-ctk cdi list", + ), + RecoveryStep::new("Then retry: openshell gateway start --gpu"), + ], + retryable: false, + } +} + fn diagnose_docker_not_running(_gateway_name: &str) -> GatewayFailureDiagnosis { GatewayFailureDiagnosis { summary: "Docker is not running".to_string(), @@ -925,4 +963,76 @@ mod tests { "commands should include gateway name, got: {all_commands}" ); } + + #[test] + fn test_diagnose_cdi_device_not_found() { + let diagnosis = diagnose_failure( + "test", + "could not run container: CDI device not found: nvidia.com/gpu=all", + None, + ); + assert!(diagnosis.is_some()); + let d = diagnosis.unwrap(); + assert!( + d.summary.contains("CDI"), + "expected CDI diagnosis, got: {}", + d.summary + ); + assert!(!d.retryable); + } + + #[test] + fn test_diagnose_cdi_injection_failed_unresolvable() { + // Exact error observed from Docker 500 response + let diagnosis = diagnose_failure( + "test", + "Docker responded with status code 500: CDI device injection failed: unresolvable CDI devices nvidia.com/gpu=all", + None, + ); + assert!(diagnosis.is_some()); + let d = diagnosis.unwrap(); + assert!( + d.summary.contains("CDI"), + "expected CDI diagnosis, got: {}", + d.summary + ); + assert!(!d.retryable); + } + + #[test] + fn test_diagnose_unknown_cdi_device() { + // containerd error path + let diagnosis = diagnose_failure( + "test", + "unknown CDI device requested: nvidia.com/gpu=all", + None, + ); + assert!(diagnosis.is_some()); + let d = diagnosis.unwrap(); + assert!( + d.summary.contains("CDI"), + "expected CDI diagnosis, got: {}", + d.summary + ); + } + + #[test] + fn test_diagnose_cdi_recovery_mentions_nvidia_ctk() { + let d = diagnose_failure("test", "CDI device not found", None) + .expect("should match CDI pattern"); + let all_steps: String = d + .recovery_steps + .iter() + .map(|s| format!("{} {}", s.description, s.command.as_deref().unwrap_or(""))) + .collect::>() + .join("\n"); + assert!( + all_steps.contains("nvidia-ctk cdi generate"), + "recovery steps should mention nvidia-ctk cdi generate, got: {all_steps}" + ); + assert!( + all_steps.contains("/etc/cdi/"), + "recovery steps should mention /etc/cdi/, got: {all_steps}" + ); + } } diff --git a/crates/openshell-bootstrap/src/lib.rs b/crates/openshell-bootstrap/src/lib.rs index 938986757..8ce10703e 100644 --- a/crates/openshell-bootstrap/src/lib.rs +++ b/crates/openshell-bootstrap/src/lib.rs @@ -26,12 +26,13 @@ use miette::{IntoDiagnostic, Result}; use std::sync::{Arc, Mutex}; use crate::constants::{ - CLIENT_TLS_SECRET_NAME, SERVER_CLIENT_CA_SECRET_NAME, SERVER_TLS_SECRET_NAME, network_name, - volume_name, + CLIENT_TLS_SECRET_NAME, SERVER_CLIENT_CA_SECRET_NAME, SERVER_TLS_SECRET_NAME, + SSH_HANDSHAKE_SECRET_NAME, network_name, volume_name, }; use crate::docker::{ - check_existing_gateway, check_port_conflicts, destroy_gateway_resources, ensure_container, - ensure_image, ensure_network, ensure_volume, start_container, stop_container, + check_existing_gateway, check_port_conflicts, cleanup_gateway_container, + destroy_gateway_resources, ensure_container, ensure_image, ensure_network, ensure_volume, + resolve_gpu_device_ids, start_container, stop_container, }; use crate::metadata::{ create_gateway_metadata, create_gateway_metadata_with_host, local_gateway_host, @@ -111,10 +112,13 @@ pub struct DeployOptions { /// bootstrap pull and inside the k3s cluster at runtime. Only needed /// for private registries. pub registry_token: Option, - /// Enable NVIDIA GPU passthrough. When true, the Docker container is - /// created with GPU device requests (`--gpus all`) and the NVIDIA - /// k8s-device-plugin is deployed inside the k3s cluster. - pub gpu: bool, + /// GPU device IDs to inject into the gateway container. + /// + /// - `[]` — no GPU passthrough (default) + /// - `["legacy"]` — internal non-CDI fallback path (`driver="nvidia"`, `count=-1`) + /// - `["auto"]` — resolved at deploy time: CDI if enabled on the daemon, else the non-CDI fallback + /// - `[cdi-ids…]` — CDI DeviceRequest with the given device IDs + pub gpu: Vec, /// When true, destroy any existing gateway resources before deploying. /// When false, an existing gateway is left as-is and deployment is /// skipped (the caller is responsible for prompting the user first). @@ -133,7 +137,7 @@ impl DeployOptions { disable_gateway_auth: false, registry_username: None, registry_token: None, - gpu: false, + gpu: vec![], recreate: false, } } @@ -187,9 +191,13 @@ impl DeployOptions { self } - /// Enable NVIDIA GPU passthrough for the cluster container. + /// Set GPU device IDs for the cluster container. + /// + /// Pass `vec!["auto"]` to auto-select between CDI and the non-CDI fallback + /// based on daemon capabilities at deploy time. The `legacy` sentinel is an + /// internal implementation detail for the fallback path. #[must_use] - pub fn with_gpu(mut self, gpu: bool) -> Self { + pub fn with_gpu(mut self, gpu: Vec) -> Self { self.gpu = gpu; self } @@ -288,49 +296,77 @@ where (preflight.docker, None) }; - // If an existing gateway is found, either tear it down (when recreate is - // requested) or bail out so the caller can prompt the user / reuse it. + // CDI is considered enabled when the daemon reports at least one CDI spec + // directory via `GET /info` (`SystemInfo.CDISpecDirs`). An empty list or + // missing field means CDI is not configured and we fall back to the legacy + // NVIDIA `DeviceRequest` (driver="nvidia"). Detection is best-effort — + // failure to query daemon info is non-fatal. + let cdi_supported = target_docker + .info() + .await + .ok() + .and_then(|info| info.cdi_spec_dirs) + .is_some_and(|dirs| !dirs.is_empty()); + + // If an existing gateway is found, decide how to proceed: + // - recreate: destroy everything and start fresh + // - otherwise: auto-resume from existing state (the ensure_* calls are + // idempotent and will reuse the volume, create a container if needed, + // and start it) + let mut resume = false; + let mut resume_container_exists = false; if let Some(existing) = check_existing_gateway(&target_docker, &name).await? { if recreate { log("[status] Removing existing gateway".to_string()); destroy_gateway_resources(&target_docker, &name).await?; + } else if existing.container_running { + log("[status] Gateway is already running".to_string()); + resume = true; + resume_container_exists = true; } else { - return Err(miette::miette!( - "Gateway '{name}' already exists (container_running={}).\n\ - Use --recreate to destroy and redeploy, or destroy it first with:\n\n \ - openshell gateway destroy --name {name}", - existing.container_running, - )); + log("[status] Resuming gateway from existing state".to_string()); + resume = true; + resume_container_exists = existing.container_exists; } } - // Ensure the image is available on the target Docker daemon - if remote_opts.is_some() { - log("[status] Downloading gateway".to_string()); - let on_log_clone = Arc::clone(&on_log); - let progress_cb = move |msg: String| { - if let Ok(mut f) = on_log_clone.lock() { - f(msg); - } - }; - image::pull_remote_image( - &target_docker, - &image_ref, - registry_username.as_deref(), - registry_token.as_deref(), - progress_cb, - ) - .await?; - } else { - // Local deployment: ensure image exists (pull if needed) - log("[status] Downloading gateway".to_string()); - ensure_image( - &target_docker, - &image_ref, - registry_username.as_deref(), - registry_token.as_deref(), - ) - .await?; + // Ensure the image is available on the target Docker daemon. + // When both the container and volume exist we can skip the pull entirely + // — the container already references a valid local image. This avoids + // failures when the original image tag (e.g. a local-only + // `openshell/cluster:dev`) is not available from the default registry. + // + // When only the volume survives (container was removed), we still need + // the image to recreate the container, so the pull must happen. + let need_image = !resume || !resume_container_exists; + if need_image { + if remote_opts.is_some() { + log("[status] Downloading gateway".to_string()); + let on_log_clone = Arc::clone(&on_log); + let progress_cb = move |msg: String| { + if let Ok(mut f) = on_log_clone.lock() { + f(msg); + } + }; + image::pull_remote_image( + &target_docker, + &image_ref, + registry_username.as_deref(), + registry_token.as_deref(), + progress_cb, + ) + .await?; + } else { + // Local deployment: ensure image exists (pull if needed) + log("[status] Downloading gateway".to_string()); + ensure_image( + &target_docker, + &image_ref, + registry_username.as_deref(), + registry_token.as_deref(), + ) + .await?; + } } // All subsequent operations use the target Docker (remote or local) @@ -405,7 +441,11 @@ where // leaving an orphaned volume in a corrupted state that blocks retries. // See: https://github.com/NVIDIA/OpenShell/issues/463 let deploy_result: Result = async { - ensure_container( + let device_ids = resolve_gpu_device_ids(&gpu, cdi_supported); + // ensure_container returns the actual host port — which may differ from + // the requested `port` when reusing an existing container that was + // originally created with a different port. + let actual_port = ensure_container( &target_docker, &name, &image_ref, @@ -416,19 +456,26 @@ where disable_gateway_auth, registry_username.as_deref(), registry_token.as_deref(), - gpu, + &device_ids, + resume, ) .await?; + let port = actual_port; start_container(&target_docker, &name).await?; // Clean up stale k3s nodes left over from previous container instances that - // used the same persistent volume. Without this, pods remain scheduled on + // used the same persistent volume. Without this, pods remain scheduled on // NotReady ghost nodes and the health check will time out. + // + // The function retries internally until kubectl becomes available (k3s may + // still be initialising after the container start). It also force-deletes + // pods stuck in Terminating on the removed nodes so that StatefulSets can + // reschedule replacements immediately. match clean_stale_nodes(&target_docker, &name).await { Ok(0) => {} - Ok(n) => tracing::debug!("removed {n} stale node(s)"), + Ok(n) => tracing::info!("removed {n} stale node(s) and their orphaned pods"), Err(err) => { - tracing::debug!("stale node cleanup failed (non-fatal): {err}"); + tracing::warn!("stale node cleanup failed (non-fatal): {err}"); } } @@ -455,6 +502,11 @@ where store_pki_bundle(&name, &pki_bundle)?; + // Reconcile SSH handshake secret: reuse existing K8s secret if present, + // generate and persist a new one otherwise. This secret is stored in etcd + // (on the persistent volume) so it survives container restarts. + reconcile_ssh_handshake_secret(&target_docker, &name, &log).await?; + // Push locally-built component images into the k3s containerd runtime. // This is the "push" path for local development — images are exported from // the local Docker daemon and streamed into the cluster's containerd so @@ -524,15 +576,30 @@ where docker: target_docker, }), Err(deploy_err) => { - // Automatically clean up Docker resources (volume, container, network, - // image) so the environment is left in a retryable state. - tracing::info!("deploy failed, cleaning up gateway resources for '{name}'"); - if let Err(cleanup_err) = destroy_gateway_resources(&target_docker, &name).await { - tracing::warn!( - "automatic cleanup after failed deploy also failed: {cleanup_err}. \ - Manual cleanup may be required: \ - openshell gateway destroy --name {name}" + if resume { + // When resuming, preserve the volume so the user can retry. + // Only clean up the container and network that we may have created. + tracing::info!( + "resume failed, cleaning up container for '{name}' (preserving volume)" ); + if let Err(cleanup_err) = cleanup_gateway_container(&target_docker, &name).await { + tracing::warn!( + "automatic cleanup after failed resume also failed: {cleanup_err}. \ + Manual cleanup may be required: \ + openshell gateway destroy --name {name}" + ); + } + } else { + // Automatically clean up Docker resources (volume, container, network, + // image) so the environment is left in a retryable state. + tracing::info!("deploy failed, cleaning up gateway resources for '{name}'"); + if let Err(cleanup_err) = destroy_gateway_resources(&target_docker, &name).await { + tracing::warn!( + "automatic cleanup after failed deploy also failed: {cleanup_err}. \ + Manual cleanup may be required: \ + openshell gateway destroy --name {name}" + ); + } } Err(deploy_err) } @@ -809,6 +876,14 @@ where let cname = container_name(name); let kubeconfig = constants::KUBECONFIG_PATH; + // Wait for the k3s API server and openshell namespace before attempting + // to read secrets. Without this, kubectl fails transiently on resume + // (k3s hasn't booted yet), the code assumes secrets are gone, and + // regenerates PKI unnecessarily — triggering a server rollout restart + // and TLS errors for in-flight connections. + log("[progress] Waiting for openshell namespace".to_string()); + wait_for_namespace(docker, &cname, kubeconfig, "openshell").await?; + // Try to load existing secrets. match load_existing_pki_bundle(docker, &cname, kubeconfig).await { Ok(bundle) => { @@ -823,10 +898,6 @@ where } // Generate fresh PKI and apply to cluster. - // Namespace may still be creating on first bootstrap, so wait here only - // when rotation is actually needed. - log("[progress] Waiting for openshell namespace".to_string()); - wait_for_namespace(docker, &cname, kubeconfig, "openshell").await?; log("[progress] Generating TLS certificates".to_string()); let bundle = generate_pki(extra_sans)?; log("[progress] Applying TLS secrets to gateway".to_string()); @@ -837,6 +908,72 @@ where Ok((bundle, true)) } +/// Reconcile the SSH handshake HMAC secret as a Kubernetes Secret. +/// +/// If the secret already exists in the cluster, this is a no-op. Otherwise a +/// fresh 32-byte hex secret is generated and applied. Because the secret lives +/// in etcd (backed by the persistent Docker volume), it survives container +/// restarts without regeneration — existing sandbox SSH sessions remain valid. +async fn reconcile_ssh_handshake_secret(docker: &Docker, name: &str, log: &F) -> Result<()> +where + F: Fn(String) + Sync, +{ + use miette::WrapErr; + + let cname = container_name(name); + let kubeconfig = constants::KUBECONFIG_PATH; + + // Check if the secret already exists. + let (output, exit_code) = exec_capture_with_exit( + docker, + &cname, + vec![ + "sh".to_string(), + "-c".to_string(), + format!( + "KUBECONFIG={kubeconfig} kubectl -n openshell get secret {SSH_HANDSHAKE_SECRET_NAME} -o jsonpath='{{.data.secret}}' 2>/dev/null" + ), + ], + ) + .await?; + + if exit_code == 0 && !output.trim().is_empty() { + tracing::debug!( + "existing SSH handshake secret found ({} bytes encoded)", + output.trim().len() + ); + log("[progress] Reusing existing SSH handshake secret".to_string()); + return Ok(()); + } + + // Generate a new 32-byte hex secret and create the K8s secret. + log("[progress] Generating SSH handshake secret".to_string()); + let (output, exit_code) = exec_capture_with_exit( + docker, + &cname, + vec![ + "sh".to_string(), + "-c".to_string(), + format!( + "SECRET=$(head -c 32 /dev/urandom | od -A n -t x1 | tr -d ' \\n') && \ + KUBECONFIG={kubeconfig} kubectl -n openshell create secret generic {SSH_HANDSHAKE_SECRET_NAME} \ + --from-literal=secret=$SECRET --dry-run=client -o yaml | \ + KUBECONFIG={kubeconfig} kubectl apply -f -" + ), + ], + ) + .await?; + + if exit_code != 0 { + return Err(miette::miette!( + "failed to create SSH handshake secret (exit {exit_code}): {output}" + )) + .wrap_err("failed to apply SSH handshake secret"); + } + + Ok(()) +} + /// Load existing TLS secrets from the cluster and reconstruct a [`PkiBundle`]. /// /// Returns an error string describing why secrets couldn't be loaded (for logging). diff --git a/crates/openshell-bootstrap/src/metadata.rs b/crates/openshell-bootstrap/src/metadata.rs index 15f79c089..20680f4c0 100644 --- a/crates/openshell-bootstrap/src/metadata.rs +++ b/crates/openshell-bootstrap/src/metadata.rs @@ -47,6 +47,34 @@ pub struct GatewayMetadata { pub edge_auth_url: Option, } +impl GatewayMetadata { + /// Extract the host portion from the stored `gateway_endpoint` URL. + /// + /// Returns `None` if the endpoint is malformed or uses a default loopback + /// address (`127.0.0.1`, `localhost`, `::1`) — those are never meaningful + /// as a `--gateway-host` override. + pub fn gateway_host(&self) -> Option<&str> { + // Endpoint format: "https://host:port" or "http://host:port" + let after_scheme = self + .gateway_endpoint + .strip_prefix("https://") + .or_else(|| self.gateway_endpoint.strip_prefix("http://"))?; + // Strip port suffix (":8082") + let host = after_scheme + .rsplit_once(':') + .map_or(after_scheme, |(h, _)| h); + if host.is_empty() + || host == "127.0.0.1" + || host == "localhost" + || host == "::1" + || host == "[::1]" + { + return None; + } + Some(host) + } +} + pub fn create_gateway_metadata( name: &str, remote: Option<&RemoteOptions>, @@ -500,6 +528,61 @@ mod tests { assert_eq!(meta.gateway_endpoint, "http://host.docker.internal:8080"); } + // ── GatewayMetadata::gateway_host() ────────────────────────────── + + #[test] + fn gateway_host_returns_custom_host() { + let meta = + create_gateway_metadata_with_host("t", None, 8082, Some("host.docker.internal"), false); + assert_eq!(meta.gateway_host(), Some("host.docker.internal")); + } + + #[test] + fn gateway_host_returns_none_for_loopback() { + let meta = create_gateway_metadata("t", None, 8080); + // Default endpoint is https://127.0.0.1:8080 + assert_eq!(meta.gateway_host(), None); + } + + #[test] + fn gateway_host_returns_none_for_localhost() { + let meta = GatewayMetadata { + name: "t".into(), + gateway_endpoint: "https://localhost:8080".into(), + is_remote: false, + gateway_port: 8080, + remote_host: None, + resolved_host: None, + auth_mode: None, + edge_team_domain: None, + edge_auth_url: None, + }; + assert_eq!(meta.gateway_host(), None); + } + + #[test] + fn gateway_host_returns_ip_for_remote() { + let meta = GatewayMetadata { + name: "t".into(), + gateway_endpoint: "https://10.0.0.5:8080".into(), + is_remote: true, + gateway_port: 8080, + remote_host: Some("user@10.0.0.5".into()), + resolved_host: Some("10.0.0.5".into()), + auth_mode: None, + edge_team_domain: None, + edge_auth_url: None, + }; + assert_eq!(meta.gateway_host(), Some("10.0.0.5")); + } + + #[test] + fn gateway_host_handles_http_scheme() { + let meta = + create_gateway_metadata_with_host("t", None, 8080, Some("host.docker.internal"), true); + assert_eq!(meta.gateway_host(), Some("host.docker.internal")); + } + #[test] fn remote_gateway_metadata_with_tls_disabled() { let opts = RemoteOptions::new("user@10.0.0.5"); diff --git a/crates/openshell-bootstrap/src/push.rs b/crates/openshell-bootstrap/src/push.rs index 0dcbaa6da..336d46c3e 100644 --- a/crates/openshell-bootstrap/src/push.rs +++ b/crates/openshell-bootstrap/src/push.rs @@ -8,15 +8,23 @@ //! uploaded into the gateway container as a tar file via the Docker //! `put_archive` API, and then imported into containerd via `ctr images import`. //! +//! To avoid unbounded memory usage with large images, the export is streamed +//! to a temporary file on disk, then streamed back through a tar wrapper into +//! the Docker upload API. Peak memory usage is `O(chunk_size)` regardless of +//! image size. +//! //! The standalone `ctr` binary is used (not `k3s ctr` which may not work in //! all k3s versions) with the k3s containerd socket. The default containerd //! namespace in k3s is already `k8s.io`, which is what kubelet uses. +use std::pin::Pin; + use bollard::Docker; use bollard::query_parameters::UploadToContainerOptionsBuilder; use bytes::Bytes; -use futures::StreamExt; +use futures::{Stream, StreamExt}; use miette::{IntoDiagnostic, Result, WrapErr}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; use crate::runtime::exec_capture_with_exit; @@ -26,11 +34,19 @@ const CONTAINERD_SOCK: &str = "/run/k3s/containerd/containerd.sock"; /// Path inside the container where the image tar is staged. const IMPORT_TAR_PATH: &str = "/tmp/openshell-images.tar"; +/// Size of chunks read from the temp file during streaming upload (8 MiB). +const UPLOAD_CHUNK_SIZE: usize = 8 * 1024 * 1024; + +/// Report export progress every N bytes (100 MiB). +const PROGRESS_INTERVAL_BYTES: u64 = 100 * 1024 * 1024; + /// Push a list of images from the local Docker daemon into a k3s gateway's /// containerd runtime. /// /// All images are exported as a single tar (shared layers are deduplicated), -/// uploaded to the container filesystem, and imported into containerd. +/// streamed to a temporary file, then uploaded to the container filesystem +/// and imported into containerd. Memory usage is bounded to `O(chunk_size)` +/// regardless of image size. pub async fn push_local_images( local_docker: &Docker, gateway_docker: &Docker, @@ -42,17 +58,30 @@ pub async fn push_local_images( return Ok(()); } - // 1. Export all images from the local Docker daemon as a single tar. - let image_tar = collect_export(local_docker, images).await?; + // 1. Export all images from the local Docker daemon to a temp file. + let (tmp_file, file_size) = export_to_tempfile(local_docker, images, on_log).await?; on_log(format!( "[progress] Exported {} MiB", - image_tar.len() / (1024 * 1024) + file_size / (1024 * 1024) )); - // 2. Wrap the image tar as a file inside an outer tar archive and upload - // it into the container filesystem via the Docker put_archive API. - let outer_tar = wrap_in_tar(IMPORT_TAR_PATH, &image_tar)?; - upload_archive(gateway_docker, container_name, &outer_tar).await?; + // 2. Stream the image tar wrapped in an outer tar archive into the + // container filesystem via the Docker put_archive API. + let parent_dir = IMPORT_TAR_PATH.rsplit_once('/').map_or("/", |(dir, _)| dir); + let options = UploadToContainerOptionsBuilder::default() + .path(parent_dir) + .build(); + + let upload_stream = streaming_tar_upload(IMPORT_TAR_PATH, tmp_file, file_size); + gateway_docker + .upload_to_container( + container_name, + Some(options), + bollard::body_try_stream(upload_stream), + ) + .await + .into_diagnostic() + .wrap_err("failed to upload image tar into container")?; on_log("[progress] Uploaded to gateway".to_string()); // 3. Import the tar into containerd via ctr. @@ -93,59 +122,115 @@ pub async fn push_local_images( Ok(()) } -/// Collect the full export tar from `docker.export_images()` into memory. -async fn collect_export(docker: &Docker, images: &[&str]) -> Result> { +/// Stream the Docker image export directly to a temporary file. +/// +/// Returns the temp file handle and the total number of bytes written. +/// Memory usage is `O(chunk_size)` — only one chunk is held at a time. +/// Progress is reported every [`PROGRESS_INTERVAL_BYTES`] bytes. +async fn export_to_tempfile( + docker: &Docker, + images: &[&str], + on_log: &mut impl FnMut(String), +) -> Result<(tempfile::NamedTempFile, u64)> { + let tmp = tempfile::NamedTempFile::new() + .into_diagnostic() + .wrap_err("failed to create temp file for image export")?; + + // Open a second handle for async writing; the NamedTempFile retains + // ownership and ensures cleanup on drop. + let std_file = tmp + .reopen() + .into_diagnostic() + .wrap_err("failed to reopen temp file for writing")?; + let mut async_file = tokio::fs::File::from_std(std_file); + let mut stream = docker.export_images(images); - let mut buf = Vec::new(); + let mut total_bytes: u64 = 0; + let mut last_reported: u64 = 0; + while let Some(chunk) = stream.next().await { let bytes = chunk .into_diagnostic() .wrap_err("failed to read image export stream")?; - buf.extend_from_slice(&bytes); + async_file + .write_all(&bytes) + .await + .into_diagnostic() + .wrap_err("failed to write image data to temp file")?; + total_bytes += bytes.len() as u64; + + // Report progress periodically. + if total_bytes >= last_reported + PROGRESS_INTERVAL_BYTES { + let mb = total_bytes / (1024 * 1024); + on_log(format!("[progress] Exported {mb} MiB")); + last_reported = total_bytes; + } } - Ok(buf) -} -/// Wrap raw bytes as a single file inside a tar archive. -/// -/// The Docker `put_archive` API expects a tar that is extracted at a target -/// directory. We create a tar containing one entry whose name is the basename -/// of `file_path`, and upload it to the parent directory. -fn wrap_in_tar(file_path: &str, data: &[u8]) -> Result> { - let file_name = file_path.rsplit('/').next().unwrap_or(file_path); - - let mut builder = tar::Builder::new(Vec::new()); - let mut header = tar::Header::new_gnu(); - header.set_path(file_name).into_diagnostic()?; - header.set_size(data.len() as u64); - header.set_mode(0o644); - header.set_cksum(); - builder - .append(&header, data) - .into_diagnostic() - .wrap_err("failed to build tar archive for image upload")?; - builder - .into_inner() + async_file + .flush() + .await .into_diagnostic() - .wrap_err("failed to finalize tar archive") -} - -/// Upload a tar archive into the container at the parent directory of -/// [`IMPORT_TAR_PATH`]. -async fn upload_archive(docker: &Docker, container_name: &str, archive: &[u8]) -> Result<()> { - let parent_dir = IMPORT_TAR_PATH.rsplit_once('/').map_or("/", |(dir, _)| dir); + .wrap_err("failed to flush temp file")?; - let options = UploadToContainerOptionsBuilder::default() - .path(parent_dir) - .build(); + Ok((tmp, total_bytes)) +} - docker - .upload_to_container( - container_name, - Some(options), - bollard::body_full(Bytes::copy_from_slice(archive)), - ) - .await - .into_diagnostic() - .wrap_err("failed to upload image tar into container") +/// Create a stream that yields an outer tar archive containing the image tar +/// as a single entry, reading the image data from the temp file in chunks. +/// +/// The Docker `put_archive` API expects a tar that is extracted at a target +/// directory. We construct a tar with one entry whose name is the basename +/// of `file_path`. The stream yields: +/// 1. A 512-byte GNU tar header +/// 2. The file content in [`UPLOAD_CHUNK_SIZE`] chunks +/// 3. Padding to a 512-byte boundary + two 512-byte zero EOF blocks +/// +/// Memory usage is O([`UPLOAD_CHUNK_SIZE`]) regardless of file size. +fn streaming_tar_upload( + file_path: &str, + tmp_file: tempfile::NamedTempFile, + file_size: u64, +) -> Pin> + Send>> { + let file_name = file_path + .rsplit('/') + .next() + .unwrap_or(file_path) + .to_string(); + + Box::pin(async_stream::try_stream! { + // 1. Build and yield the tar header. + let mut header = tar::Header::new_gnu(); + header.set_path(&file_name)?; + header.set_size(file_size); + header.set_mode(0o644); + header.set_cksum(); + yield Bytes::copy_from_slice(header.as_bytes()); + + // 2. Stream the temp file content in chunks. + let std_file = tmp_file.reopen()?; + let mut async_file = tokio::fs::File::from_std(std_file); + let mut buf = vec![0u8; UPLOAD_CHUNK_SIZE]; + loop { + let n = async_file.read(&mut buf).await?; + if n == 0 { + break; + } + yield Bytes::copy_from_slice(&buf[..n]); + } + + // 3. Yield tar padding and EOF blocks. + // Tar entries must be padded to a 512-byte boundary, followed by + // two 512-byte zero blocks to signal end-of-archive. + let padding_len = if file_size.is_multiple_of(512) { + 0 + } else { + 512 - (file_size % 512) as usize + }; + let footer = vec![0u8; padding_len + 1024]; + yield Bytes::from(footer); + + // The NamedTempFile is dropped here, cleaning up the temp file. + drop(tmp_file); + }) } diff --git a/crates/openshell-bootstrap/src/runtime.rs b/crates/openshell-bootstrap/src/runtime.rs index 271fde8d4..0f9a96e6b 100644 --- a/crates/openshell-bootstrap/src/runtime.rs +++ b/crates/openshell-bootstrap/src/runtime.rs @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::constants::{KUBECONFIG_PATH, container_name}; +use crate::constants::{KUBECONFIG_PATH, container_name, node_name}; use bollard::Docker; use bollard::container::LogOutput; use bollard::exec::CreateExecOptions; @@ -362,57 +362,154 @@ pub async fn fetch_recent_logs(docker: &Docker, container_name: &str, n: usize) rendered } -/// Remove stale k3s nodes from a cluster with a reused persistent volume. +/// Remove stale k3s nodes and their orphaned pods from a resumed cluster. /// /// When a cluster container is recreated but the volume is reused, k3s registers /// a new node (using the container ID as the hostname) while old node entries /// persist in etcd. Pods scheduled on those stale `NotReady` nodes will never run, /// causing health checks to fail. /// -/// This function identifies all `NotReady` nodes and deletes them so k3s can -/// reschedule workloads onto the current (Ready) node. +/// This function retries with backoff until `kubectl` becomes available (k3s may +/// still be initialising), then: +/// 1. Deletes all `NotReady` nodes so k3s stops tracking them. +/// 2. Force-deletes any pods stuck in `Terminating` so `StatefulSets` and +/// Deployments can reschedule replacements on the current (Ready) node. /// /// Returns the number of stale nodes removed. pub async fn clean_stale_nodes(docker: &Docker, name: &str) -> Result { + // Retry until kubectl is responsive. k3s can take 10-20 s to start the + // API server after a container restart, so we allow up to ~45 s. + const MAX_ATTEMPTS: u32 = 15; + const RETRY_DELAY: Duration = Duration::from_secs(3); + let container_name = container_name(name); + let mut stale_nodes: Vec = Vec::new(); + + // Determine the current node name. With the deterministic `--node-name` + // entrypoint change the k3s node is `openshell-{gateway}`. However, older + // cluster images (built before that change) still use the container hostname + // (= Docker container ID) as the node name. We must handle both: + // + // 1. If the expected deterministic name appears in the node list, use it. + // 2. Otherwise fall back to the container hostname (old behaviour). + // + // This ensures backward compatibility during upgrades where the bootstrap + // CLI is newer than the cluster image. + let deterministic_node = node_name(name); + + for attempt in 1..=MAX_ATTEMPTS { + let (output, exit_code) = exec_capture_with_exit( + docker, + &container_name, + vec![ + "sh".to_string(), + "-c".to_string(), + format!( + "KUBECONFIG={KUBECONFIG_PATH} kubectl get nodes \ + --no-headers -o custom-columns=NAME:.metadata.name \ + 2>/dev/null" + ), + ], + ) + .await?; + + if exit_code == 0 { + let all_nodes: Vec<&str> = output + .lines() + .map(str::trim) + .filter(|l| !l.is_empty()) + .collect(); + + // Pick the current node identity: prefer the deterministic name, + // fall back to the container hostname for older cluster images. + let current_node = if all_nodes.contains(&deterministic_node.as_str()) { + deterministic_node.clone() + } else { + // Older cluster image without --node-name: read hostname. + let (hostname_out, _) = + exec_capture_with_exit(docker, &container_name, vec!["hostname".to_string()]) + .await?; + hostname_out.trim().to_string() + }; + + stale_nodes = all_nodes + .into_iter() + .filter(|n| *n != current_node) + .map(ToString::to_string) + .collect(); + break; + } + + if attempt < MAX_ATTEMPTS { + tracing::debug!( + "kubectl not ready yet (attempt {attempt}/{MAX_ATTEMPTS}), retrying in {}s", + RETRY_DELAY.as_secs() + ); + tokio::time::sleep(RETRY_DELAY).await; + } + } + + if stale_nodes.is_empty() { + return Ok(0); + } + + let node_list = stale_nodes.join(" "); + let count = stale_nodes.len(); + tracing::info!("removing {} stale node(s): {}", count, node_list); - // Get the list of NotReady nodes. - // The last condition on a node is always type=Ready; we need to check its - // **status** (True/False/Unknown), not its type. Nodes where the Ready - // condition status is not "True" are stale and should be removed. - let (output, exit_code) = exec_capture_with_exit( + // Step 1: delete the stale node objects. + let (_output, exit_code) = exec_capture_with_exit( docker, &container_name, vec![ "sh".to_string(), "-c".to_string(), format!( - "KUBECONFIG={KUBECONFIG_PATH} kubectl get nodes \ - --no-headers -o custom-columns=NAME:.metadata.name,STATUS:.status.conditions[-1].status \ - 2>/dev/null | grep -v '\\bTrue$' | awk '{{print $1}}'" + "KUBECONFIG={KUBECONFIG_PATH} kubectl delete node {node_list} --ignore-not-found" ), ], ) .await?; if exit_code != 0 { - // kubectl not ready yet or no nodes — nothing to clean - return Ok(0); + tracing::warn!("failed to delete stale nodes (exit code {exit_code})"); } - let stale_nodes: Vec<&str> = output - .lines() - .map(str::trim) - .filter(|l| !l.is_empty()) - .collect(); - if stale_nodes.is_empty() { - return Ok(0); - } + // Step 2: force-delete pods stuck in Terminating. After the stale node is + // removed, pods that were scheduled on it transition to Terminating but + // will never complete graceful shutdown (the node is gone). StatefulSets + // will not create a replacement until the old pod is fully deleted. + let (_output, exit_code) = exec_capture_with_exit( + docker, + &container_name, + vec![ + "sh".to_string(), + "-c".to_string(), + format!( + "KUBECONFIG={KUBECONFIG_PATH} kubectl get pods --all-namespaces \ + --field-selector=status.phase=Running -o name 2>/dev/null; \ + for pod_line in $(KUBECONFIG={KUBECONFIG_PATH} kubectl get pods --all-namespaces \ + --no-headers 2>/dev/null | awk '$4 == \"Terminating\" {{print $1\"/\"$2}}'); do \ + ns=${{pod_line%%/*}}; pod=${{pod_line#*/}}; \ + KUBECONFIG={KUBECONFIG_PATH} kubectl delete pod \"$pod\" -n \"$ns\" \ + --force --grace-period=0 --ignore-not-found 2>/dev/null; \ + done" + ), + ], + ) + .await?; - let node_list = stale_nodes.join(" "); - let count = stale_nodes.len(); - tracing::info!("removing {} stale node(s): {}", count, node_list); + if exit_code != 0 { + tracing::debug!( + "force-delete of terminating pods returned exit code {exit_code} (non-fatal)" + ); + } + // Step 3: delete PersistentVolumeClaims in the openshell namespace whose + // backing PV has node affinity for a stale node. local-path-provisioner + // creates PVs tied to the original node; when the node changes, the PV is + // unschedulable and the `StatefulSet` pod stays Pending. Deleting the PVC + // (and its PV) lets the provisioner create a fresh one on the current node. let (_output, exit_code) = exec_capture_with_exit( docker, &container_name, @@ -420,14 +517,24 @@ pub async fn clean_stale_nodes(docker: &Docker, name: &str) -> Result { "sh".to_string(), "-c".to_string(), format!( - "KUBECONFIG={KUBECONFIG_PATH} kubectl delete node {node_list} --ignore-not-found" + r#"KUBECONFIG={KUBECONFIG_PATH}; export KUBECONFIG; \ + CURRENT_NODE=$(kubectl get nodes --no-headers -o custom-columns=NAME:.metadata.name 2>/dev/null | head -1); \ + [ -z "$CURRENT_NODE" ] && exit 0; \ + for pv in $(kubectl get pv -o jsonpath='{{.items[*].metadata.name}}' 2>/dev/null); do \ + NODE=$(kubectl get pv "$pv" -o jsonpath='{{.spec.nodeAffinity.required.nodeSelectorTerms[0].matchExpressions[0].values[0]}}' 2>/dev/null); \ + [ "$NODE" = "$CURRENT_NODE" ] && continue; \ + NS=$(kubectl get pv "$pv" -o jsonpath='{{.spec.claimRef.namespace}}' 2>/dev/null); \ + PVC=$(kubectl get pv "$pv" -o jsonpath='{{.spec.claimRef.name}}' 2>/dev/null); \ + [ -n "$PVC" ] && kubectl delete pvc "$PVC" -n "$NS" --ignore-not-found 2>/dev/null; \ + kubectl delete pv "$pv" --ignore-not-found 2>/dev/null; \ + done"# ), ], ) .await?; if exit_code != 0 { - tracing::warn!("failed to delete stale nodes (exit code {exit_code})"); + tracing::debug!("PV/PVC cleanup returned exit code {exit_code} (non-fatal)"); } Ok(count) diff --git a/crates/openshell-cli/src/bootstrap.rs b/crates/openshell-cli/src/bootstrap.rs index e976061fa..ee9a481aa 100644 --- a/crates/openshell-cli/src/bootstrap.rs +++ b/crates/openshell-cli/src/bootstrap.rs @@ -144,43 +144,62 @@ pub async fn run_bootstrap( ); eprintln!(); - // Auto-bootstrap always recreates if stale Docker resources are found - // (e.g. metadata was deleted but container/volume still exist). - let mut options = openshell_bootstrap::DeployOptions::new(&gateway_name).with_recreate(true); - if let Some(dest) = remote { - let mut remote_opts = openshell_bootstrap::RemoteOptions::new(dest); - if let Some(key) = ssh_key { - remote_opts = remote_opts.with_ssh_key(key); + // Build deploy options. The deploy flow auto-resumes from existing state + // (preserving sandboxes and secrets) when it finds an existing gateway. + // If the initial attempt fails, fall back to a full recreate. + let build_options = |recreate: bool| { + let mut opts = openshell_bootstrap::DeployOptions::new(&gateway_name) + .with_recreate(recreate) + .with_gpu(if gpu { + vec!["auto".to_string()] + } else { + vec![] + }); + if let Some(dest) = remote { + let mut remote_opts = openshell_bootstrap::RemoteOptions::new(dest); + if let Some(key) = ssh_key { + remote_opts = remote_opts.with_ssh_key(key); + } + opts = opts.with_remote(remote_opts); } - options = options.with_remote(remote_opts); - } - // Read registry credentials from environment for the auto-bootstrap path. - // The explicit `--registry-username` / `--registry-token` flags are only - // on `gateway start`; when bootstrapping via `sandbox create`, the env - // vars are the mechanism. - if let Ok(username) = std::env::var("OPENSHELL_REGISTRY_USERNAME") - && !username.trim().is_empty() - { - options = options.with_registry_username(username); - } - if let Ok(token) = std::env::var("OPENSHELL_REGISTRY_TOKEN") - && !token.trim().is_empty() - { - options = options.with_registry_token(token); - } - // Read gateway host override from environment. Needed whenever the - // client cannot reach the Docker host at 127.0.0.1 — CI containers, - // WSL, remote Docker hosts, etc. The explicit `--gateway-host` flag - // is only on `gateway start`; this env var covers the auto-bootstrap - // path triggered by `sandbox create`. - if let Ok(host) = std::env::var("OPENSHELL_GATEWAY_HOST") - && !host.trim().is_empty() - { - options = options.with_gateway_host(host); - } - options = options.with_gpu(gpu); + // Read registry credentials from environment for the auto-bootstrap path. + // The explicit `--registry-username` / `--registry-token` flags are only + // on `gateway start`; when bootstrapping via `sandbox create`, the env + // vars are the mechanism. + if let Ok(username) = std::env::var("OPENSHELL_REGISTRY_USERNAME") + && !username.trim().is_empty() + { + opts = opts.with_registry_username(username); + } + if let Ok(token) = std::env::var("OPENSHELL_REGISTRY_TOKEN") + && !token.trim().is_empty() + { + opts = opts.with_registry_token(token); + } + // Read gateway host override from environment. Needed whenever the + // client cannot reach the Docker host at 127.0.0.1 — CI containers, + // WSL, remote Docker hosts, etc. The explicit `--gateway-host` flag + // is only on `gateway start`; this env var covers the auto-bootstrap + // path triggered by `sandbox create`. + if let Ok(host) = std::env::var("OPENSHELL_GATEWAY_HOST") + && !host.trim().is_empty() + { + opts = opts.with_gateway_host(host); + } + opts + }; - let handle = deploy_gateway_with_panel(options, &gateway_name, location).await?; + // Deploy the gateway. The deploy flow auto-resumes from existing state + // when it finds one. If that fails, fall back to a full recreate. + let handle = match deploy_gateway_with_panel(build_options(false), &gateway_name, location) + .await + { + Ok(handle) => handle, + Err(resume_err) => { + tracing::warn!("auto-bootstrap resume failed, falling back to recreate: {resume_err}"); + deploy_gateway_with_panel(build_options(true), &gateway_name, location).await? + } + }; let server = handle.gateway_endpoint().to_string(); print_deploy_summary(&gateway_name, &handle); @@ -206,9 +225,13 @@ pub async fn run_bootstrap( /// Retry connecting to the gateway gRPC endpoint until it succeeds or a /// timeout is reached. Uses exponential backoff starting at 500 ms, doubling -/// up to 4 s, with a total deadline of 30 s. -async fn wait_for_grpc_ready(server: &str, tls: &TlsOptions) -> Result<()> { - const MAX_WAIT: Duration = Duration::from_secs(30); +/// up to 4 s, with a total deadline of 90 s. +/// +/// The generous timeout accounts for gateway resume scenarios where stale k3s +/// nodes must be cleaned up and workload pods rescheduled before the gRPC +/// endpoint becomes available. +pub(crate) async fn wait_for_grpc_ready(server: &str, tls: &TlsOptions) -> Result<()> { + const MAX_WAIT: Duration = Duration::from_secs(90); const INITIAL_BACKOFF: Duration = Duration::from_millis(500); let start = std::time::Instant::now(); @@ -232,7 +255,7 @@ async fn wait_for_grpc_ready(server: &str, tls: &TlsOptions) -> Result<()> { Err(last_err .unwrap_or_else(|| miette::miette!("timed out waiting for gateway")) - .wrap_err("gateway deployed but not accepting connections after 30 s")) + .wrap_err("gateway deployed but not accepting connections after 90 s")) } #[cfg(test)] diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index 5de31c79c..87d377b39 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -597,6 +597,7 @@ enum CliProviderType { Claude, Opencode, Codex, + Copilot, Generic, Openai, Anthropic, @@ -627,6 +628,7 @@ impl CliProviderType { Self::Claude => "claude", Self::Opencode => "opencode", Self::Codex => "codex", + Self::Copilot => "copilot", Self::Generic => "generic", Self::Openai => "openai", Self::Anthropic => "anthropic", @@ -807,6 +809,10 @@ enum GatewayCommands { /// NVIDIA k8s-device-plugin so Kubernetes workloads can request /// `nvidia.com/gpu` resources. Requires NVIDIA drivers and the /// NVIDIA Container Toolkit on the host. + /// + /// When enabled, OpenShell auto-selects CDI when the Docker daemon has + /// CDI enabled and falls back to Docker's NVIDIA GPU request path + /// (`--gpus all`) otherwise. #[arg(long)] gpu: bool, }, @@ -937,6 +943,10 @@ enum InferenceCommands { /// Skip endpoint verification before saving the route. #[arg(long)] no_verify: bool, + + /// Request timeout in seconds for inference calls (0 = default 60s). + #[arg(long, default_value_t = 0)] + timeout: u64, }, /// Update gateway-level inference configuration (partial update). @@ -957,6 +967,10 @@ enum InferenceCommands { /// Skip endpoint verification before saving the route. #[arg(long)] no_verify: bool, + + /// Request timeout in seconds for inference calls (0 = default 60s, unchanged if omitted). + #[arg(long)] + timeout: Option, }, /// Get gateway-level inference provider and model. @@ -1077,8 +1091,8 @@ enum SandboxCommands { /// Upload local files into the sandbox before running. /// /// Format: `[:]`. - /// When `SANDBOX_PATH` is omitted, files are uploaded to the container - /// working directory (`/sandbox`). + /// When `SANDBOX_PATH` is omitted, files are uploaded to the container's + /// working directory. /// `.gitignore` rules are applied by default; use `--no-git-ignore` to /// upload everything. #[arg(long, value_hint = ValueHint::AnyPath, help_heading = "UPLOAD FLAGS")] @@ -1104,8 +1118,10 @@ enum SandboxCommands { /// Request GPU resources for the sandbox. /// /// When no gateway is running, auto-bootstrap starts a GPU-enabled - /// gateway. GPU intent is also inferred automatically for known - /// GPU-designated image names such as `nvidia-gpu`. + /// gateway using the same automatic injection selection as + /// `openshell gateway start --gpu`. GPU intent is also inferred + /// automatically for known GPU-designated image names such as + /// `nvidia-gpu`. #[arg(long)] gpu: bool, @@ -1213,6 +1229,48 @@ enum SandboxCommands { all: bool, }, + /// Execute a command in a running sandbox. + /// + /// Runs a command inside an existing sandbox using the gRPC exec endpoint. + /// Output is streamed to the terminal in real-time. The CLI exits with the + /// remote command's exit code. + /// + /// For interactive shell sessions, use `sandbox connect` instead. + /// + /// Examples: + /// openshell sandbox exec --name my-sandbox -- ls -la /workspace + /// openshell sandbox exec -n my-sandbox --workdir /app -- python script.py + /// echo "hello" | openshell sandbox exec -n my-sandbox -- cat + #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] + Exec { + /// Sandbox name (defaults to last-used sandbox). + #[arg(long, short = 'n', add = ArgValueCompleter::new(completers::complete_sandbox_names))] + name: Option, + + /// Working directory inside the sandbox. + #[arg(long)] + workdir: Option, + + /// Timeout in seconds (0 = no timeout). + #[arg(long, default_value_t = 0)] + timeout: u32, + + /// Allocate a pseudo-terminal for the remote command. + /// Defaults to auto-detection (on when stdin and stdout are terminals). + /// Use --tty to force a PTY even when auto-detection fails, or + /// --no-tty to disable. + #[arg(long, overrides_with = "no_tty")] + tty: bool, + + /// Disable pseudo-terminal allocation. + #[arg(long, overrides_with = "tty")] + no_tty: bool, + + /// Command and arguments to execute. + #[arg(required = true, trailing_var_arg = true, allow_hyphen_values = true)] + command: Vec, + }, + /// Connect to a sandbox. /// /// When no name is given, reconnects to the last-used sandbox. @@ -1239,7 +1297,7 @@ enum SandboxCommands { #[arg(value_hint = ValueHint::AnyPath)] local_path: String, - /// Destination path in the sandbox (defaults to `/sandbox`). + /// Destination path in the sandbox (defaults to the container's working directory). dest: Option, /// Disable `.gitignore` filtering (uploads everything). @@ -1562,6 +1620,11 @@ async fn main() -> Result<()> { registry_token, gpu, } => { + let gpu = if gpu { + vec!["auto".to_string()] + } else { + vec![] + }; run::gateway_admin_deploy( &name, remote.as_deref(), @@ -2026,10 +2089,11 @@ async fn main() -> Result<()> { model, system, no_verify, + timeout, } => { let route_name = if system { "sandbox-system" } else { "" }; run::gateway_inference_set( - endpoint, &provider, &model, route_name, no_verify, &tls, + endpoint, &provider, &model, route_name, no_verify, timeout, &tls, ) .await?; } @@ -2038,6 +2102,7 @@ async fn main() -> Result<()> { model, system, no_verify, + timeout, } => { let route_name = if system { "sandbox-system" } else { "" }; run::gateway_inference_update( @@ -2046,6 +2111,7 @@ async fn main() -> Result<()> { model.as_deref(), route_name, no_verify, + timeout, &tls, ) .await?; @@ -2200,7 +2266,7 @@ async fn main() -> Result<()> { let ctx = resolve_gateway(&cli.gateway, &cli.gateway_endpoint)?; let mut tls = tls.with_gateway_name(&ctx.name); apply_edge_auth(&mut tls, &ctx.name); - let sandbox_dest = dest.as_deref().unwrap_or("/sandbox"); + let sandbox_dest = dest.as_deref(); let local = std::path::Path::new(&local_path); if !local.exists() { return Err(miette::miette!( @@ -2208,7 +2274,8 @@ async fn main() -> Result<()> { local.display() )); } - eprintln!("Uploading {} -> sandbox:{}", local.display(), sandbox_dest); + let dest_display = sandbox_dest.unwrap_or("~"); + eprintln!("Uploading {} -> sandbox:{}", local.display(), dest_display); if !no_git_ignore && let Ok((base_dir, files)) = run::git_sync_files(local) { run::sandbox_sync_up_files( &ctx.endpoint, @@ -2282,6 +2349,38 @@ async fn main() -> Result<()> { } let _ = save_last_sandbox(&ctx.name, &name); } + SandboxCommands::Exec { + name, + workdir, + timeout, + tty, + no_tty, + command, + } => { + let name = resolve_sandbox_name(name, &ctx.name)?; + // Resolve --tty / --no-tty into an Option override. + let tty_override = if no_tty { + Some(false) + } else if tty { + Some(true) + } else { + None // auto-detect + }; + let exit_code = run::sandbox_exec_grpc( + endpoint, + &name, + &command, + workdir.as_deref(), + timeout, + tty_override, + &tls, + ) + .await?; + let _ = save_last_sandbox(&ctx.name, &name); + if exit_code != 0 { + std::process::exit(exit_code); + } + } SandboxCommands::SshConfig { name } => { let name = resolve_sandbox_name(name, &ctx.name)?; run::print_ssh_config(&ctx.name, &name); @@ -3115,4 +3214,29 @@ mod tests { other => panic!("expected settings delete command, got: {other:?}"), } } + + /// Ensure every provider registered in `ProviderRegistry` has a + /// corresponding `CliProviderType` variant (and vice-versa). + /// This test would have caught the missing `Copilot` variant from #707. + #[test] + fn cli_provider_types_match_registry() { + let registry = openshell_providers::ProviderRegistry::new(); + let registry_types: std::collections::BTreeSet<&str> = + registry.known_types().into_iter().collect(); + + let cli_types: std::collections::BTreeSet<&str> = + ::value_variants() + .iter() + .map(CliProviderType::as_str) + .collect(); + + assert_eq!( + cli_types, + registry_types, + "CliProviderType variants must match ProviderRegistry.known_types(). \ + CLI-only: {:?}, Registry-only: {:?}", + cli_types.difference(®istry_types).collect::>(), + registry_types.difference(&cli_types).collect::>(), + ); + } } diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index e32eec2a4..c40640c30 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -24,13 +24,13 @@ use openshell_bootstrap::{ use openshell_core::proto::{ ApproveAllDraftChunksRequest, ApproveDraftChunkRequest, ClearDraftChunksRequest, CreateProviderRequest, CreateSandboxRequest, DeleteProviderRequest, DeleteSandboxRequest, - GetClusterInferenceRequest, GetDraftHistoryRequest, GetDraftPolicyRequest, + ExecSandboxRequest, GetClusterInferenceRequest, GetDraftHistoryRequest, GetDraftPolicyRequest, GetGatewayConfigRequest, GetProviderRequest, GetSandboxConfigRequest, GetSandboxLogsRequest, GetSandboxPolicyStatusRequest, GetSandboxRequest, HealthRequest, ListProvidersRequest, ListSandboxPoliciesRequest, ListSandboxesRequest, PolicyStatus, Provider, RejectDraftChunkRequest, Sandbox, SandboxPhase, SandboxPolicy, SandboxSpec, SandboxTemplate, SetClusterInferenceRequest, SettingScope, SettingValue, UpdateConfigRequest, - UpdateProviderRequest, WatchSandboxRequest, setting_value, + UpdateProviderRequest, WatchSandboxRequest, exec_sandbox_event, setting_value, }; use openshell_core::settings::{self, SettingValueKind}; use openshell_providers::{ @@ -38,7 +38,7 @@ use openshell_providers::{ }; use owo_colors::OwoColorize; use std::collections::{HashMap, HashSet, VecDeque}; -use std::io::{IsTerminal, Write}; +use std::io::{IsTerminal, Read, Write}; use std::path::{Path, PathBuf}; use std::process::Command; use std::time::{Duration, Instant}; @@ -1355,7 +1355,7 @@ pub async fn gateway_admin_deploy( disable_gateway_auth: bool, registry_username: Option<&str>, registry_token: Option<&str>, - gpu: bool, + gpu: Vec, ) -> Result<()> { let location = if remote.is_some() { "remote" } else { "local" }; @@ -1369,62 +1369,51 @@ pub async fn gateway_admin_deploy( opts }); - // Check whether a gateway already exists. If so, prompt the user (unless - // --recreate was passed or we're in non-interactive mode). - let mut should_recreate = recreate; - if let Some(existing) = - openshell_bootstrap::check_existing_deployment(name, remote_opts.as_ref()).await? - { - if !should_recreate { - let interactive = std::io::stdin().is_terminal() && std::io::stderr().is_terminal(); - if interactive { - let status = if existing.container_running { - "running" - } else if existing.container_exists { - "stopped" - } else { - "volume only" - }; - eprintln!(); + // If the gateway is already running and we're not recreating, short-circuit. + if !recreate { + if let Some(existing) = + openshell_bootstrap::check_existing_deployment(name, remote_opts.as_ref()).await? + { + if existing.container_running { eprintln!( - "{} Gateway '{name}' already exists ({status}).", - "!".yellow().bold() + "{} Gateway '{name}' is already running.", + "✓".green().bold() ); - if let Some(image) = &existing.container_image { - eprintln!(" {} {}", "Image:".dimmed(), image); - } - eprintln!(); - eprint!("Destroy and recreate? [y/N] "); - std::io::stderr().flush().ok(); - let mut input = String::new(); - std::io::stdin() - .read_line(&mut input) - .into_diagnostic() - .wrap_err("failed to read user input")?; - let choice = input.trim().to_lowercase(); - should_recreate = choice == "y" || choice == "yes"; - if !should_recreate { - eprintln!("Keeping existing gateway."); - return Ok(()); - } - } else { - // Non-interactive mode: reuse existing gateway silently. - eprintln!("Gateway '{name}' already exists, reusing."); return Ok(()); } } } + // When resuming an existing gateway (not recreating), prefer the port + // and gateway host from stored metadata over the CLI defaults. The user + // may have originally bootstrapped on a non-default port (e.g. `--port + // 8082`) or with `--gateway-host host.docker.internal`, and a bare + // `gateway start` without those flags should honour the original values. + let stored_metadata = if !recreate { + openshell_bootstrap::load_gateway_metadata(name).ok() + } else { + None + }; + let effective_port = stored_metadata + .as_ref() + .filter(|m| m.gateway_port > 0) + .map_or(port, |m| m.gateway_port); + let effective_gateway_host: Option = gateway_host.map(String::from).or_else(|| { + stored_metadata + .as_ref() + .and_then(|m| m.gateway_host().map(String::from)) + }); + let mut options = DeployOptions::new(name) - .with_port(port) + .with_port(effective_port) .with_disable_tls(disable_tls) .with_disable_gateway_auth(disable_gateway_auth) .with_gpu(gpu) - .with_recreate(should_recreate); + .with_recreate(recreate); if let Some(opts) = remote_opts { options = options.with_remote(opts); } - if let Some(host) = gateway_host { + if let Some(host) = effective_gateway_host { options = options.with_gateway_host(host); } if let Some(username) = registry_username { @@ -1436,6 +1425,15 @@ pub async fn gateway_admin_deploy( let handle = deploy_gateway_with_panel(options, name, location).await?; + // Wait for the gRPC endpoint to actually accept connections before + // declaring the gateway ready. The Docker health check may pass before + // the gRPC listener inside the pod is fully bound. + let server = handle.gateway_endpoint().to_string(); + let tls = TlsOptions::default() + .with_gateway_name(name) + .with_default_paths(&server); + crate::bootstrap::wait_for_grpc_ready(&server, &tls).await?; + print_deploy_summary(name, &handle); // Auto-activate: set this gateway as the active gateway. @@ -2045,7 +2043,16 @@ pub async fn sandbox_create( name: name.unwrap_or_default().to_string(), }; - let response = client.create_sandbox(request).await.into_diagnostic()?; + let response = match client.create_sandbox(request).await { + Ok(resp) => resp, + Err(status) if status.code() == Code::AlreadyExists => { + return Err(miette::miette!( + "{}\n\nhint: delete it first with: openshell sandbox delete \n or use a different name", + status.message() + )); + } + Err(status) => return Err(status).into_diagnostic(), + }; let sandbox = response .into_inner() .sandbox @@ -2300,8 +2307,12 @@ pub async fn sandbox_create( drop(client); if let Some((local_path, sandbox_path, git_ignore)) = upload { - let dest = sandbox_path.as_deref().unwrap_or("/sandbox"); - eprintln!(" {} Uploading files to {dest}...", "\u{2022}".dimmed(),); + let dest = sandbox_path.as_deref(); + let dest_display = dest.unwrap_or("~"); + eprintln!( + " {} Uploading files to {dest_display}...", + "\u{2022}".dimmed(), + ); let local = Path::new(local_path); if *git_ignore && let Ok((base_dir, files)) = git_sync_files(local) { sandbox_sync_up_files( @@ -2619,7 +2630,6 @@ pub async fn sandbox_sync_command( ) -> Result<()> { match (up, down) { (Some(local_path), None) => { - let sandbox_dest = dest.unwrap_or("/sandbox"); let local = Path::new(local_path); if !local.exists() { return Err(miette::miette!( @@ -2627,8 +2637,9 @@ pub async fn sandbox_sync_command( local.display() )); } - eprintln!("Syncing {} -> sandbox:{}", local.display(), sandbox_dest); - sandbox_sync_up(server, name, local, sandbox_dest, tls).await?; + let dest_display = dest.unwrap_or("~"); + eprintln!("Syncing {} -> sandbox:{}", local.display(), dest_display); + sandbox_sync_up(server, name, local, dest, tls).await?; eprintln!("{} Sync complete", "✓".green().bold()); } (None, Some(sandbox_path)) => { @@ -2682,6 +2693,116 @@ pub async fn sandbox_get(server: &str, name: &str, tls: &TlsOptions) -> Result<( Ok(()) } +/// Maximum stdin payload size (4 MiB). Prevents the CLI from reading unbounded +/// data into memory before the server rejects an oversized message. +const MAX_STDIN_PAYLOAD: usize = 4 * 1024 * 1024; + +/// Execute a command in a running sandbox via gRPC, streaming output to the terminal. +/// +/// Returns the remote command's exit code. +pub async fn sandbox_exec_grpc( + server: &str, + name: &str, + command: &[String], + workdir: Option<&str>, + timeout_seconds: u32, + tty_override: Option, + tls: &TlsOptions, +) -> Result { + let mut client = grpc_client(server, tls).await?; + + // Resolve sandbox name to id. + let sandbox = client + .get_sandbox(GetSandboxRequest { + name: name.to_string(), + }) + .await + .into_diagnostic()? + .into_inner() + .sandbox + .ok_or_else(|| miette::miette!("sandbox not found"))?; + + // Verify the sandbox is ready before issuing the exec. + if SandboxPhase::try_from(sandbox.phase) != Ok(SandboxPhase::Ready) { + return Err(miette::miette!( + "sandbox '{}' is not ready (phase: {}); wait for it to reach Ready state", + name, + phase_name(sandbox.phase) + )); + } + + // Read stdin if piped (not a TTY), using spawn_blocking to avoid blocking + // the async runtime. Cap the read at MAX_STDIN_PAYLOAD + 1 so we never + // buffer more than the limit into memory. + let stdin_payload = if !std::io::stdin().is_terminal() { + tokio::task::spawn_blocking(|| { + let limit = (MAX_STDIN_PAYLOAD + 1) as u64; + let mut buf = Vec::new(); + std::io::stdin() + .take(limit) + .read_to_end(&mut buf) + .into_diagnostic()?; + if buf.len() > MAX_STDIN_PAYLOAD { + return Err(miette::miette!( + "stdin payload exceeds {} byte limit; pipe smaller inputs or use `sandbox upload`", + MAX_STDIN_PAYLOAD + )); + } + Ok(buf) + }) + .await + .into_diagnostic()?? // first ? unwraps JoinError, second ? unwraps Result + } else { + Vec::new() + }; + + // Resolve TTY mode: explicit --tty / --no-tty wins, otherwise auto-detect. + let tty = tty_override + .unwrap_or_else(|| std::io::stdin().is_terminal() && std::io::stdout().is_terminal()); + + // Make the streaming gRPC call. + let mut stream = client + .exec_sandbox(ExecSandboxRequest { + sandbox_id: sandbox.id, + command: command.to_vec(), + workdir: workdir.unwrap_or_default().to_string(), + environment: HashMap::new(), + timeout_seconds, + stdin: stdin_payload, + tty, + }) + .await + .into_diagnostic()? + .into_inner(); + + // Stream output to terminal in real-time. + let mut exit_code = 0i32; + let stdout = std::io::stdout(); + let stderr = std::io::stderr(); + + while let Some(event) = stream.next().await { + let event = event.into_diagnostic()?; + match event.payload { + Some(exec_sandbox_event::Payload::Stdout(out)) => { + let mut handle = stdout.lock(); + handle.write_all(&out.data).into_diagnostic()?; + handle.flush().into_diagnostic()?; + } + Some(exec_sandbox_event::Payload::Stderr(err)) => { + let mut handle = stderr.lock(); + handle.write_all(&err.data).into_diagnostic()?; + handle.flush().into_diagnostic()?; + } + Some(exec_sandbox_event::Payload::Exit(exit)) => { + exit_code = exit.exit_code; + } + None => {} + } + } + + Ok(exit_code) +} + /// Print a single YAML line with dimmed keys and regular values. fn print_yaml_line(line: &str) { // Find leading whitespace @@ -3481,6 +3602,7 @@ pub async fn gateway_inference_set( model_id: &str, route_name: &str, no_verify: bool, + timeout_secs: u64, tls: &TlsOptions, ) -> Result<()> { let progress = if std::io::stdout().is_terminal() { @@ -3504,6 +3626,7 @@ pub async fn gateway_inference_set( route_name: route_name.to_string(), verify: false, no_verify, + timeout_secs, }) .await; @@ -3525,6 +3648,7 @@ pub async fn gateway_inference_set( println!(" {} {}", "Provider:".dimmed(), configured.provider_name); println!(" {} {}", "Model:".dimmed(), configured.model_id); println!(" {} {}", "Version:".dimmed(), configured.version); + print_timeout(configured.timeout_secs); if configured.validation_performed { println!(" {}", "Validated Endpoints:".dimmed()); for endpoint in configured.validated_endpoints { @@ -3540,11 +3664,12 @@ pub async fn gateway_inference_update( model_id: Option<&str>, route_name: &str, no_verify: bool, + timeout_secs: Option, tls: &TlsOptions, ) -> Result<()> { - if provider_name.is_none() && model_id.is_none() { + if provider_name.is_none() && model_id.is_none() && timeout_secs.is_none() { return Err(miette::miette!( - "at least one of --provider or --model must be specified" + "at least one of --provider, --model, or --timeout must be specified" )); } @@ -3561,6 +3686,7 @@ pub async fn gateway_inference_update( let provider = provider_name.unwrap_or(¤t.provider_name); let model = model_id.unwrap_or(¤t.model_id); + let timeout = timeout_secs.unwrap_or(current.timeout_secs); let progress = if std::io::stdout().is_terminal() { let spinner = ProgressBar::new_spinner(); @@ -3582,6 +3708,7 @@ pub async fn gateway_inference_update( route_name: route_name.to_string(), verify: false, no_verify, + timeout_secs: timeout, }) .await; @@ -3603,6 +3730,7 @@ pub async fn gateway_inference_update( println!(" {} {}", "Provider:".dimmed(), configured.provider_name); println!(" {} {}", "Model:".dimmed(), configured.model_id); println!(" {} {}", "Version:".dimmed(), configured.version); + print_timeout(configured.timeout_secs); if configured.validation_performed { println!(" {}", "Validated Endpoints:".dimmed()); for endpoint in configured.validated_endpoints { @@ -3639,6 +3767,7 @@ pub async fn gateway_inference_get( println!(" {} {}", "Provider:".dimmed(), configured.provider_name); println!(" {} {}", "Model:".dimmed(), configured.model_id); println!(" {} {}", "Version:".dimmed(), configured.version); + print_timeout(configured.timeout_secs); } else { // Show both routes by default. print_inference_route(&mut client, "Gateway inference", "").await; @@ -3666,6 +3795,7 @@ async fn print_inference_route( println!(" {} {}", "Provider:".dimmed(), configured.provider_name); println!(" {} {}", "Model:".dimmed(), configured.model_id); println!(" {} {}", "Version:".dimmed(), configured.version); + print_timeout(configured.timeout_secs); } Err(e) if e.code() == Code::NotFound => { println!("{}", format!("{label}:").cyan().bold()); @@ -3680,6 +3810,14 @@ async fn print_inference_route( } } +fn print_timeout(timeout_secs: u64) { + if timeout_secs == 0 { + println!(" {} {}s (default)", "Timeout:".dimmed(), 60); + } else { + println!(" {} {}s", "Timeout:".dimmed(), timeout_secs); + } +} + fn format_inference_status(status: Status) -> miette::Report { let message = status.message().trim(); diff --git a/crates/openshell-cli/src/ssh.rs b/crates/openshell-cli/src/ssh.rs index 4b284bff1..ebcbbeb4f 100644 --- a/crates/openshell-cli/src/ssh.rs +++ b/crates/openshell-cli/src/ssh.rs @@ -447,33 +447,51 @@ pub(crate) async fn sandbox_exec_without_exec( sandbox_exec_with_mode(server, name, command, tty, tls, false).await } -/// Push a list of files from a local directory into a sandbox using tar-over-SSH. +/// What to pack into the tar archive streamed to the sandbox. +enum UploadSource { + /// A single local file or directory. `tar_name` controls the entry name + /// inside the archive (e.g. the target basename for file-to-file uploads). + SinglePath { + local_path: PathBuf, + tar_name: std::ffi::OsString, + }, + /// A set of files relative to a base directory (git-filtered uploads). + FileList { + base_dir: PathBuf, + files: Vec, + }, +} + +/// Core tar-over-SSH upload: streams a tar archive into `dest_dir` on the +/// sandbox. Callers are responsible for splitting the destination path so +/// that `dest_dir` is always a directory. /// -/// This replaces the old rsync-based sync. Files are streamed as a tar archive -/// to `ssh ... tar xf - -C ` on the sandbox side. -pub async fn sandbox_sync_up_files( +/// When `dest_dir` is `None`, the sandbox user's home directory (`$HOME`) is +/// used as the extraction target. This avoids hard-coding any particular +/// path and works for custom container images with non-default `WORKDIR`. +async fn ssh_tar_upload( server: &str, name: &str, - base_dir: &Path, - files: &[String], - dest: &str, + dest_dir: Option<&str>, + source: UploadSource, tls: &TlsOptions, ) -> Result<()> { - if files.is_empty() { - return Ok(()); - } - let session = ssh_session_config(server, name, tls).await?; + // When no explicit destination is given, use the unescaped `$HOME` shell + // variable so the remote shell resolves it at runtime. + let escaped_dest = match dest_dir { + Some(d) => shell_escape(d), + None => "$HOME".to_string(), + }; + let mut ssh = ssh_base_command(&session.proxy_command); ssh.arg("-T") .arg("-o") .arg("RequestTTY=no") .arg("sandbox") .arg(format!( - "mkdir -p {} && cat | tar xf - -C {}", - shell_escape(dest), - shell_escape(dest) + "mkdir -p {escaped_dest} && cat | tar xf - -C {escaped_dest}", )) .stdin(Stdio::piped()) .stdout(Stdio::inherit()) @@ -486,22 +504,43 @@ pub async fn sandbox_sync_up_files( .ok_or_else(|| miette::miette!("failed to open stdin for ssh process"))?; // Build the tar archive in a blocking task since the tar crate is synchronous. - let base_dir = base_dir.to_path_buf(); - let files = files.to_vec(); tokio::task::spawn_blocking(move || -> Result<()> { let mut archive = tar::Builder::new(stdin); - for file in &files { - let full_path = base_dir.join(file); - if full_path.is_file() { - archive - .append_path_with_name(&full_path, file) - .into_diagnostic() - .wrap_err_with(|| format!("failed to add {file} to tar archive"))?; - } else if full_path.is_dir() { - archive - .append_dir_all(file, &full_path) - .into_diagnostic() - .wrap_err_with(|| format!("failed to add directory {file} to tar archive"))?; + match source { + UploadSource::SinglePath { + local_path, + tar_name, + } => { + if local_path.is_file() { + archive + .append_path_with_name(&local_path, &tar_name) + .into_diagnostic()?; + } else if local_path.is_dir() { + archive.append_dir_all(".", &local_path).into_diagnostic()?; + } else { + return Err(miette::miette!( + "local path does not exist: {}", + local_path.display() + )); + } + } + UploadSource::FileList { base_dir, files } => { + for file in &files { + let full_path = base_dir.join(file); + if full_path.is_file() { + archive + .append_path_with_name(&full_path, file) + .into_diagnostic() + .wrap_err_with(|| format!("failed to add {file} to tar archive"))?; + } else if full_path.is_dir() { + archive + .append_dir_all(file, &full_path) + .into_diagnostic() + .wrap_err_with(|| { + format!("failed to add directory {file} to tar archive") + })?; + } + } } } archive.finish().into_diagnostic()?; @@ -524,72 +563,112 @@ pub async fn sandbox_sync_up_files( Ok(()) } +/// Split a sandbox path into (parent_directory, basename). +/// +/// Examples: +/// `"/sandbox/.bashrc"` -> `("/sandbox", ".bashrc")` +/// `"/sandbox/sub/file"` -> `("/sandbox/sub", "file")` +/// `"file.txt"` -> `(".", "file.txt")` +fn split_sandbox_path(path: &str) -> (&str, &str) { + match path.rfind('/') { + Some(0) => ("/", &path[1..]), + Some(pos) => (&path[..pos], &path[pos + 1..]), + None => (".", path), + } +} + +/// Push a list of files from a local directory into a sandbox using tar-over-SSH. +/// +/// Files are streamed as a tar archive to `ssh ... tar xf - -C ` on +/// the sandbox side. When `dest` is `None`, files are uploaded to the +/// sandbox user's home directory. +pub async fn sandbox_sync_up_files( + server: &str, + name: &str, + base_dir: &Path, + files: &[String], + dest: Option<&str>, + tls: &TlsOptions, +) -> Result<()> { + if files.is_empty() { + return Ok(()); + } + ssh_tar_upload( + server, + name, + dest, + UploadSource::FileList { + base_dir: base_dir.to_path_buf(), + files: files.to_vec(), + }, + tls, + ) + .await +} + /// Push a local path (file or directory) into a sandbox using tar-over-SSH. +/// +/// When `sandbox_path` is `None`, files are uploaded to the sandbox user's +/// home directory. When uploading a single file to an explicit destination +/// that does not end with `/`, the destination is treated as a file path: +/// the parent directory is created and the file is written with the +/// destination's basename. This matches `cp` / `scp` semantics. pub async fn sandbox_sync_up( server: &str, name: &str, local_path: &Path, - sandbox_path: &str, + sandbox_path: Option<&str>, tls: &TlsOptions, ) -> Result<()> { - let session = ssh_session_config(server, name, tls).await?; - - let mut ssh = ssh_base_command(&session.proxy_command); - ssh.arg("-T") - .arg("-o") - .arg("RequestTTY=no") - .arg("sandbox") - .arg(format!( - "mkdir -p {} && cat | tar xf - -C {}", - shell_escape(sandbox_path), - shell_escape(sandbox_path) - )) - .stdin(Stdio::piped()) - .stdout(Stdio::inherit()) - .stderr(Stdio::inherit()); - - let mut child = ssh.spawn().into_diagnostic()?; - let stdin = child - .stdin - .take() - .ok_or_else(|| miette::miette!("failed to open stdin for ssh process"))?; - - let local_path = local_path.to_path_buf(); - tokio::task::spawn_blocking(move || -> Result<()> { - let mut archive = tar::Builder::new(stdin); - if local_path.is_file() { - let file_name = local_path - .file_name() - .ok_or_else(|| miette::miette!("path has no file name"))?; - archive - .append_path_with_name(&local_path, file_name) - .into_diagnostic()?; - } else if local_path.is_dir() { - archive.append_dir_all(".", &local_path).into_diagnostic()?; - } else { - return Err(miette::miette!( - "local path does not exist: {}", - local_path.display() - )); + // When an explicit destination is given and looks like a file path (does + // not end with '/'), split into parent directory + target basename so that + // `mkdir -p` creates the parent and tar extracts the file with the right + // name. + // + // Exception: if splitting would yield "/" as the parent (e.g. the user + // passed "/sandbox"), fall through to directory semantics instead. The + // sandbox user cannot write to "/" and the intent is almost certainly + // "put the file inside /sandbox", not "create a file named sandbox in /". + if let Some(path) = sandbox_path { + if local_path.is_file() && !path.ends_with('/') { + let (parent, target_name) = split_sandbox_path(path); + if parent != "/" { + return ssh_tar_upload( + server, + name, + Some(parent), + UploadSource::SinglePath { + local_path: local_path.to_path_buf(), + tar_name: target_name.into(), + }, + tls, + ) + .await; + } } - archive.finish().into_diagnostic()?; - Ok(()) - }) - .await - .into_diagnostic()??; - - let status = tokio::task::spawn_blocking(move || child.wait()) - .await - .into_diagnostic()? - .into_diagnostic()?; - - if !status.success() { - return Err(miette::miette!( - "ssh tar extract exited with status {status}" - )); } - Ok(()) + let tar_name = if local_path.is_file() { + local_path + .file_name() + .ok_or_else(|| miette::miette!("path has no file name"))? + .to_os_string() + } else { + // For directories the tar_name is unused — append_dir_all uses "." + ".".into() + }; + + ssh_tar_upload( + server, + name, + sandbox_path, + UploadSource::SinglePath { + local_path: local_path.to_path_buf(), + tar_name, + }, + tls, + ) + .await } /// Pull a path from a sandbox to a local destination using tar-over-SSH. @@ -693,27 +772,50 @@ pub async fn sandbox_ssh_proxy( .ok_or_else(|| miette::miette!("gateway URL missing port"))?; let connect_path = url.path(); - let mut stream: Box = - connect_gateway(scheme, gateway_host, gateway_port, tls).await?; - let request = format!( "CONNECT {connect_path} HTTP/1.1\r\nHost: {gateway_host}\r\nX-Sandbox-Id: {sandbox_id}\r\nX-Sandbox-Token: {token}\r\n\r\n" ); - stream - .write_all(request.as_bytes()) - .await - .into_diagnostic()?; - // Wrap in a BufReader **before** reading the HTTP response. The gateway - // may send the 200 OK response and the first SSH protocol bytes in the - // same TCP segment / WebSocket frame. A plain `read()` would consume - // those SSH bytes into our buffer and discard them, causing SSH to see a - // truncated protocol banner and exit with code 255. BufReader ensures - // any bytes read past the `\r\n\r\n` header boundary stay buffered and - // are returned by subsequent reads during the bidirectional copy phase. - let mut buf_stream = BufReader::new(stream); - let status = read_connect_status(&mut buf_stream).await?; - if status != 200 { + // The gateway returns 412 (Precondition Failed) when the sandbox pod + // exists but hasn't reached Ready phase yet. This is a transient state + // after sandbox allocation — retry with backoff instead of failing + // immediately. + const MAX_CONNECT_WAIT: Duration = Duration::from_secs(60); + const INITIAL_BACKOFF: Duration = Duration::from_secs(1); + + let start = std::time::Instant::now(); + let mut backoff = INITIAL_BACKOFF; + let mut buf_stream; + + loop { + let mut stream: Box = + connect_gateway(scheme, gateway_host, gateway_port, tls).await?; + stream + .write_all(request.as_bytes()) + .await + .into_diagnostic()?; + + // Wrap in a BufReader **before** reading the HTTP response. The gateway + // may send the 200 OK response and the first SSH protocol bytes in the + // same TCP segment / WebSocket frame. A plain `read()` would consume + // those SSH bytes into our buffer and discard them, causing SSH to see a + // truncated protocol banner and exit with code 255. BufReader ensures + // any bytes read past the `\r\n\r\n` header boundary stay buffered and + // are returned by subsequent reads during the bidirectional copy phase. + buf_stream = BufReader::new(stream); + let status = read_connect_status(&mut buf_stream).await?; + if status == 200 { + break; + } + if status == 412 && start.elapsed() < MAX_CONNECT_WAIT { + tracing::debug!( + elapsed = ?start.elapsed(), + "sandbox not yet ready (HTTP 412), retrying in {backoff:?}" + ); + tokio::time::sleep(backoff).await; + backoff = (backoff * 2).min(Duration::from_secs(8)); + continue; + } return Err(miette::miette!( "gateway CONNECT failed with status {status}" )); @@ -1149,4 +1251,25 @@ mod tests { assert!(message.contains("Forwarding port 3000 to sandbox demo")); assert!(message.contains("Access at: http://localhost:3000/")); } + + #[test] + fn split_sandbox_path_separates_parent_and_basename() { + assert_eq!( + split_sandbox_path("/sandbox/.bashrc"), + ("/sandbox", ".bashrc") + ); + assert_eq!( + split_sandbox_path("/sandbox/sub/file"), + ("/sandbox/sub", "file") + ); + assert_eq!(split_sandbox_path("/a/b/c/d.txt"), ("/a/b/c", "d.txt")); + } + + #[test] + fn split_sandbox_path_handles_root_and_bare_names() { + // File directly under root + assert_eq!(split_sandbox_path("/.bashrc"), ("/", ".bashrc")); + // No directory component at all + assert_eq!(split_sandbox_path("file.txt"), (".", "file.txt")); + } } diff --git a/crates/openshell-core/src/proto/openshell.datamodel.v1.rs b/crates/openshell-core/src/proto/openshell.datamodel.v1.rs deleted file mode 100644 index 310497d1a..000000000 --- a/crates/openshell-core/src/proto/openshell.datamodel.v1.rs +++ /dev/null @@ -1,146 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 - -// This file is @generated by prost-build. -/// Sandbox model stored by OpenShell. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Sandbox { - #[prost(string, tag = "1")] - pub id: ::prost::alloc::string::String, - #[prost(string, tag = "2")] - pub name: ::prost::alloc::string::String, - #[prost(string, tag = "3")] - pub namespace: ::prost::alloc::string::String, - #[prost(message, optional, tag = "4")] - pub spec: ::core::option::Option, - #[prost(message, optional, tag = "5")] - pub status: ::core::option::Option, - #[prost(enumeration = "SandboxPhase", tag = "6")] - pub phase: i32, -} -/// OpenShell-level sandbox spec. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct SandboxSpec { - #[prost(string, tag = "1")] - pub log_level: ::prost::alloc::string::String, - #[prost(map = "string, string", tag = "5")] - pub environment: - ::std::collections::HashMap<::prost::alloc::string::String, ::prost::alloc::string::String>, - #[prost(message, optional, tag = "6")] - pub template: ::core::option::Option, - /// Required sandbox policy configuration. - #[prost(message, optional, tag = "7")] - pub policy: ::core::option::Option, - /// Provider names to attach to this sandbox. - #[prost(string, repeated, tag = "8")] - pub providers: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, -} -/// Sandbox template mapped onto Kubernetes pod template inputs. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct SandboxTemplate { - #[prost(string, tag = "1")] - pub image: ::prost::alloc::string::String, - #[prost(string, tag = "2")] - pub runtime_class_name: ::prost::alloc::string::String, - #[prost(string, tag = "3")] - pub agent_socket: ::prost::alloc::string::String, - #[prost(map = "string, string", tag = "4")] - pub labels: - ::std::collections::HashMap<::prost::alloc::string::String, ::prost::alloc::string::String>, - #[prost(map = "string, string", tag = "5")] - pub annotations: - ::std::collections::HashMap<::prost::alloc::string::String, ::prost::alloc::string::String>, - #[prost(map = "string, string", tag = "6")] - pub environment: - ::std::collections::HashMap<::prost::alloc::string::String, ::prost::alloc::string::String>, - #[prost(message, optional, tag = "7")] - pub resources: ::core::option::Option<::prost_types::Struct>, - #[prost(message, optional, tag = "9")] - pub volume_claim_templates: ::core::option::Option<::prost_types::Struct>, -} -/// Sandbox status captured from Kubernetes. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct SandboxStatus { - #[prost(string, tag = "1")] - pub sandbox_name: ::prost::alloc::string::String, - #[prost(string, tag = "2")] - pub agent_pod: ::prost::alloc::string::String, - #[prost(string, tag = "3")] - pub agent_fd: ::prost::alloc::string::String, - #[prost(string, tag = "4")] - pub sandbox_fd: ::prost::alloc::string::String, - #[prost(message, repeated, tag = "5")] - pub conditions: ::prost::alloc::vec::Vec, -} -/// Sandbox condition mirrors Kubernetes conditions. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct SandboxCondition { - #[prost(string, tag = "1")] - pub r#type: ::prost::alloc::string::String, - #[prost(string, tag = "2")] - pub status: ::prost::alloc::string::String, - #[prost(string, tag = "3")] - pub reason: ::prost::alloc::string::String, - #[prost(string, tag = "4")] - pub message: ::prost::alloc::string::String, - #[prost(string, tag = "5")] - pub last_transition_time: ::prost::alloc::string::String, -} -/// Provider model stored by OpenShell. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Provider { - #[prost(string, tag = "1")] - pub id: ::prost::alloc::string::String, - #[prost(string, tag = "2")] - pub name: ::prost::alloc::string::String, - /// Canonical provider type slug (for example: "claude", "gitlab"). - #[prost(string, tag = "3")] - pub r#type: ::prost::alloc::string::String, - /// Secret values used for authentication. - #[prost(map = "string, string", tag = "4")] - pub credentials: - ::std::collections::HashMap<::prost::alloc::string::String, ::prost::alloc::string::String>, - /// Non-secret provider configuration. - #[prost(map = "string, string", tag = "5")] - pub config: - ::std::collections::HashMap<::prost::alloc::string::String, ::prost::alloc::string::String>, -} -/// High-level sandbox lifecycle phase. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum SandboxPhase { - Unspecified = 0, - Provisioning = 1, - Ready = 2, - Error = 3, - Deleting = 4, - Unknown = 5, -} -impl SandboxPhase { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - Self::Unspecified => "SANDBOX_PHASE_UNSPECIFIED", - Self::Provisioning => "SANDBOX_PHASE_PROVISIONING", - Self::Ready => "SANDBOX_PHASE_READY", - Self::Error => "SANDBOX_PHASE_ERROR", - Self::Deleting => "SANDBOX_PHASE_DELETING", - Self::Unknown => "SANDBOX_PHASE_UNKNOWN", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "SANDBOX_PHASE_UNSPECIFIED" => Some(Self::Unspecified), - "SANDBOX_PHASE_PROVISIONING" => Some(Self::Provisioning), - "SANDBOX_PHASE_READY" => Some(Self::Ready), - "SANDBOX_PHASE_ERROR" => Some(Self::Error), - "SANDBOX_PHASE_DELETING" => Some(Self::Deleting), - "SANDBOX_PHASE_UNKNOWN" => Some(Self::Unknown), - _ => None, - } - } -} diff --git a/crates/openshell-core/src/proto/openshell.sandbox.v1.rs b/crates/openshell-core/src/proto/openshell.sandbox.v1.rs deleted file mode 100644 index c7fbb178b..000000000 --- a/crates/openshell-core/src/proto/openshell.sandbox.v1.rs +++ /dev/null @@ -1,160 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 - -// This file is @generated by prost-build. -/// Sandbox security policy configuration. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct SandboxPolicy { - /// Policy version. - #[prost(uint32, tag = "1")] - pub version: u32, - /// Filesystem access policy. - #[prost(message, optional, tag = "2")] - pub filesystem: ::core::option::Option, - /// Network access policy. - #[prost(message, optional, tag = "3")] - pub network: ::core::option::Option, - /// Landlock configuration. - #[prost(message, optional, tag = "4")] - pub landlock: ::core::option::Option, - /// Process execution policy. - #[prost(message, optional, tag = "5")] - pub process: ::core::option::Option, -} -/// Filesystem access policy. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct FilesystemPolicy { - /// Read-only directory allow list. - #[prost(string, repeated, tag = "1")] - pub read_only: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, - /// Read-write directory allow list. - #[prost(string, repeated, tag = "2")] - pub read_write: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, - /// Automatically include the workdir as read-write. - #[prost(bool, tag = "3")] - pub include_workdir: bool, -} -/// Network access policy. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct NetworkPolicy { - /// Network access mode. - #[prost(enumeration = "NetworkMode", tag = "1")] - pub mode: i32, - /// Proxy configuration (required when mode is PROXY). - #[prost(message, optional, tag = "2")] - pub proxy: ::core::option::Option, -} -/// Proxy configuration for network policy. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ProxyPolicy { - /// Unix socket path for a local proxy (preferred for strict seccomp rules). - #[prost(string, tag = "1")] - pub unix_socket: ::prost::alloc::string::String, - /// TCP address for a local HTTP proxy (loopback-only). - #[prost(string, tag = "2")] - pub http_addr: ::prost::alloc::string::String, - /// Allowed hostnames for proxy traffic. Empty means allow all. - #[prost(string, repeated, tag = "3")] - pub allow_hosts: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, -} -/// Landlock policy configuration. -#[derive(Clone, Copy, PartialEq, ::prost::Message)] -pub struct LandlockPolicy { - /// Compatibility mode. - #[prost(enumeration = "LandlockCompatibility", tag = "1")] - pub compatibility: i32, -} -/// Process execution policy. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ProcessPolicy { - /// User name to run the sandboxed process as. - #[prost(string, tag = "1")] - pub run_as_user: ::prost::alloc::string::String, - /// Group name to run the sandboxed process as. - #[prost(string, tag = "2")] - pub run_as_group: ::prost::alloc::string::String, -} -/// Request to get sandbox policy by sandbox ID. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct GetSandboxPolicyRequest { - /// The sandbox ID. - #[prost(string, tag = "1")] - pub sandbox_id: ::prost::alloc::string::String, -} -/// Response containing sandbox policy. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct GetSandboxPolicyResponse { - /// The sandbox policy configuration. - #[prost(message, optional, tag = "1")] - pub policy: ::core::option::Option, -} -/// Network access mode. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum NetworkMode { - /// Unspecified defaults to BLOCK. - Unspecified = 0, - /// Block all network access. - Block = 1, - /// Route traffic through a proxy. - Proxy = 2, - /// Allow all network access. - Allow = 3, -} -impl NetworkMode { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - Self::Unspecified => "NETWORK_MODE_UNSPECIFIED", - Self::Block => "NETWORK_MODE_BLOCK", - Self::Proxy => "NETWORK_MODE_PROXY", - Self::Allow => "NETWORK_MODE_ALLOW", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "NETWORK_MODE_UNSPECIFIED" => Some(Self::Unspecified), - "NETWORK_MODE_BLOCK" => Some(Self::Block), - "NETWORK_MODE_PROXY" => Some(Self::Proxy), - "NETWORK_MODE_ALLOW" => Some(Self::Allow), - _ => None, - } - } -} -/// Landlock compatibility mode. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum LandlockCompatibility { - /// Unspecified defaults to BEST_EFFORT. - Unspecified = 0, - /// Use best effort - degrade gracefully on older kernels. - BestEffort = 1, - /// Require full Landlock support or fail. - HardRequirement = 2, -} -impl LandlockCompatibility { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - Self::Unspecified => "LANDLOCK_COMPATIBILITY_UNSPECIFIED", - Self::BestEffort => "LANDLOCK_COMPATIBILITY_BEST_EFFORT", - Self::HardRequirement => "LANDLOCK_COMPATIBILITY_HARD_REQUIREMENT", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "LANDLOCK_COMPATIBILITY_UNSPECIFIED" => Some(Self::Unspecified), - "LANDLOCK_COMPATIBILITY_BEST_EFFORT" => Some(Self::BestEffort), - "LANDLOCK_COMPATIBILITY_HARD_REQUIREMENT" => Some(Self::HardRequirement), - _ => None, - } - } -} diff --git a/crates/openshell-core/src/proto/openshell.test.v1.rs b/crates/openshell-core/src/proto/openshell.test.v1.rs deleted file mode 100644 index 319b3fd3a..000000000 --- a/crates/openshell-core/src/proto/openshell.test.v1.rs +++ /dev/null @@ -1,14 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 - -// This file is @generated by prost-build. -/// Simple object for persistence tests. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ObjectForTest { - #[prost(string, tag = "1")] - pub id: ::prost::alloc::string::String, - #[prost(string, tag = "2")] - pub name: ::prost::alloc::string::String, - #[prost(uint32, tag = "3")] - pub count: u32, -} diff --git a/crates/openshell-core/src/proto/openshell.v1.rs b/crates/openshell-core/src/proto/openshell.v1.rs deleted file mode 100644 index a2735b076..000000000 --- a/crates/openshell-core/src/proto/openshell.v1.rs +++ /dev/null @@ -1,1188 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 - -// This file is @generated by prost-build. -/// Health check request. -#[derive(Clone, Copy, PartialEq, ::prost::Message)] -pub struct HealthRequest {} -/// Health check response. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct HealthResponse { - /// Service status. - #[prost(enumeration = "ServiceStatus", tag = "1")] - pub status: i32, - /// Service version. - #[prost(string, tag = "2")] - pub version: ::prost::alloc::string::String, -} -/// Create sandbox request. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct CreateSandboxRequest { - #[prost(message, optional, tag = "1")] - pub spec: ::core::option::Option, - /// Optional user-supplied sandbox name. When empty the server generates one. - #[prost(string, tag = "2")] - pub name: ::prost::alloc::string::String, -} -/// Get sandbox request. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct GetSandboxRequest { - #[prost(string, tag = "1")] - pub id: ::prost::alloc::string::String, -} -/// List sandboxes request. -#[derive(Clone, Copy, PartialEq, ::prost::Message)] -pub struct ListSandboxesRequest { - #[prost(uint32, tag = "1")] - pub limit: u32, - #[prost(uint32, tag = "2")] - pub offset: u32, -} -/// Delete sandbox request. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct DeleteSandboxRequest { - #[prost(string, tag = "1")] - pub id: ::prost::alloc::string::String, -} -/// Sandbox response. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct SandboxResponse { - #[prost(message, optional, tag = "1")] - pub sandbox: ::core::option::Option, -} -/// List sandboxes response. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ListSandboxesResponse { - #[prost(message, repeated, tag = "1")] - pub sandboxes: ::prost::alloc::vec::Vec, -} -/// Delete sandbox response. -#[derive(Clone, Copy, PartialEq, ::prost::Message)] -pub struct DeleteSandboxResponse { - #[prost(bool, tag = "1")] - pub deleted: bool, -} -/// Create SSH session request. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct CreateSshSessionRequest { - /// Sandbox id. - #[prost(string, tag = "1")] - pub sandbox_id: ::prost::alloc::string::String, -} -/// Create SSH session response. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct CreateSshSessionResponse { - /// Sandbox id. - #[prost(string, tag = "1")] - pub sandbox_id: ::prost::alloc::string::String, - /// Session token for the gateway tunnel. - #[prost(string, tag = "2")] - pub token: ::prost::alloc::string::String, - /// Gateway host for SSH proxy connection. - #[prost(string, tag = "3")] - pub gateway_host: ::prost::alloc::string::String, - /// Gateway port for SSH proxy connection. - #[prost(uint32, tag = "4")] - pub gateway_port: u32, - /// Gateway scheme (http or https). - #[prost(string, tag = "5")] - pub gateway_scheme: ::prost::alloc::string::String, - /// HTTP path for the CONNECT/upgrade endpoint. - #[prost(string, tag = "6")] - pub connect_path: ::prost::alloc::string::String, - /// Optional host key fingerprint. - #[prost(string, tag = "7")] - pub host_key_fingerprint: ::prost::alloc::string::String, -} -/// Revoke SSH session request. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct RevokeSshSessionRequest { - /// Session token to revoke. - #[prost(string, tag = "1")] - pub token: ::prost::alloc::string::String, -} -/// Revoke SSH session response. -#[derive(Clone, Copy, PartialEq, ::prost::Message)] -pub struct RevokeSshSessionResponse { - /// True when a session was revoked. - #[prost(bool, tag = "1")] - pub revoked: bool, -} -/// SSH session record stored in persistence. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct SshSession { - /// Unique id (token). - #[prost(string, tag = "1")] - pub id: ::prost::alloc::string::String, - /// Sandbox id. - #[prost(string, tag = "2")] - pub sandbox_id: ::prost::alloc::string::String, - /// Session token. - #[prost(string, tag = "3")] - pub token: ::prost::alloc::string::String, - /// Creation timestamp in milliseconds since epoch. - #[prost(int64, tag = "4")] - pub created_at_ms: i64, - /// Revoked flag. - #[prost(bool, tag = "5")] - pub revoked: bool, -} -/// Watch sandbox request. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct WatchSandboxRequest { - /// Sandbox id. - #[prost(string, tag = "1")] - pub id: ::prost::alloc::string::String, - /// Stream sandbox status snapshots. - #[prost(bool, tag = "2")] - pub follow_status: bool, - /// Stream openshell-server process logs correlated to this sandbox. - #[prost(bool, tag = "3")] - pub follow_logs: bool, - /// Stream platform events correlated to this sandbox. - #[prost(bool, tag = "4")] - pub follow_events: bool, - /// Replay the last N log lines (best-effort) before following. - #[prost(uint32, tag = "5")] - pub log_tail_lines: u32, - /// Replay the last N platform events (best-effort) before following. - #[prost(uint32, tag = "6")] - pub event_tail: u32, - /// Stop streaming once the sandbox reaches a terminal phase (READY or ERROR). - #[prost(bool, tag = "7")] - pub stop_on_terminal: bool, -} -/// One event in a sandbox watch stream. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct SandboxStreamEvent { - #[prost(oneof = "sandbox_stream_event::Payload", tags = "1, 2, 3, 4")] - pub payload: ::core::option::Option, -} -/// Nested message and enum types in `SandboxStreamEvent`. -pub mod sandbox_stream_event { - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum Payload { - /// Latest sandbox snapshot. - #[prost(message, tag = "1")] - Sandbox(super::super::datamodel::v1::Sandbox), - /// One server log line/event. - #[prost(message, tag = "2")] - Log(super::SandboxLogLine), - /// One platform event. - #[prost(message, tag = "3")] - Event(super::PlatformEvent), - /// Warning from the server (e.g. missed messages due to lag). - #[prost(message, tag = "4")] - Warning(super::SandboxStreamWarning), - } -} -/// OpenShell server process log line correlated to a sandbox. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct SandboxLogLine { - #[prost(string, tag = "1")] - pub sandbox_id: ::prost::alloc::string::String, - #[prost(int64, tag = "2")] - pub timestamp_ms: i64, - #[prost(string, tag = "3")] - pub level: ::prost::alloc::string::String, - #[prost(string, tag = "4")] - pub target: ::prost::alloc::string::String, - #[prost(string, tag = "5")] - pub message: ::prost::alloc::string::String, -} -/// Platform event correlated to a sandbox. -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PlatformEvent { - /// Event timestamp in milliseconds since epoch. - #[prost(int64, tag = "1")] - pub timestamp_ms: i64, - /// Event source (e.g. "kubernetes", "docker", "process"). - #[prost(string, tag = "2")] - pub source: ::prost::alloc::string::String, - /// Event type/severity (e.g. "Normal", "Warning"). - #[prost(string, tag = "3")] - pub r#type: ::prost::alloc::string::String, - /// Short reason code (e.g. "Started", "Pulled", "Failed"). - #[prost(string, tag = "4")] - pub reason: ::prost::alloc::string::String, - /// Human-readable event message. - #[prost(string, tag = "5")] - pub message: ::prost::alloc::string::String, - /// Optional metadata as key-value pairs. - #[prost(map = "string, string", tag = "6")] - pub metadata: ::std::collections::HashMap< - ::prost::alloc::string::String, - ::prost::alloc::string::String, - >, -} -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct SandboxStreamWarning { - #[prost(string, tag = "1")] - pub message: ::prost::alloc::string::String, -} -/// Service status enum. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum ServiceStatus { - Unspecified = 0, - Healthy = 1, - Degraded = 2, - Unhealthy = 3, -} -impl ServiceStatus { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - Self::Unspecified => "SERVICE_STATUS_UNSPECIFIED", - Self::Healthy => "SERVICE_STATUS_HEALTHY", - Self::Degraded => "SERVICE_STATUS_DEGRADED", - Self::Unhealthy => "SERVICE_STATUS_UNHEALTHY", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "SERVICE_STATUS_UNSPECIFIED" => Some(Self::Unspecified), - "SERVICE_STATUS_HEALTHY" => Some(Self::Healthy), - "SERVICE_STATUS_DEGRADED" => Some(Self::Degraded), - "SERVICE_STATUS_UNHEALTHY" => Some(Self::Unhealthy), - _ => None, - } - } -} -/// Generated client implementations. -pub mod open_shell_client { - #![allow( - unused_variables, - dead_code, - missing_docs, - clippy::wildcard_imports, - clippy::let_unit_value, - )] - use tonic::codegen::*; - use tonic::codegen::http::Uri; - /// OpenShell service provides agent execution and management capabilities. - #[derive(Debug, Clone)] - pub struct OpenShellClient { - inner: tonic::client::Grpc, - } - impl OpenShellClient { - /// Attempt to create a new client by connecting to a given endpoint. - pub async fn connect(dst: D) -> Result - where - D: TryInto, - D::Error: Into, - { - let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; - Ok(Self::new(conn)) - } - } - impl OpenShellClient - where - T: tonic::client::GrpcService, - T::Error: Into, - T::ResponseBody: Body + std::marker::Send + 'static, - ::Error: Into + std::marker::Send, - { - pub fn new(inner: T) -> Self { - let inner = tonic::client::Grpc::new(inner); - Self { inner } - } - pub fn with_origin(inner: T, origin: Uri) -> Self { - let inner = tonic::client::Grpc::with_origin(inner, origin); - Self { inner } - } - pub fn with_interceptor( - inner: T, - interceptor: F, - ) -> OpenShellClient> - where - F: tonic::service::Interceptor, - T::ResponseBody: Default, - T: tonic::codegen::Service< - http::Request, - Response = http::Response< - >::ResponseBody, - >, - >, - , - >>::Error: Into + std::marker::Send + std::marker::Sync, - { - OpenShellClient::new(InterceptedService::new(inner, interceptor)) - } - /// Compress requests with the given encoding. - /// - /// This requires the server to support it otherwise it might respond with an - /// error. - #[must_use] - pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { - self.inner = self.inner.send_compressed(encoding); - self - } - /// Enable decompressing responses. - #[must_use] - pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { - self.inner = self.inner.accept_compressed(encoding); - self - } - /// Limits the maximum size of a decoded message. - /// - /// Default: `4MB` - #[must_use] - pub fn max_decoding_message_size(mut self, limit: usize) -> Self { - self.inner = self.inner.max_decoding_message_size(limit); - self - } - /// Limits the maximum size of an encoded message. - /// - /// Default: `usize::MAX` - #[must_use] - pub fn max_encoding_message_size(mut self, limit: usize) -> Self { - self.inner = self.inner.max_encoding_message_size(limit); - self - } - /// Check the health of the service. - pub async fn health( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result, tonic::Status> { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::unknown( - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/openshell.v1.OpenShell/Health", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("openshell.v1.OpenShell", "Health")); - self.inner.unary(req, path, codec).await - } - /// Create a new sandbox. - pub async fn create_sandbox( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::unknown( - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/openshell.v1.OpenShell/CreateSandbox", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("openshell.v1.OpenShell", "CreateSandbox")); - self.inner.unary(req, path, codec).await - } - /// Fetch a sandbox by id. - pub async fn get_sandbox( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::unknown( - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/openshell.v1.OpenShell/GetSandbox", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("openshell.v1.OpenShell", "GetSandbox")); - self.inner.unary(req, path, codec).await - } - /// List sandboxes. - pub async fn list_sandboxes( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::unknown( - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/openshell.v1.OpenShell/ListSandboxes", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("openshell.v1.OpenShell", "ListSandboxes")); - self.inner.unary(req, path, codec).await - } - /// Delete a sandbox by id. - pub async fn delete_sandbox( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::unknown( - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/openshell.v1.OpenShell/DeleteSandbox", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("openshell.v1.OpenShell", "DeleteSandbox")); - self.inner.unary(req, path, codec).await - } - /// Create a short-lived SSH session for a sandbox. - pub async fn create_ssh_session( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::unknown( - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/openshell.v1.OpenShell/CreateSshSession", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("openshell.v1.OpenShell", "CreateSshSession")); - self.inner.unary(req, path, codec).await - } - /// Revoke a previously issued SSH session. - pub async fn revoke_ssh_session( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::unknown( - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/openshell.v1.OpenShell/RevokeSshSession", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("openshell.v1.OpenShell", "RevokeSshSession")); - self.inner.unary(req, path, codec).await - } - /// Get sandbox policy by id (called by sandbox entrypoint at startup). - pub async fn get_sandbox_policy( - &mut self, - request: impl tonic::IntoRequest< - super::super::sandbox::v1::GetSandboxPolicyRequest, - >, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::unknown( - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/openshell.v1.OpenShell/GetSandboxPolicy", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("openshell.v1.OpenShell", "GetSandboxPolicy")); - self.inner.unary(req, path, codec).await - } - /// Watch a sandbox and stream updates. - /// - /// This stream can include: - /// - Sandbox status snapshots (phase/status) - /// - OpenShell server process logs correlated by sandbox_id - /// - Platform events correlated to the sandbox - pub async fn watch_sandbox( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response>, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::unknown( - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/openshell.v1.OpenShell/WatchSandbox", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("openshell.v1.OpenShell", "WatchSandbox")); - self.inner.server_streaming(req, path, codec).await - } - } -} -/// Generated server implementations. -pub mod open_shell_server { - #![allow( - unused_variables, - dead_code, - missing_docs, - clippy::wildcard_imports, - clippy::let_unit_value, - )] - use tonic::codegen::*; - /// Generated trait containing gRPC methods that should be implemented for use with OpenShellServer. - #[async_trait] - pub trait OpenShell: std::marker::Send + std::marker::Sync + 'static { - /// Check the health of the service. - async fn health( - &self, - request: tonic::Request, - ) -> std::result::Result, tonic::Status>; - /// Create a new sandbox. - async fn create_sandbox( - &self, - request: tonic::Request, - ) -> std::result::Result, tonic::Status>; - /// Fetch a sandbox by id. - async fn get_sandbox( - &self, - request: tonic::Request, - ) -> std::result::Result, tonic::Status>; - /// List sandboxes. - async fn list_sandboxes( - &self, - request: tonic::Request, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - >; - /// Delete a sandbox by id. - async fn delete_sandbox( - &self, - request: tonic::Request, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - >; - /// Create a short-lived SSH session for a sandbox. - async fn create_ssh_session( - &self, - request: tonic::Request, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - >; - /// Revoke a previously issued SSH session. - async fn revoke_ssh_session( - &self, - request: tonic::Request, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - >; - /// Get sandbox policy by id (called by sandbox entrypoint at startup). - async fn get_sandbox_policy( - &self, - request: tonic::Request, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - >; - /// Server streaming response type for the WatchSandbox method. - type WatchSandboxStream: tonic::codegen::tokio_stream::Stream< - Item = std::result::Result, - > - + std::marker::Send - + 'static; - /// Watch a sandbox and stream updates. - /// - /// This stream can include: - /// - Sandbox status snapshots (phase/status) - /// - OpenShell server process logs correlated by sandbox_id - /// - Platform events correlated to the sandbox - async fn watch_sandbox( - &self, - request: tonic::Request, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - >; - } - /// OpenShell service provides agent execution and management capabilities. - #[derive(Debug)] - pub struct OpenShellServer { - inner: Arc, - accept_compression_encodings: EnabledCompressionEncodings, - send_compression_encodings: EnabledCompressionEncodings, - max_decoding_message_size: Option, - max_encoding_message_size: Option, - } - impl OpenShellServer { - pub fn new(inner: T) -> Self { - Self::from_arc(Arc::new(inner)) - } - pub fn from_arc(inner: Arc) -> Self { - Self { - inner, - accept_compression_encodings: Default::default(), - send_compression_encodings: Default::default(), - max_decoding_message_size: None, - max_encoding_message_size: None, - } - } - pub fn with_interceptor( - inner: T, - interceptor: F, - ) -> InterceptedService - where - F: tonic::service::Interceptor, - { - InterceptedService::new(Self::new(inner), interceptor) - } - /// Enable decompressing requests with the given encoding. - #[must_use] - pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { - self.accept_compression_encodings.enable(encoding); - self - } - /// Compress responses with the given encoding, if the client supports it. - #[must_use] - pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { - self.send_compression_encodings.enable(encoding); - self - } - /// Limits the maximum size of a decoded message. - /// - /// Default: `4MB` - #[must_use] - pub fn max_decoding_message_size(mut self, limit: usize) -> Self { - self.max_decoding_message_size = Some(limit); - self - } - /// Limits the maximum size of an encoded message. - /// - /// Default: `usize::MAX` - #[must_use] - pub fn max_encoding_message_size(mut self, limit: usize) -> Self { - self.max_encoding_message_size = Some(limit); - self - } - } - impl tonic::codegen::Service> for OpenShellServer - where - T: OpenShell, - B: Body + std::marker::Send + 'static, - B::Error: Into + std::marker::Send + 'static, - { - type Response = http::Response; - type Error = std::convert::Infallible; - type Future = BoxFuture; - fn poll_ready( - &mut self, - _cx: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) - } - fn call(&mut self, req: http::Request) -> Self::Future { - match req.uri().path() { - "/openshell.v1.OpenShell/Health" => { - #[allow(non_camel_case_types)] - struct HealthSvc(pub Arc); - impl tonic::server::UnaryService - for HealthSvc { - type Response = super::HealthResponse; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; - fn call( - &mut self, - request: tonic::Request, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - ::health(&inner, request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let method = HealthSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.unary(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/openshell.v1.OpenShell/CreateSandbox" => { - #[allow(non_camel_case_types)] - struct CreateSandboxSvc(pub Arc); - impl< - T: OpenShell, - > tonic::server::UnaryService - for CreateSandboxSvc { - type Response = super::SandboxResponse; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; - fn call( - &mut self, - request: tonic::Request, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - ::create_sandbox(&inner, request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let method = CreateSandboxSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.unary(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/openshell.v1.OpenShell/GetSandbox" => { - #[allow(non_camel_case_types)] - struct GetSandboxSvc(pub Arc); - impl< - T: OpenShell, - > tonic::server::UnaryService - for GetSandboxSvc { - type Response = super::SandboxResponse; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; - fn call( - &mut self, - request: tonic::Request, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - ::get_sandbox(&inner, request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let method = GetSandboxSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.unary(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/openshell.v1.OpenShell/ListSandboxes" => { - #[allow(non_camel_case_types)] - struct ListSandboxesSvc(pub Arc); - impl< - T: OpenShell, - > tonic::server::UnaryService - for ListSandboxesSvc { - type Response = super::ListSandboxesResponse; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; - fn call( - &mut self, - request: tonic::Request, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - ::list_sandboxes(&inner, request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let method = ListSandboxesSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.unary(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/openshell.v1.OpenShell/DeleteSandbox" => { - #[allow(non_camel_case_types)] - struct DeleteSandboxSvc(pub Arc); - impl< - T: OpenShell, - > tonic::server::UnaryService - for DeleteSandboxSvc { - type Response = super::DeleteSandboxResponse; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; - fn call( - &mut self, - request: tonic::Request, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - ::delete_sandbox(&inner, request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let method = DeleteSandboxSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.unary(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/openshell.v1.OpenShell/CreateSshSession" => { - #[allow(non_camel_case_types)] - struct CreateSshSessionSvc(pub Arc); - impl< - T: OpenShell, - > tonic::server::UnaryService - for CreateSshSessionSvc { - type Response = super::CreateSshSessionResponse; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; - fn call( - &mut self, - request: tonic::Request, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - ::create_ssh_session(&inner, request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let method = CreateSshSessionSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.unary(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/openshell.v1.OpenShell/RevokeSshSession" => { - #[allow(non_camel_case_types)] - struct RevokeSshSessionSvc(pub Arc); - impl< - T: OpenShell, - > tonic::server::UnaryService - for RevokeSshSessionSvc { - type Response = super::RevokeSshSessionResponse; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; - fn call( - &mut self, - request: tonic::Request, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - ::revoke_ssh_session(&inner, request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let method = RevokeSshSessionSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.unary(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/openshell.v1.OpenShell/GetSandboxPolicy" => { - #[allow(non_camel_case_types)] - struct GetSandboxPolicySvc(pub Arc); - impl< - T: OpenShell, - > tonic::server::UnaryService< - super::super::sandbox::v1::GetSandboxPolicyRequest, - > for GetSandboxPolicySvc { - type Response = super::super::sandbox::v1::GetSandboxPolicyResponse; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; - fn call( - &mut self, - request: tonic::Request< - super::super::sandbox::v1::GetSandboxPolicyRequest, - >, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - ::get_sandbox_policy(&inner, request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let method = GetSandboxPolicySvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.unary(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/openshell.v1.OpenShell/WatchSandbox" => { - #[allow(non_camel_case_types)] - struct WatchSandboxSvc(pub Arc); - impl< - T: OpenShell, - > tonic::server::ServerStreamingService - for WatchSandboxSvc { - type Response = super::SandboxStreamEvent; - type ResponseStream = T::WatchSandboxStream; - type Future = BoxFuture< - tonic::Response, - tonic::Status, - >; - fn call( - &mut self, - request: tonic::Request, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - ::watch_sandbox(&inner, request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let method = WatchSandboxSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.server_streaming(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - _ => { - Box::pin(async move { - let mut response = http::Response::new(empty_body()); - let headers = response.headers_mut(); - headers - .insert( - tonic::Status::GRPC_STATUS, - (tonic::Code::Unimplemented as i32).into(), - ); - headers - .insert( - http::header::CONTENT_TYPE, - tonic::metadata::GRPC_CONTENT_TYPE, - ); - Ok(response) - }) - } - } - } - } - impl Clone for OpenShellServer { - fn clone(&self) -> Self { - let inner = self.inner.clone(); - Self { - inner, - accept_compression_encodings: self.accept_compression_encodings, - send_compression_encodings: self.send_compression_encodings, - max_decoding_message_size: self.max_decoding_message_size, - max_encoding_message_size: self.max_encoding_message_size, - } - } - } - /// Generated gRPC service name - pub const SERVICE_NAME: &str = "openshell.v1.OpenShell"; - impl tonic::server::NamedService for OpenShellServer { - const NAME: &'static str = SERVICE_NAME; - } -} diff --git a/crates/openshell-core/src/settings.rs b/crates/openshell-core/src/settings.rs index b94c08fc3..995fe6e2a 100644 --- a/crates/openshell-core/src/settings.rs +++ b/crates/openshell-core/src/settings.rs @@ -49,8 +49,13 @@ pub struct RegisteredSetting { /// keys are accepted. /// 5. Add a unit test in this module's `tests` section to cover the new key. pub const REGISTERED_SETTINGS: &[RegisteredSetting] = &[ - // Production settings go here. Add entries following the steps above. - // + // When true the sandbox writes OCSF v1.7.0 JSONL records to + // `/var/log/openshell-ocsf*.log` (daily rotation, 3 files) in addition + // to the human-readable shorthand log. Defaults to false (no JSONL written). + RegisteredSetting { + key: "ocsf_json_enabled", + kind: SettingValueKind::Bool, + }, // Test-only keys live behind the `dev-settings` feature flag so they // don't appear in production builds. #[cfg(feature = "dev-settings")] diff --git a/crates/openshell-ocsf/src/format/shorthand.rs b/crates/openshell-ocsf/src/format/shorthand.rs index e9c99ab5e..3b245e10e 100644 --- a/crates/openshell-ocsf/src/format/shorthand.rs +++ b/crates/openshell-ocsf/src/format/shorthand.rs @@ -36,6 +36,29 @@ pub fn severity_char(severity_id: u8) -> char { } } +/// Format the severity as a bracketed tag placed after the `CLASS:ACTIVITY`. +/// +/// Placed as a suffix so the class name always starts at column 0, keeping +/// logs vertically scannable: +/// +/// ```text +/// NET:OPEN [INFO] ALLOWED python3(42) -> api.example.com:443 +/// NET:OPEN [MED] DENIED python3(42) -> blocked.com:443 +/// FINDING:BLOCKED [HIGH] "NSSH1 Nonce Replay Attack" +/// ``` +#[must_use] +pub fn severity_tag(severity_id: u8) -> &'static str { + match severity_id { + 1 => "[INFO]", + 2 => "[LOW]", + 3 => "[MED]", + 4 => "[HIGH]", + 5 => "[CRIT]", + 6 => "[FATAL]", + _ => "[INFO]", + } +} + impl OcsfEvent { /// Produce the single-line shorthand for `openshell.log` and gRPC log push. /// @@ -43,8 +66,7 @@ impl OcsfEvent { #[must_use] pub fn format_shorthand(&self) -> String { let base = self.base(); - let ts = format_ts(base.time); - let sev = severity_char(base.severity.as_u8()); + let sev = severity_tag(base.severity.as_u8()); match self { Self::NetworkActivity(e) => { @@ -85,7 +107,13 @@ impl OcsfEvent { format!(" {actor_str} -> {dst}") }; - format!("{ts} {sev} NET:{activity} {action}{arrow}{rule_ctx}") + let detail = match (action.is_empty(), arrow.is_empty()) { + (true, true) => String::new(), + (true, false) => arrow, + (false, true) => format!(" {action}"), + (false, false) => format!(" {action}{arrow}"), + }; + format!("NET:{activity} {sev}{detail}{rule_ctx}") } Self::HttpActivity(e) => { @@ -116,7 +144,13 @@ impl OcsfEvent { format!(" {actor_str} -> {method} {url_str}") }; - format!("{ts} {sev} HTTP:{method} {action}{arrow}{rule_ctx}") + let detail = match (action.is_empty(), arrow.is_empty()) { + (true, true) => String::new(), + (true, false) => arrow, + (false, true) => format!(" {action}"), + (false, false) => format!(" {action}{arrow}"), + }; + format!("HTTP:{method} {sev}{detail}{rule_ctx}") } Self::SshActivity(e) => { @@ -143,7 +177,21 @@ impl OcsfEvent { }) .unwrap_or_default(); - format!("{ts} {sev} SSH:{activity} {action} {peer}{auth_ctx}") + let detail = [ + if action.is_empty() { "" } else { &action }, + if peer.is_empty() { "" } else { &peer }, + ] + .iter() + .filter(|s| !s.is_empty()) + .copied() + .collect::>() + .join(" "); + let detail = if detail.is_empty() { + String::new() + } else { + format!(" {detail}") + }; + format!("SSH:{activity} {sev}{detail}{auth_ctx}") } Self::ProcessActivity(e) => { @@ -160,7 +208,7 @@ impl OcsfEvent { .map(|c| format!(" [cmd:{c}]")) .unwrap_or_default(); - format!("{ts} {sev} PROC:{activity} {proc_str}{exit_ctx}{cmd_ctx}") + format!("PROC:{activity} {sev} {proc_str}{exit_ctx}{cmd_ctx}") } Self::DetectionFinding(e) => { @@ -173,7 +221,7 @@ impl OcsfEvent { .map(|c| format!(" [confidence:{}]", c.label().to_lowercase())) .unwrap_or_default(); - format!("{ts} {sev} FINDING:{disposition} \"{title}\"{confidence_ctx}") + format!("FINDING:{disposition} {sev} \"{title}\"{confidence_ctx}") } Self::ApplicationLifecycle(e) => { @@ -185,7 +233,7 @@ impl OcsfEvent { .map(|s| s.label().to_lowercase()) .unwrap_or_default(); - format!("{ts} {sev} LIFECYCLE:{activity} {app} {status}") + format!("LIFECYCLE:{activity} {sev} {app} {status}") } Self::DeviceConfigStateChange(e) => { @@ -214,7 +262,7 @@ impl OcsfEvent { }) .unwrap_or_default(); - format!("{ts} {sev} CONFIG:{state} {what}{version_ctx}") + format!("CONFIG:{state} {sev} {what}{version_ctx}") } Self::Base(e) => { @@ -240,7 +288,7 @@ impl OcsfEvent { }) .unwrap_or_default(); - format!("{ts} {sev} EVENT {message}{unmapped_ctx}") + format!("EVENT {sev} {message}{unmapped_ctx}") } } } @@ -337,7 +385,7 @@ mod tests { let shorthand = event.format_shorthand(); assert_eq!( shorthand, - "14:00:00.000 I NET:OPEN ALLOWED python3(42) -> api.example.com:443 [policy:default-egress engine:mechanistic]" + "NET:OPEN [INFO] ALLOWED python3(42) -> api.example.com:443 [policy:default-egress engine:mechanistic]" ); } @@ -366,7 +414,7 @@ mod tests { let shorthand = event.format_shorthand(); assert_eq!( shorthand, - "14:00:00.000 M NET:REFUSE DENIED node(1234) -> 93.184.216.34:443/tcp [policy:bypass-detect engine:iptables]" + "NET:REFUSE [MED] DENIED node(1234) -> 93.184.216.34:443/tcp [policy:bypass-detect engine:iptables]" ); } @@ -395,7 +443,7 @@ mod tests { let shorthand = event.format_shorthand(); assert_eq!( shorthand, - "14:00:00.000 I HTTP:GET ALLOWED curl(88) -> GET https://api.example.com/v1/data [policy:default-egress]" + "HTTP:GET [INFO] ALLOWED curl(88) -> GET https://api.example.com/v1/data [policy:default-egress]" ); } @@ -416,7 +464,7 @@ mod tests { let shorthand = event.format_shorthand(); assert_eq!( shorthand, - "14:00:00.000 I SSH:OPEN ALLOWED 10.42.0.1:48201 [auth:NSSH1]" + "SSH:OPEN [INFO] ALLOWED 10.42.0.1:48201 [auth:NSSH1]" ); } @@ -435,7 +483,7 @@ mod tests { let shorthand = event.format_shorthand(); assert_eq!( shorthand, - "14:00:00.000 I PROC:LAUNCH python3(42) [cmd:python3 /app/main.py]" + "PROC:LAUNCH [INFO] python3(42) [cmd:python3 /app/main.py]" ); } @@ -459,10 +507,7 @@ mod tests { }); let shorthand = event.format_shorthand(); - assert_eq!( - shorthand, - "14:00:00.000 I PROC:TERMINATE python3(42) [exit:0]" - ); + assert_eq!(shorthand, "PROC:TERMINATE [INFO] python3(42) [exit:0]"); } #[test] @@ -487,7 +532,7 @@ mod tests { let shorthand = event.format_shorthand(); assert_eq!( shorthand, - "14:00:00.000 H FINDING:BLOCKED \"NSSH1 Nonce Replay Attack\" [confidence:high]" + "FINDING:BLOCKED [HIGH] \"NSSH1 Nonce Replay Attack\" [confidence:high]" ); } @@ -514,7 +559,7 @@ mod tests { let shorthand = event.format_shorthand(); assert_eq!( shorthand, - "14:00:00.000 I LIFECYCLE:START openshell-sandbox success" + "LIFECYCLE:START [INFO] openshell-sandbox success" ); } @@ -536,7 +581,7 @@ mod tests { let shorthand = event.format_shorthand(); assert_eq!( shorthand, - "14:00:00.000 I CONFIG:LOADED policy reloaded [version:v3 hash:sha256:abc123def456]" + "CONFIG:LOADED [INFO] policy reloaded [version:v3 hash:sha256:abc123def456]" ); } @@ -551,7 +596,7 @@ mod tests { let shorthand = event.format_shorthand(); assert_eq!( shorthand, - "14:00:00.000 I EVENT Network namespace created [ns:openshell-sandbox-abc123]" + "EVENT [INFO] Network namespace created [ns:openshell-sandbox-abc123]" ); } } diff --git a/crates/openshell-ocsf/src/lib.rs b/crates/openshell-ocsf/src/lib.rs index a70a344ba..b9000afcf 100644 --- a/crates/openshell-ocsf/src/lib.rs +++ b/crates/openshell-ocsf/src/lib.rs @@ -62,4 +62,6 @@ pub use builders::{ }; // --- Tracing layers --- -pub use tracing_layers::{OcsfJsonlLayer, OcsfShorthandLayer, emit_ocsf_event}; +pub use tracing_layers::{ + OCSF_TARGET, OcsfJsonlLayer, OcsfShorthandLayer, clone_current_event, emit_ocsf_event, +}; diff --git a/crates/openshell-ocsf/src/tracing_layers/jsonl_layer.rs b/crates/openshell-ocsf/src/tracing_layers/jsonl_layer.rs index 4466c0ab2..1f7022ef8 100644 --- a/crates/openshell-ocsf/src/tracing_layers/jsonl_layer.rs +++ b/crates/openshell-ocsf/src/tracing_layers/jsonl_layer.rs @@ -4,7 +4,9 @@ //! Tracing layer that writes OCSF JSONL to a writer. use std::io::Write; +use std::sync::Arc; use std::sync::Mutex; +use std::sync::atomic::{AtomicBool, Ordering}; use tracing::Subscriber; use tracing_subscriber::Layer; @@ -15,8 +17,15 @@ use crate::tracing_layers::event_bridge::{OCSF_TARGET, clone_current_event}; /// A tracing `Layer` that intercepts OCSF events and writes JSONL output. /// /// Only events with `target: "ocsf"` are processed; non-OCSF events are ignored. +/// +/// An optional enabled flag (`Arc`) can be set via +/// [`with_enabled_flag`](Self::with_enabled_flag). When the flag is present and +/// `false`, the layer short-circuits without writing. This allows the sandbox +/// to hot-toggle OCSF JSONL output at runtime via the `ocsf_json_enabled` +/// setting without rebuilding the subscriber. pub struct OcsfJsonlLayer { writer: Mutex, + enabled: Option>, } impl OcsfJsonlLayer { @@ -25,8 +34,19 @@ impl OcsfJsonlLayer { pub fn new(writer: W) -> Self { Self { writer: Mutex::new(writer), + enabled: None, } } + + /// Attach a shared boolean flag that controls whether the layer writes. + /// + /// When the flag is `false`, the layer receives events but discards them. + /// When the flag is absent (the default), the layer always writes. + #[must_use] + pub fn with_enabled_flag(mut self, flag: Arc) -> Self { + self.enabled = Some(flag); + self + } } impl Layer for OcsfJsonlLayer @@ -39,6 +59,13 @@ where return; } + // If an enabled flag is set and it reads `false`, skip writing. + if let Some(ref flag) = self.enabled { + if !flag.load(Ordering::Relaxed) { + return; + } + } + if let Some(ocsf_event) = clone_current_event() && let Ok(line) = ocsf_event.to_json_line() && let Ok(mut w) = self.writer.lock() diff --git a/crates/openshell-ocsf/src/tracing_layers/mod.rs b/crates/openshell-ocsf/src/tracing_layers/mod.rs index a8213a299..c8e5d9f2e 100644 --- a/crates/openshell-ocsf/src/tracing_layers/mod.rs +++ b/crates/openshell-ocsf/src/tracing_layers/mod.rs @@ -11,6 +11,6 @@ pub(crate) mod event_bridge; mod jsonl_layer; mod shorthand_layer; -pub use event_bridge::emit_ocsf_event; +pub use event_bridge::{OCSF_TARGET, clone_current_event, emit_ocsf_event}; pub use jsonl_layer::OcsfJsonlLayer; pub use shorthand_layer::OcsfShorthandLayer; diff --git a/crates/openshell-ocsf/src/tracing_layers/shorthand_layer.rs b/crates/openshell-ocsf/src/tracing_layers/shorthand_layer.rs index f8a39f6aa..ea75cf0dc 100644 --- a/crates/openshell-ocsf/src/tracing_layers/shorthand_layer.rs +++ b/crates/openshell-ocsf/src/tracing_layers/shorthand_layer.rs @@ -6,6 +6,7 @@ use std::io::Write; use std::sync::Mutex; +use chrono::Utc; use tracing::Subscriber; use tracing_subscriber::Layer; use tracing_subscriber::layer::Context; @@ -16,6 +17,10 @@ use crate::tracing_layers::event_bridge::{OCSF_TARGET, clone_current_event}; /// /// Events with `target: "ocsf"` are formatted via `format_shorthand()`. /// Non-OCSF events are formatted with a simple fallback. +/// +/// Each line is prefixed with a UTC timestamp (`YYYY-MM-DDTHH:MM:SS.mmmZ`) +/// since this layer writes directly to a file with no outer display layer +/// to supply timestamps. pub struct OcsfShorthandLayer { writer: Mutex, /// Whether to include non-OCSF events in the output. @@ -48,12 +53,14 @@ where fn on_event(&self, event: &tracing::Event<'_>, _ctx: Context<'_, S>) { let meta = event.metadata(); + let ts = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ"); + if meta.target() == OCSF_TARGET { // This is an OCSF event — clone from thread-local (non-consuming) if let Some(ocsf_event) = clone_current_event() { let line = ocsf_event.format_shorthand(); if let Ok(mut w) = self.writer.lock() { - let _ = writeln!(w, "{line}"); + let _ = writeln!(w, "{ts} OCSF {line}"); } } } else if self.include_non_ocsf { @@ -64,7 +71,7 @@ where let mut message = String::new(); event.record(&mut MessageVisitor(&mut message)); if let Ok(mut w) = self.writer.lock() { - let _ = writeln!(w, "{level} {target}: {message}"); + let _ = writeln!(w, "{ts} {level} {target}: {message}"); } } } @@ -103,4 +110,46 @@ mod tests { let layer = OcsfShorthandLayer::new(buffer).with_non_ocsf(false); assert!(!layer.include_non_ocsf); } + + #[test] + fn test_non_ocsf_fallback_includes_timestamp() { + use std::sync::Arc; + use tracing_subscriber::layer::SubscriberExt; + use tracing_subscriber::util::SubscriberInitExt; + + let buffer = Arc::new(Mutex::new(Vec::::new())); + let writer = SyncWriter(buffer.clone()); + let layer = OcsfShorthandLayer::new(writer).with_non_ocsf(true); + + let subscriber = tracing_subscriber::registry().with(layer); + let _guard = subscriber.set_default(); + + tracing::info!("test message"); + + let output = buffer.lock().unwrap(); + let line = String::from_utf8_lossy(&output); + // Should start with a timestamp like 2026-04-01T... + assert!( + line.contains('T') && line.contains('Z'), + "Expected timestamp in output, got: {line}" + ); + assert!( + line.contains("test message"), + "Expected message, got: {line}" + ); + } +} + +/// Test helper: wraps `Arc>>` so it implements `Write + Send`. +#[cfg(test)] +struct SyncWriter(std::sync::Arc>>); + +#[cfg(test)] +impl Write for SyncWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.0.lock().unwrap().write(buf) + } + fn flush(&mut self) -> std::io::Result<()> { + self.0.lock().unwrap().flush() + } } diff --git a/crates/openshell-policy/Cargo.toml b/crates/openshell-policy/Cargo.toml index 311bb4e86..f26136c6b 100644 --- a/crates/openshell-policy/Cargo.toml +++ b/crates/openshell-policy/Cargo.toml @@ -13,7 +13,7 @@ repository.workspace = true [dependencies] openshell-core = { path = "../openshell-core" } serde = { workspace = true } -serde_yaml = { workspace = true } +serde_yml = { workspace = true } miette = { workspace = true } [lints] diff --git a/crates/openshell-policy/src/lib.rs b/crates/openshell-policy/src/lib.rs index f1c15539e..9cf543bdf 100644 --- a/crates/openshell-policy/src/lib.rs +++ b/crates/openshell-policy/src/lib.rs @@ -15,8 +15,8 @@ use std::path::Path; use miette::{IntoDiagnostic, Result, WrapErr}; use openshell_core::proto::{ - FilesystemPolicy, L7Allow, L7Rule, LandlockPolicy, NetworkBinary, NetworkEndpoint, - NetworkPolicyRule, ProcessPolicy, SandboxPolicy, + FilesystemPolicy, L7Allow, L7QueryMatcher, L7Rule, LandlockPolicy, NetworkBinary, + NetworkEndpoint, NetworkPolicyRule, ProcessPolicy, SandboxPolicy, }; use serde::{Deserialize, Serialize}; @@ -82,11 +82,12 @@ struct NetworkEndpointDef { #[serde(default, skip_serializing_if = "String::is_empty")] host: String, /// Single port (backwards compat). Mutually exclusive with `ports`. + /// Uses `u16` to reject invalid values >65535 at parse time. #[serde(default, skip_serializing_if = "is_zero")] - port: u32, + port: u16, /// Multiple ports. When non-empty, this endpoint covers all listed ports. #[serde(default, skip_serializing_if = "Vec::is_empty")] - ports: Vec, + ports: Vec, #[serde(default, skip_serializing_if = "String::is_empty")] protocol: String, #[serde(default, skip_serializing_if = "String::is_empty")] @@ -101,7 +102,7 @@ struct NetworkEndpointDef { allowed_ips: Vec, } -fn is_zero(v: &u32) -> bool { +fn is_zero(v: &u16) -> bool { *v == 0 } @@ -120,6 +121,22 @@ struct L7AllowDef { path: String, #[serde(default, skip_serializing_if = "String::is_empty")] command: String, + #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] + query: BTreeMap, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +enum QueryMatcherDef { + Glob(String), + Any(QueryAnyDef), +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +struct QueryAnyDef { + #[serde(default, skip_serializing_if = "Vec::is_empty")] + any: Vec, } #[derive(Debug, Serialize, Deserialize)] @@ -153,10 +170,10 @@ fn to_proto(raw: PolicyFile) -> SandboxPolicy { .map(|e| { // Normalize port/ports: ports takes precedence, else // single port is promoted to ports array. - let normalized_ports = if !e.ports.is_empty() { - e.ports + let normalized_ports: Vec = if !e.ports.is_empty() { + e.ports.into_iter().map(u32::from).collect() } else if e.port > 0 { - vec![e.port] + vec![u32::from(e.port)] } else { vec![] }; @@ -176,6 +193,23 @@ fn to_proto(raw: PolicyFile) -> SandboxPolicy { method: r.allow.method, path: r.allow.path, command: r.allow.command, + query: r + .allow + .query + .into_iter() + .map(|(key, matcher)| { + let proto = match matcher { + QueryMatcherDef::Glob(glob) => { + L7QueryMatcher { glob, any: vec![] } + } + QueryMatcherDef::Any(any) => L7QueryMatcher { + glob: String::new(), + any: any.any, + }, + }; + (key, proto) + }) + .collect(), }), }) .collect(), @@ -252,10 +286,12 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { .map(|e| { // Use compact form: if ports has exactly 1 element, // emit port (scalar). If >1, emit ports (array). + // Proto uses u32; YAML uses u16. Clamp at boundary. + let clamp = |v: u32| -> u16 { v.min(65535) as u16 }; let (port, ports) = if e.ports.len() > 1 { - (0, e.ports.clone()) + (0, e.ports.iter().map(|&p| clamp(p)).collect()) } else { - (e.ports.first().copied().unwrap_or(e.port), vec![]) + (clamp(e.ports.first().copied().unwrap_or(e.port)), vec![]) }; NetworkEndpointDef { host: e.host.clone(), @@ -275,6 +311,20 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { method: a.method, path: a.path, command: a.command, + query: a + .query + .into_iter() + .map(|(key, matcher)| { + let yaml_matcher = if !matcher.any.is_empty() { + QueryMatcherDef::Any(QueryAnyDef { + any: matcher.any, + }) + } else { + QueryMatcherDef::Glob(matcher.glob) + }; + (key, yaml_matcher) + }) + .collect(), }, } }) @@ -311,7 +361,7 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { /// Parse a sandbox policy from a YAML string. pub fn parse_sandbox_policy(yaml: &str) -> Result { - let raw: PolicyFile = serde_yaml::from_str(yaml) + let raw: PolicyFile = serde_yml::from_str(yaml) .into_diagnostic() .wrap_err("failed to parse sandbox policy YAML")?; Ok(to_proto(raw)) @@ -324,7 +374,7 @@ pub fn parse_sandbox_policy(yaml: &str) -> Result { /// and is round-trippable through `parse_sandbox_policy`. pub fn serialize_sandbox_policy(policy: &SandboxPolicy) -> Result { let yaml_repr = from_proto(policy); - serde_yaml::to_string(&yaml_repr) + serde_yml::to_string(&yaml_repr) .into_diagnostic() .wrap_err("failed to serialize policy to YAML") } @@ -754,6 +804,49 @@ network_policies: assert_eq!(rule.binaries[0].path, "/usr/bin/curl"); } + #[test] + fn parse_l7_query_matchers_and_round_trip() { + let yaml = r#" +version: 1 +network_policies: + query_test: + name: query_test + endpoints: + - host: api.example.com + port: 8080 + protocol: rest + rules: + - allow: + method: GET + path: /download + query: + slug: "my-*" + tag: + any: ["foo-*", "bar-*"] + binaries: + - path: /usr/bin/curl +"#; + let proto = parse_sandbox_policy(yaml).expect("parse failed"); + let allow = proto.network_policies["query_test"].endpoints[0].rules[0] + .allow + .as_ref() + .expect("allow"); + assert_eq!(allow.query["slug"].glob, "my-*"); + assert_eq!(allow.query["slug"].any, Vec::::new()); + assert_eq!(allow.query["tag"].any, vec!["foo-*", "bar-*"]); + assert!(allow.query["tag"].glob.is_empty()); + + let yaml_out = serialize_sandbox_policy(&proto).expect("serialize failed"); + let proto_round_trip = parse_sandbox_policy(&yaml_out).expect("re-parse failed"); + let allow_round_trip = proto_round_trip.network_policies["query_test"].endpoints[0].rules + [0] + .allow + .as_ref() + .expect("allow"); + assert_eq!(allow_round_trip.query["slug"].glob, "my-*"); + assert_eq!(allow_round_trip.query["tag"].any, vec!["foo-*", "bar-*"]); + } + #[test] fn parse_rejects_unknown_fields() { let yaml = "version: 1\nbogus_field: true\n"; @@ -1117,4 +1210,20 @@ network_policies: proto2.network_policies["test"].endpoints[0].host ); } + + #[test] + fn rejects_port_above_65535() { + let yaml = r#" +version: 1 +network_policies: + test: + endpoints: + - host: example.com + port: 70000 +"#; + assert!( + parse_sandbox_policy(yaml).is_err(), + "port >65535 should fail to parse" + ); + } } diff --git a/crates/openshell-router/Cargo.toml b/crates/openshell-router/Cargo.toml index dc8e9c924..e4c3d5ea7 100644 --- a/crates/openshell-router/Cargo.toml +++ b/crates/openshell-router/Cargo.toml @@ -19,7 +19,7 @@ serde_json = { workspace = true } thiserror = { workspace = true } tracing = { workspace = true } tokio = { workspace = true } -serde_yaml = { workspace = true } +serde_yml = { workspace = true } uuid = { workspace = true } [dev-dependencies] diff --git a/crates/openshell-router/src/backend.rs b/crates/openshell-router/src/backend.rs index 0708cfb03..d1d7092c0 100644 --- a/crates/openshell-router/src/backend.rs +++ b/crates/openshell-router/src/backend.rs @@ -31,6 +31,11 @@ struct ValidationProbe { path: &'static str, protocol: &'static str, body: bytes::Bytes, + /// Alternate body to try when the primary probe fails with HTTP 400. + /// Used for OpenAI chat completions where newer models require + /// `max_completion_tokens` while legacy/self-hosted backends only + /// accept `max_tokens`. + fallback_body: Option, } /// Response from a proxied HTTP request to a backend (fully buffered). @@ -144,7 +149,7 @@ async fn send_backend_request( } Err(_) => body, }; - builder = builder.body(body); + builder = builder.body(body).timeout(route.timeout); builder.send().await.map_err(|e| { if e.is_timeout() { @@ -163,12 +168,17 @@ fn validation_probe(route: &ResolvedRoute) -> Result Result Result Result, + body: bytes::Bytes, +) -> Result { + let response = send_backend_request(client, route, "POST", path, headers, body) .await .map_err(|err| match err { RouterError::UpstreamUnavailable(details) => ValidationFailure { @@ -253,12 +306,12 @@ pub async fn verify_backend_endpoint( details, }, })?; - let url = build_backend_url(&route.endpoint, probe.path); + let url = build_backend_url(&route.endpoint, path); if response.status().is_success() { return Ok(ValidatedEndpoint { url, - protocol: probe.protocol.to_string(), + protocol: protocol.to_string(), }); } @@ -376,7 +429,7 @@ fn build_backend_url(endpoint: &str, path: &str) -> String { #[cfg(test)] mod tests { - use super::{build_backend_url, verify_backend_endpoint}; + use super::{ValidationFailureKind, build_backend_url, verify_backend_endpoint}; use crate::config::ResolvedRoute; use openshell_core::inference::AuthHeader; use wiremock::matchers::{body_partial_json, header, method, path}; @@ -415,6 +468,7 @@ mod tests { protocols: protocols.iter().map(|p| (*p).to_string()).collect(), auth, default_headers: vec![("anthropic-version".to_string(), "2023-06-01".to_string())], + timeout: crate::config::DEFAULT_ROUTE_TIMEOUT, } } @@ -463,4 +517,102 @@ mod tests { assert_eq!(validated.protocol, "openai_chat_completions"); assert_eq!(validated.url, "mock://test-backend/v1/chat/completions"); } + + /// GPT-5+ models reject `max_tokens` — the primary probe uses + /// `max_completion_tokens` so validation should succeed directly. + #[tokio::test] + async fn verify_openai_chat_uses_max_completion_tokens() { + let mock_server = MockServer::start().await; + let route = test_route( + &mock_server.uri(), + &["openai_chat_completions"], + AuthHeader::Bearer, + ); + + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .and(body_partial_json(serde_json::json!({ + "max_completion_tokens": 32, + }))) + .respond_with( + ResponseTemplate::new(200).set_body_json(serde_json::json!({"id": "chatcmpl-1"})), + ) + .mount(&mock_server) + .await; + + let client = reqwest::Client::builder().build().unwrap(); + let validated = verify_backend_endpoint(&client, &route).await.unwrap(); + + assert_eq!(validated.protocol, "openai_chat_completions"); + } + + /// Legacy/self-hosted backends that reject `max_completion_tokens` + /// should succeed on the fallback probe using `max_tokens`. + #[tokio::test] + async fn verify_openai_chat_falls_back_to_max_tokens() { + let mock_server = MockServer::start().await; + let route = test_route( + &mock_server.uri(), + &["openai_chat_completions"], + AuthHeader::Bearer, + ); + + // Reject the primary probe (max_completion_tokens) with 400. + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .and(body_partial_json(serde_json::json!({ + "max_completion_tokens": 32, + }))) + .respond_with(ResponseTemplate::new(400).set_body_string( + r#"{"error":{"message":"Unsupported parameter: 'max_completion_tokens'"}}"#, + )) + .expect(1) + .mount(&mock_server) + .await; + + // Accept the fallback probe (max_tokens). + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .and(body_partial_json(serde_json::json!({ + "max_tokens": 32, + }))) + .respond_with( + ResponseTemplate::new(200).set_body_json(serde_json::json!({"id": "chatcmpl-2"})), + ) + .expect(1) + .mount(&mock_server) + .await; + + let client = reqwest::Client::builder().build().unwrap(); + let validated = verify_backend_endpoint(&client, &route).await.unwrap(); + + assert_eq!(validated.protocol, "openai_chat_completions"); + } + + /// Non-chat-completions probes (e.g. anthropic_messages) should not + /// have a fallback — a 400 remains a hard failure. + #[tokio::test] + async fn verify_non_chat_completions_no_fallback() { + let mock_server = MockServer::start().await; + let route = test_route( + &mock_server.uri(), + &["anthropic_messages"], + AuthHeader::Custom("x-api-key"), + ); + + Mock::given(method("POST")) + .and(path("/v1/messages")) + .respond_with(ResponseTemplate::new(400).set_body_string("bad request")) + .mount(&mock_server) + .await; + + let client = reqwest::Client::builder().build().unwrap(); + let result = verify_backend_endpoint(&client, &route).await; + + assert!(result.is_err()); + assert_eq!( + result.unwrap_err().kind, + ValidationFailureKind::RequestShape + ); + } } diff --git a/crates/openshell-router/src/config.rs b/crates/openshell-router/src/config.rs index d9c081d60..b531e091d 100644 --- a/crates/openshell-router/src/config.rs +++ b/crates/openshell-router/src/config.rs @@ -3,11 +3,14 @@ use serde::Deserialize; use std::path::Path; +use std::time::Duration; pub use openshell_core::inference::AuthHeader; use crate::RouterError; +pub const DEFAULT_ROUTE_TIMEOUT: Duration = Duration::from_secs(60); + #[derive(Debug, Clone, Deserialize)] pub struct RouterConfig { pub routes: Vec, @@ -45,6 +48,8 @@ pub struct ResolvedRoute { pub auth: AuthHeader, /// Extra headers injected on every request (e.g. `anthropic-version`). pub default_headers: Vec<(String, String)>, + /// Per-request timeout for proxied inference calls. + pub timeout: Duration, } impl std::fmt::Debug for ResolvedRoute { @@ -57,6 +62,7 @@ impl std::fmt::Debug for ResolvedRoute { .field("protocols", &self.protocols) .field("auth", &self.auth) .field("default_headers", &self.default_headers) + .field("timeout", &self.timeout) .finish() } } @@ -69,7 +75,7 @@ impl RouterConfig { path.display() )) })?; - let config: Self = serde_yaml::from_str(&content).map_err(|e| { + let config: Self = serde_yml::from_str(&content).map_err(|e| { RouterError::Internal(format!( "failed to parse router config {}: {e}", path.display() @@ -129,6 +135,7 @@ impl RouteConfig { protocols, auth, default_headers, + timeout: DEFAULT_ROUTE_TIMEOUT, }) } } @@ -256,6 +263,7 @@ routes: protocols: vec!["openai_chat_completions".to_string()], auth: AuthHeader::Bearer, default_headers: Vec::new(), + timeout: DEFAULT_ROUTE_TIMEOUT, }; let debug_output = format!("{route:?}"); assert!( diff --git a/crates/openshell-router/src/lib.rs b/crates/openshell-router/src/lib.rs index a5712d9a0..7deed6fc4 100644 --- a/crates/openshell-router/src/lib.rs +++ b/crates/openshell-router/src/lib.rs @@ -5,8 +5,6 @@ mod backend; pub mod config; mod mock; -use std::time::Duration; - pub use backend::{ ProxyResponse, StreamingProxyResponse, ValidatedEndpoint, ValidationFailure, ValidationFailureKind, verify_backend_endpoint, @@ -39,7 +37,6 @@ pub struct Router { impl Router { pub fn new() -> Result { let client = reqwest::Client::builder() - .timeout(Duration::from_secs(60)) .build() .map_err(|e| RouterError::Internal(format!("failed to build HTTP client: {e}")))?; Ok(Self { diff --git a/crates/openshell-router/src/mock.rs b/crates/openshell-router/src/mock.rs index 9b6accb60..a17ce486f 100644 --- a/crates/openshell-router/src/mock.rs +++ b/crates/openshell-router/src/mock.rs @@ -131,6 +131,7 @@ mod tests { protocols: protocols.iter().map(ToString::to_string).collect(), auth: crate::config::AuthHeader::Bearer, default_headers: Vec::new(), + timeout: crate::config::DEFAULT_ROUTE_TIMEOUT, } } diff --git a/crates/openshell-router/tests/backend_integration.rs b/crates/openshell-router/tests/backend_integration.rs index 4861bd6d0..571964aa8 100644 --- a/crates/openshell-router/tests/backend_integration.rs +++ b/crates/openshell-router/tests/backend_integration.rs @@ -15,6 +15,7 @@ fn mock_candidates(base_url: &str) -> Vec { protocols: vec!["openai_chat_completions".to_string()], auth: AuthHeader::Bearer, default_headers: Vec::new(), + timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, }] } @@ -117,6 +118,7 @@ async fn proxy_no_compatible_route_returns_error() { protocols: vec!["anthropic_messages".to_string()], auth: AuthHeader::Custom("x-api-key"), default_headers: Vec::new(), + timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, }]; let err = router @@ -178,6 +180,7 @@ async fn proxy_mock_route_returns_canned_response() { protocols: vec!["openai_chat_completions".to_string()], auth: AuthHeader::Bearer, default_headers: Vec::new(), + timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, }]; let body = serde_json::to_vec(&serde_json::json!({ @@ -312,6 +315,7 @@ async fn proxy_uses_x_api_key_for_anthropic_route() { protocols: vec!["anthropic_messages".to_string()], auth: AuthHeader::Custom("x-api-key"), default_headers: vec![("anthropic-version".to_string(), "2023-06-01".to_string())], + timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, }]; let body = serde_json::to_vec(&serde_json::json!({ @@ -370,6 +374,7 @@ async fn proxy_anthropic_does_not_send_bearer_auth() { protocols: vec!["anthropic_messages".to_string()], auth: AuthHeader::Custom("x-api-key"), default_headers: vec![("anthropic-version".to_string(), "2023-06-01".to_string())], + timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, }]; let response = router @@ -414,6 +419,7 @@ async fn proxy_forwards_client_anthropic_version_header() { protocols: vec!["anthropic_messages".to_string()], auth: AuthHeader::Custom("x-api-key"), default_headers: vec![("anthropic-version".to_string(), "2023-06-01".to_string())], + timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, }]; let body = serde_json::to_vec(&serde_json::json!({ diff --git a/crates/openshell-sandbox/Cargo.toml b/crates/openshell-sandbox/Cargo.toml index 8a0639a7d..541784ee6 100644 --- a/crates/openshell-sandbox/Cargo.toml +++ b/crates/openshell-sandbox/Cargo.toml @@ -16,6 +16,7 @@ path = "src/main.rs" [dependencies] openshell-core = { path = "../openshell-core" } +openshell-ocsf = { path = "../openshell-ocsf" } openshell-policy = { path = "../openshell-policy" } openshell-router = { path = "../openshell-router" } @@ -52,12 +53,15 @@ webpki-roots = { workspace = true } # HTTP bytes = { workspace = true } +# Encoding +base64 = { workspace = true } + # IP network / CIDR parsing ipnet = "2" # Serialization serde_json = { workspace = true } -serde_yaml = { workspace = true } +serde_yml = { workspace = true } # Logging tracing = { workspace = true } @@ -78,6 +82,8 @@ uuid = { version = "1", features = ["v4"] } [dev-dependencies] tempfile = "3" temp-env = "0.3" +tokio-tungstenite = { workspace = true } +futures = { workspace = true } [lints] workspace = true diff --git a/crates/openshell-sandbox/data/sandbox-policy.rego b/crates/openshell-sandbox/data/sandbox-policy.rego index 1544dfe55..0a7a33888 100644 --- a/crates/openshell-sandbox/data/sandbox-policy.rego +++ b/crates/openshell-sandbox/data/sandbox-policy.rego @@ -208,6 +208,7 @@ request_allowed_for_endpoint(request, endpoint) if { rule.allow.method method_matches(request.method, rule.allow.method) path_matches(request.path, rule.allow.path) + query_params_match(request, rule) } # --- L7 rule matching: SQL command --- @@ -235,6 +236,55 @@ path_matches(actual, pattern) if { glob.match(pattern, ["/"], actual) } +# Query matching: +# - If no query rules are configured, allow any query params. +# - For configured keys, all request values for that key must match. +# - Matcher shape supports either `glob` or `any`. +query_params_match(request, rule) if { + query_rules := object.get(rule.allow, "query", {}) + not query_mismatch(request, query_rules) +} + +query_mismatch(request, query_rules) if { + some key + matcher := query_rules[key] + not query_key_matches(request, key, matcher) +} + +query_key_matches(request, key, matcher) if { + request_query := object.get(request, "query_params", {}) + values := object.get(request_query, key, null) + values != null + count(values) > 0 + not query_value_mismatch(values, matcher) +} + +query_value_mismatch(values, matcher) if { + some i + value := values[i] + not query_value_matches(value, matcher) +} + +query_value_matches(value, matcher) if { + is_string(matcher) + glob.match(matcher, [], value) +} + +query_value_matches(value, matcher) if { + is_object(matcher) + glob_pattern := object.get(matcher, "glob", "") + glob_pattern != "" + glob.match(glob_pattern, [], value) +} + +query_value_matches(value, matcher) if { + is_object(matcher) + any_patterns := object.get(matcher, "any", []) + count(any_patterns) > 0 + some i + glob.match(any_patterns[i], [], value) +} + # SQL command matching: "*" matches any; otherwise case-insensitive. command_matches(_, "*") if true diff --git a/crates/openshell-sandbox/src/bypass_monitor.rs b/crates/openshell-sandbox/src/bypass_monitor.rs index f99d74934..d0e49c42d 100644 --- a/crates/openshell-sandbox/src/bypass_monitor.rs +++ b/crates/openshell-sandbox/src/bypass_monitor.rs @@ -17,10 +17,14 @@ //! still provide fast-fail UX — the monitor only adds diagnostic visibility. use crate::denial_aggregator::DenialEvent; +use openshell_ocsf::{ + ActionId, ActivityId, ConfidenceId, DetectionFindingBuilder, DispositionId, Endpoint, + FindingInfo, NetworkActivityBuilder, Process, SeverityId, ocsf_emit, +}; use std::sync::Arc; use std::sync::atomic::{AtomicU32, Ordering}; use tokio::sync::mpsc; -use tracing::{debug, warn}; +use tracing::debug; /// A parsed iptables LOG entry from `/dev/kmsg`. #[derive(Debug, Clone, PartialEq, Eq)] @@ -126,10 +130,15 @@ pub fn spawn( .status(); if !dmesg_check.is_ok_and(|s| s.success()) { - warn!( - "dmesg not available; bypass detection monitor will not run. \ - Bypass REJECT rules still provide fast-fail behavior." - ); + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .severity(SeverityId::Low) + .message( + "dmesg not available; bypass detection monitor will not run. \ + Bypass REJECT rules still provide fast-fail behavior.", + ) + .build(); + ocsf_emit!(event); return None; } @@ -149,7 +158,14 @@ pub fn spawn( { Ok(c) => c, Err(e) => { - warn!(error = %e, "Failed to start dmesg --follow; bypass monitor will not run"); + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .severity(SeverityId::Low) + .message(format!( + "Failed to start dmesg --follow; bypass monitor will not run: {e}" + )) + .build(); + ocsf_emit!(event); return; } }; @@ -157,7 +173,12 @@ pub fn spawn( let stdout = match child.stdout.take() { Some(s) => s, None => { - warn!("dmesg --follow produced no stdout; bypass monitor will not run"); + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .severity(SeverityId::Low) + .message("dmesg --follow produced no stdout; bypass monitor will not run") + .build(); + ocsf_emit!(event); return; } }; @@ -186,19 +207,59 @@ pub fn spawn( }; let hint = hint_for_event(&event); + let reason = "direct connection bypassed HTTP CONNECT proxy"; - warn!( - dst_addr = %event.dst_addr, - dst_port = event.dst_port, - proto = %event.proto, - binary = %binary, - binary_pid = %binary_pid, - ancestors = %ancestors, - action = "reject", - reason = "direct connection bypassed HTTP CONNECT proxy", - hint = hint, - "BYPASS_DETECT", - ); + // Dual-emit: Network Activity [4001] + Detection Finding [2004] + { + let dst_ep = if let Ok(ip) = event.dst_addr.parse::() { + Endpoint::from_ip(ip, event.dst_port) + } else { + Endpoint::from_domain(&event.dst_addr, event.dst_port) + }; + + let net_event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Refuse) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .dst_endpoint(dst_ep.clone()) + .actor_process(Process::from_bypass(&binary, &binary_pid, &ancestors)) + .firewall_rule("bypass-detect", "iptables") + .observation_point(3) + .message(format!( + "BYPASS_DETECT {}:{} proto={} binary={binary} action=reject reason={reason}", + event.dst_addr, event.dst_port, event.proto, + )) + .build(); + ocsf_emit!(net_event); + + let finding_event = DetectionFindingBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .is_alert(true) + .confidence(ConfidenceId::High) + .finding_info( + FindingInfo::new("bypass-detect", "Proxy Bypass Detected") + .with_desc(reason), + ) + .remediation(hint) + .evidence_pairs(&[ + ("dst_addr", &event.dst_addr), + ("dst_port", &event.dst_port.to_string()), + ("proto", &event.proto), + ("binary", &binary), + ("binary_pid", &binary_pid), + ("ancestors", &ancestors), + ]) + .message(format!( + "BYPASS_DETECT {}:{} proto={} binary={binary} hint={hint}", + event.dst_addr, event.dst_port, event.proto, + )) + .build(); + ocsf_emit!(finding_event); + } // Send to denial aggregator if available. if let Some(ref tx) = denial_tx { diff --git a/crates/openshell-sandbox/src/identity.rs b/crates/openshell-sandbox/src/identity.rs index d27976ba7..49809f95b 100644 --- a/crates/openshell-sandbox/src/identity.rs +++ b/crates/openshell-sandbox/src/identity.rs @@ -16,6 +16,7 @@ use std::fs::Metadata; use std::os::unix::fs::MetadataExt; use std::path::{Path, PathBuf}; use std::sync::Mutex; +use tracing::debug; #[derive(Clone)] struct FileFingerprint { @@ -100,6 +101,7 @@ impl BinaryIdentityCache { where F: FnMut(&Path) -> Result, { + let start = std::time::Instant::now(); let metadata = std::fs::metadata(path) .map_err(|error| miette::miette!("Failed to stat {}: {error}", path.display()))?; let fingerprint = FileFingerprint::from_metadata(&metadata); @@ -114,9 +116,20 @@ impl BinaryIdentityCache { if let Some(cached_binary) = &cached && cached_binary.fingerprint == fingerprint { + debug!( + " verify_or_cache: {}ms CACHE HIT path={}", + start.elapsed().as_millis(), + path.display() + ); return Ok(cached_binary.hash.clone()); } + debug!( + " verify_or_cache: CACHE MISS size={} path={}", + metadata.len(), + path.display() + ); + let current_hash = hash_file(path)?; let mut hashes = self @@ -143,6 +156,12 @@ impl BinaryIdentityCache { }, ); + debug!( + " verify_or_cache TOTAL (cold): {}ms path={}", + start.elapsed().as_millis(), + path.display() + ); + Ok(current_hash) } } diff --git a/crates/openshell-sandbox/src/l7/inference.rs b/crates/openshell-sandbox/src/l7/inference.rs index 59dafdaba..5136c8783 100644 --- a/crates/openshell-sandbox/src/l7/inference.rs +++ b/crates/openshell-sandbox/src/l7/inference.rs @@ -96,6 +96,8 @@ pub enum ParseResult { Complete(ParsedHttpRequest, usize), /// Headers are incomplete — caller should read more data. Incomplete, + /// The request is malformed and must be rejected (e.g., duplicate Content-Length). + Invalid(String), } /// Try to parse an HTTP/1.1 request from raw bytes. @@ -125,6 +127,7 @@ pub fn try_parse_http_request(buf: &[u8]) -> ParseResult { let mut headers = Vec::new(); let mut content_length: usize = 0; + let mut has_content_length = false; let mut is_chunked = false; for line in lines { if line.is_empty() { @@ -134,7 +137,21 @@ pub fn try_parse_http_request(buf: &[u8]) -> ParseResult { let name = name.trim().to_string(); let value = value.trim().to_string(); if name.eq_ignore_ascii_case("content-length") { - content_length = value.parse().unwrap_or(0); + let new_len: usize = match value.parse() { + Ok(v) => v, + Err(_) => { + return ParseResult::Invalid(format!( + "invalid Content-Length value: {value}" + )); + } + }; + if has_content_length && new_len != content_length { + return ParseResult::Invalid(format!( + "duplicate Content-Length headers with differing values ({content_length} vs {new_len})" + )); + } + content_length = new_len; + has_content_length = true; } if name.eq_ignore_ascii_case("transfer-encoding") && value @@ -147,6 +164,12 @@ pub fn try_parse_http_request(buf: &[u8]) -> ParseResult { } } + if is_chunked && has_content_length { + return ParseResult::Invalid( + "Request contains both Transfer-Encoding and Content-Length headers".to_string(), + ); + } + let (body, consumed) = if is_chunked { let Some((decoded_body, consumed)) = parse_chunked_body(buf, body_start) else { return ParseResult::Incomplete; @@ -552,4 +575,94 @@ mod tests { }; assert_eq!(parsed.body.len(), 100); } + + /// SEC: Transfer-Encoding substring match must not match partial tokens. + #[test] + fn te_substring_not_chunked() { + let body = r#"{"model":"m","messages":[]}"#; + let request = format!( + "POST /v1/chat/completions HTTP/1.1\r\n\ + Host: x\r\n\ + Transfer-Encoding: chunkedx\r\n\ + Content-Length: {}\r\n\ + \r\n{body}", + body.len(), + ); + let ParseResult::Complete(parsed, _) = try_parse_http_request(request.as_bytes()) else { + panic!("expected Complete for non-matching TE with valid CL"); + }; + assert_eq!(parsed.body.len(), body.len()); + } + + // ---- SEC: Content-Length validation ---- + + #[test] + fn reject_differing_duplicate_content_length() { + let request = b"POST /v1/chat/completions HTTP/1.1\r\nHost: x\r\nContent-Length: 0\r\nContent-Length: 50\r\n\r\n"; + assert!(matches!( + try_parse_http_request(request), + ParseResult::Invalid(reason) if reason.contains("differing values") + )); + } + + #[test] + fn accept_identical_duplicate_content_length() { + let request = b"POST /v1/chat/completions HTTP/1.1\r\nHost: x\r\nContent-Length: 5\r\nContent-Length: 5\r\n\r\nhello"; + let ParseResult::Complete(parsed, _) = try_parse_http_request(request) else { + panic!("expected Complete for identical duplicate CL"); + }; + assert_eq!(parsed.body, b"hello"); + } + + #[test] + fn reject_non_numeric_content_length() { + let request = + b"POST /v1/chat/completions HTTP/1.1\r\nHost: x\r\nContent-Length: abc\r\n\r\n"; + assert!(matches!( + try_parse_http_request(request), + ParseResult::Invalid(reason) if reason.contains("invalid Content-Length") + )); + } + + #[test] + fn reject_two_non_numeric_content_lengths() { + let request = b"POST /v1/chat/completions HTTP/1.1\r\nHost: x\r\nContent-Length: abc\r\nContent-Length: def\r\n\r\n"; + assert!(matches!( + try_parse_http_request(request), + ParseResult::Invalid(_) + )); + } + + // ---- SEC-009: CL/TE desynchronisation ---- + + /// Reject requests with both Content-Length and Transfer-Encoding to + /// prevent CL/TE request smuggling (RFC 7230 Section 3.3.3). + #[test] + fn reject_dual_content_length_and_transfer_encoding() { + let request = b"POST /v1/chat/completions HTTP/1.1\r\nHost: x\r\nContent-Length: 5\r\nTransfer-Encoding: chunked\r\n\r\n"; + assert!( + matches!( + try_parse_http_request(request), + ParseResult::Invalid(reason) + if reason.contains("Transfer-Encoding") + && reason.contains("Content-Length") + ), + "Must reject request with both CL and TE" + ); + } + + /// Same rejection regardless of header order. + #[test] + fn reject_dual_transfer_encoding_and_content_length() { + let request = b"POST /v1/chat/completions HTTP/1.1\r\nHost: x\r\nTransfer-Encoding: chunked\r\nContent-Length: 5\r\n\r\n"; + assert!( + matches!( + try_parse_http_request(request), + ParseResult::Invalid(reason) + if reason.contains("Transfer-Encoding") + && reason.contains("Content-Length") + ), + "Must reject request with both TE and CL" + ); + } } diff --git a/crates/openshell-sandbox/src/l7/mod.rs b/crates/openshell-sandbox/src/l7/mod.rs index 09e547885..ca76dc47a 100644 --- a/crates/openshell-sandbox/src/l7/mod.rs +++ b/crates/openshell-sandbox/src/l7/mod.rs @@ -76,6 +76,8 @@ pub struct L7RequestInfo { pub action: String, /// Target: URL path for REST, or empty for SQL. pub target: String, + /// Decoded query parameter multimap for REST requests. + pub query_params: std::collections::HashMap>, } /// Parse an L7 endpoint config from a regorus Value (returned by Rego query). @@ -89,17 +91,27 @@ pub fn parse_l7_config(val: ®orus::Value) -> Option { let tls = match get_object_str(val, "tls").as_deref() { Some("skip") => TlsMode::Skip, Some("terminate") => { - tracing::warn!( - "'tls: terminate' is deprecated; TLS termination is now automatic. \ - Use 'tls: skip' to explicitly disable. This field will be removed in a future version." - ); + let event = openshell_ocsf::NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(openshell_ocsf::ActivityId::Other) + .severity(openshell_ocsf::SeverityId::Medium) + .message( + "'tls: terminate' is deprecated; TLS termination is now automatic. \ + Use 'tls: skip' to explicitly disable. This field will be removed in a future version.", + ) + .build(); + openshell_ocsf::ocsf_emit!(event); TlsMode::Auto } Some("passthrough") => { - tracing::warn!( - "'tls: passthrough' is deprecated; TLS termination is now automatic. \ - Use 'tls: skip' to explicitly disable. This field will be removed in a future version." - ); + let event = openshell_ocsf::NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(openshell_ocsf::ActivityId::Other) + .severity(openshell_ocsf::SeverityId::Medium) + .message( + "'tls: passthrough' is deprecated; TLS termination is now automatic. \ + Use 'tls: skip' to explicitly disable. This field will be removed in a future version.", + ) + .build(); + openshell_ocsf::ocsf_emit!(event); TlsMode::Auto } _ => TlsMode::Auto, @@ -144,6 +156,49 @@ fn get_object_str(val: ®orus::Value, key: &str) -> Option { } } +/// Check a glob pattern for obvious syntax issues. +/// +/// Returns `Some(warning_message)` if the pattern looks malformed. +/// OPA's `glob.match` is forgiving, so these are warnings (not errors) +/// to surface likely typos without blocking policy loading. +fn check_glob_syntax(pattern: &str) -> Option { + let mut bracket_depth: i32 = 0; + for c in pattern.chars() { + match c { + '[' => bracket_depth += 1, + ']' => { + if bracket_depth == 0 { + return Some(format!("glob pattern '{pattern}' has unmatched ']'")); + } + bracket_depth -= 1; + } + _ => {} + } + } + if bracket_depth > 0 { + return Some(format!("glob pattern '{pattern}' has unclosed '['")); + } + + let mut brace_depth: i32 = 0; + for c in pattern.chars() { + match c { + '{' => brace_depth += 1, + '}' => { + if brace_depth == 0 { + return Some(format!("glob pattern '{pattern}' has unmatched '}}'")); + } + brace_depth -= 1; + } + _ => {} + } + } + if brace_depth > 0 { + return Some(format!("glob pattern '{pattern}' has unclosed '{{'")); + } + + None +} + /// Validate L7 policy configuration in the loaded OPA data. /// /// Returns a list of errors and warnings. Errors should prevent sandbox startup; @@ -279,7 +334,7 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< "GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "*", ]; if let Some(rules) = ep.get("rules").and_then(|v| v.as_array()) { - for rule in rules { + for (rule_idx, rule) in rules.iter().enumerate() { if let Some(method) = rule .get("allow") .and_then(|a| a.get("method")) @@ -291,6 +346,110 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< "{loc}: Unknown HTTP method '{method}'. Standard methods: GET, HEAD, POST, PUT, DELETE, PATCH, OPTIONS." )); } + + let Some(query) = rule + .get("allow") + .and_then(|a| a.get("query")) + .filter(|v| !v.is_null()) + else { + continue; + }; + + let Some(query_obj) = query.as_object() else { + errors.push(format!( + "{loc}.rules[{rule_idx}].allow.query: expected map of query matchers" + )); + continue; + }; + + for (param, matcher) in query_obj { + if let Some(glob_str) = matcher.as_str() { + if let Some(warning) = check_glob_syntax(glob_str) { + warnings.push(format!( + "{loc}.rules[{rule_idx}].allow.query.{param}: {warning}" + )); + } + continue; + } + + let Some(matcher_obj) = matcher.as_object() else { + errors.push(format!( + "{loc}.rules[{rule_idx}].allow.query.{param}: expected string glob or object with `any`" + )); + continue; + }; + + let has_any = matcher_obj.get("any").is_some(); + let has_glob = matcher_obj.get("glob").is_some(); + let has_unknown = matcher_obj.keys().any(|k| k != "any" && k != "glob"); + if has_unknown { + errors.push(format!( + "{loc}.rules[{rule_idx}].allow.query.{param}: unknown matcher keys; only `glob` or `any` are supported" + )); + continue; + } + + if has_glob && has_any { + errors.push(format!( + "{loc}.rules[{rule_idx}].allow.query.{param}: matcher cannot specify both `glob` and `any`" + )); + continue; + } + + if !has_glob && !has_any { + errors.push(format!( + "{loc}.rules[{rule_idx}].allow.query.{param}: object matcher requires `glob` string or non-empty `any` list" + )); + continue; + } + + if has_glob { + match matcher_obj.get("glob").and_then(|v| v.as_str()) { + None => { + errors.push(format!( + "{loc}.rules[{rule_idx}].allow.query.{param}.glob: expected glob string" + )); + } + Some(g) => { + if let Some(warning) = check_glob_syntax(g) { + warnings.push(format!( + "{loc}.rules[{rule_idx}].allow.query.{param}.glob: {warning}" + )); + } + } + } + continue; + } + + let any = matcher_obj.get("any").and_then(|v| v.as_array()); + let Some(any) = any else { + errors.push(format!( + "{loc}.rules[{rule_idx}].allow.query.{param}.any: expected array of glob strings" + )); + continue; + }; + + if any.is_empty() { + errors.push(format!( + "{loc}.rules[{rule_idx}].allow.query.{param}.any: list must not be empty" + )); + continue; + } + + if any.iter().any(|v| v.as_str().is_none()) { + errors.push(format!( + "{loc}.rules[{rule_idx}].allow.query.{param}.any: all values must be strings" + )); + } + + for item in any.iter().filter_map(|v| v.as_str()) { + if let Some(warning) = check_glob_syntax(item) { + warnings.push(format!( + "{loc}.rules[{rule_idx}].allow.query.{param}.any: {warning}" + )); + } + } + } } } } @@ -780,4 +939,204 @@ mod tests { "should have no tls warnings with auto-detect: {warnings:?}" ); } + + #[test] + fn validate_query_any_requires_non_empty_array() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 8080, + "protocol": "rest", + "rules": [{ + "allow": { + "method": "GET", + "path": "/download", + "query": { + "tag": { "any": [] } + } + } + }] + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors.iter().any(|e| e.contains("allow.query.tag.any")), + "expected query any validation error, got: {errors:?}" + ); + } + + #[test] + fn validate_query_object_rejects_unknown_keys() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 8080, + "protocol": "rest", + "rules": [{ + "allow": { + "method": "GET", + "path": "/download", + "query": { + "tag": { "mode": "foo-*" } + } + } + }] + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors.iter().any(|e| e.contains("unknown matcher keys")), + "expected unknown query matcher key error, got: {errors:?}" + ); + } + + #[test] + fn validate_query_glob_warns_on_unclosed_bracket() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 8080, + "protocol": "rest", + "rules": [{ + "allow": { + "method": "GET", + "path": "/download", + "query": { + "tag": "[unclosed" + } + } + }] + }], + "binaries": [] + } + } + }); + let (errors, warnings) = validate_l7_policies(&data); + assert!( + errors.is_empty(), + "malformed glob should warn, not error: {errors:?}" + ); + assert!( + warnings + .iter() + .any(|w| w.contains("unclosed '['") && w.contains("allow.query.tag")), + "expected glob syntax warning, got: {warnings:?}" + ); + } + + #[test] + fn validate_query_glob_warns_on_unclosed_brace() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 8080, + "protocol": "rest", + "rules": [{ + "allow": { + "method": "GET", + "path": "/download", + "query": { + "format": { "glob": "{json,xml" } + } + } + }] + }], + "binaries": [] + } + } + }); + let (errors, warnings) = validate_l7_policies(&data); + assert!( + errors.is_empty(), + "malformed glob should warn, not error: {errors:?}" + ); + assert!( + warnings + .iter() + .any(|w| w.contains("unclosed '{'") && w.contains("allow.query.format.glob")), + "expected glob syntax warning, got: {warnings:?}" + ); + } + + #[test] + fn validate_query_any_warns_on_malformed_glob_item() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 8080, + "protocol": "rest", + "rules": [{ + "allow": { + "method": "GET", + "path": "/download", + "query": { + "tag": { "any": ["valid-*", "[bad"] } + } + } + }] + }], + "binaries": [] + } + } + }); + let (errors, warnings) = validate_l7_policies(&data); + assert!( + errors.is_empty(), + "malformed glob in any should warn, not error: {errors:?}" + ); + assert!( + warnings + .iter() + .any(|w| w.contains("unclosed '['") && w.contains("allow.query.tag.any")), + "expected glob syntax warning for any item, got: {warnings:?}" + ); + } + + #[test] + fn validate_query_string_and_any_matchers_are_accepted() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 8080, + "protocol": "rest", + "rules": [{ + "allow": { + "method": "GET", + "path": "/download", + "query": { + "slug": "my-*", + "tag": { "any": ["foo-*", "bar-*"] }, + "owner": { "glob": "org-*" } + } + } + }] + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors.is_empty(), + "valid query matcher shapes should not error: {errors:?}" + ); + } } diff --git a/crates/openshell-sandbox/src/l7/provider.rs b/crates/openshell-sandbox/src/l7/provider.rs index a9bf8bf5f..7516aa85c 100644 --- a/crates/openshell-sandbox/src/l7/provider.rs +++ b/crates/openshell-sandbox/src/l7/provider.rs @@ -10,9 +10,26 @@ //! works for both plaintext TCP and TLS-terminated connections. use miette::Result; +use std::collections::HashMap; use std::future::Future; use tokio::io::{AsyncRead, AsyncWrite}; +/// Outcome of relaying a single HTTP request/response pair. +#[derive(Debug)] +pub enum RelayOutcome { + /// Connection is reusable for further HTTP requests (keep-alive). + Reusable, + /// Connection was consumed (e.g. read-until-EOF or `Connection: close`). + Consumed, + /// Server responded with 101 Switching Protocols. + /// The connection has been upgraded (e.g. to WebSocket) and must be + /// relayed as raw bidirectional TCP from this point forward. + /// Contains any overflow bytes read from upstream past the 101 response + /// headers that belong to the upgraded protocol. The 101 headers + /// themselves have already been forwarded to the client. + Upgraded { overflow: Vec }, +} + /// Body framing for HTTP requests/responses. #[derive(Debug, Clone, Copy)] pub enum BodyLength { @@ -31,6 +48,8 @@ pub struct L7Request { pub action: String, /// Target: URL path for REST, empty for SQL. pub target: String, + /// Decoded query parameter multimap for REST requests. + pub query_params: HashMap>, /// Raw request header bytes (request line + headers for HTTP, message for SQL). /// May include overflow body bytes read during header parsing. pub raw_header: Vec, @@ -54,14 +73,15 @@ pub trait L7Provider: Send + Sync { /// Forward an allowed request to upstream and relay the response back. /// - /// Returns `true` if the upstream connection is reusable (keep-alive), - /// `false` if it was consumed (e.g. read-until-EOF or `Connection: close`). + /// Returns a [`RelayOutcome`] indicating whether the connection is + /// reusable (keep-alive), consumed, or has been upgraded (101 Switching + /// Protocols) and must be relayed as raw bidirectional TCP. fn relay( &self, req: &L7Request, client: &mut C, upstream: &mut U, - ) -> impl Future> + Send + ) -> impl Future> + Send where C: AsyncRead + AsyncWrite + Unpin + Send, U: AsyncRead + AsyncWrite + Unpin + Send; diff --git a/crates/openshell-sandbox/src/l7/relay.rs b/crates/openshell-sandbox/src/l7/relay.rs index 618280475..110f777e9 100644 --- a/crates/openshell-sandbox/src/l7/relay.rs +++ b/crates/openshell-sandbox/src/l7/relay.rs @@ -7,12 +7,16 @@ //! Parses each request within the tunnel, evaluates it against OPA policy, //! and either forwards or denies the request. -use crate::l7::provider::L7Provider; +use crate::l7::provider::{L7Provider, RelayOutcome}; use crate::l7::{EnforcementMode, L7EndpointConfig, L7Protocol, L7RequestInfo}; -use crate::secrets::SecretResolver; +use crate::secrets::{self, SecretResolver}; use miette::{IntoDiagnostic, Result, miette}; +use openshell_ocsf::{ + ActionId, ActivityId, DispositionId, Endpoint, HttpActivityBuilder, HttpRequest, + NetworkActivityBuilder, SeverityId, Url as OcsfUrl, ocsf_emit, +}; use std::sync::{Arc, Mutex}; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tracing::{debug, info, warn}; /// Context for L7 request policy evaluation. @@ -55,11 +59,15 @@ where L7Protocol::Rest => relay_rest(config, &engine, client, upstream, ctx).await, L7Protocol::Sql => { // SQL provider is Phase 3 — fall through to passthrough with warning - warn!( - host = %ctx.host, - port = ctx.port, - "SQL L7 provider not yet implemented, falling back to passthrough" - ); + { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .severity(SeverityId::Low) + .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) + .message("SQL L7 provider not yet implemented, falling back to passthrough") + .build(); + ocsf_emit!(event); + } tokio::io::copy_bidirectional(client, upstream) .await .into_diagnostic()?; @@ -68,6 +76,46 @@ where } } +/// Handle an upgraded connection (101 Switching Protocols). +/// +/// Forwards any overflow bytes from the upgrade response to the client, then +/// switches to raw bidirectional TCP copy for the upgraded protocol (WebSocket, +/// HTTP/2, etc.). L7 policy enforcement does not apply after the upgrade — +/// the initial HTTP request was already evaluated. +async fn handle_upgrade( + client: &mut C, + upstream: &mut U, + overflow: Vec, + host: &str, + port: u16, +) -> Result<()> +where + C: AsyncRead + AsyncWrite + Unpin + Send, + U: AsyncRead + AsyncWrite + Unpin + Send, +{ + ocsf_emit!( + NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .activity_name("Upgrade") + .severity(SeverityId::Informational) + .dst_endpoint(Endpoint::from_domain(host, port)) + .message(format!( + "101 Switching Protocols — raw bidirectional relay (L7 enforcement no longer active) \ + [host:{host} port:{port} overflow_bytes:{}]", + overflow.len() + )) + .build() + ); + if !overflow.is_empty() { + client.write_all(&overflow).await.into_diagnostic()?; + client.flush().await.into_diagnostic()?; + } + tokio::io::copy_bidirectional(client, upstream) + .await + .into_diagnostic()?; + Ok(()) +} + /// REST relay loop: parse request -> evaluate -> allow/deny -> relay response -> repeat. async fn relay_rest( config: &L7EndpointConfig, @@ -94,65 +142,146 @@ where "L7 connection closed" ); } else { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) + .message(format!("HTTP parse error in L7 relay: {e}")) + .build(); + ocsf_emit!(event); + } + return Ok(()); // Close connection on parse error + } + }; + + // Rewrite credential placeholders in the request target BEFORE OPA + // evaluation. OPA sees the redacted path; the resolved path goes only + // to the upstream write. + let (eval_target, redacted_target) = if let Some(ref resolver) = ctx.secret_resolver { + match secrets::rewrite_target_for_eval(&req.target, resolver) { + Ok(result) => (result.resolved, result.redacted), + Err(e) => { warn!( host = %ctx.host, port = ctx.port, error = %e, - "HTTP parse error in L7 relay" + "credential resolution failed in request target, rejecting" ); + let response = b"HTTP/1.1 500 Internal Server Error\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"; + client.write_all(response).await.into_diagnostic()?; + client.flush().await.into_diagnostic()?; + return Ok(()); } - return Ok(()); // Close connection on parse error } + } else { + (req.target.clone(), req.target.clone()) }; let request_info = L7RequestInfo { action: req.action.clone(), - target: req.target.clone(), + target: redacted_target.clone(), + query_params: req.query_params.clone(), }; - // Evaluate L7 policy via Rego + // Evaluate L7 policy via Rego (using redacted target) let (allowed, reason) = evaluate_l7_request(engine, ctx, &request_info)?; - let decision_str = match (allowed, config.enforcement) { - (true, _) => "allow", - (false, EnforcementMode::Audit) => "audit", - (false, EnforcementMode::Enforce) => "deny", + // Check if this is an upgrade request for logging purposes. + let header_end = req + .raw_header + .windows(4) + .position(|w| w == b"\r\n\r\n") + .map_or(req.raw_header.len(), |p| p + 4); + let is_upgrade_request = { + let h = String::from_utf8_lossy(&req.raw_header[..header_end]); + h.lines() + .skip(1) + .any(|l| l.to_ascii_lowercase().starts_with("upgrade:")) + }; + + let decision_str = match (allowed, config.enforcement, is_upgrade_request) { + (true, _, true) => "allow_upgrade", + (true, _, false) => "allow", + (false, EnforcementMode::Audit, _) => "audit", + (false, EnforcementMode::Enforce, _) => "deny", }; - // Log every L7 decision - info!( - dst_host = %ctx.host, - dst_port = ctx.port, - policy = %ctx.policy_name, - l7_protocol = "rest", - l7_action = %request_info.action, - l7_target = %request_info.target, - l7_decision = decision_str, - l7_deny_reason = %reason, - "L7_REQUEST", - ); + // Log every L7 decision as an OCSF HTTP Activity event. + // Uses redacted_target (path only, no query params) to avoid logging secrets. + { + let (action_id, disposition_id, severity) = match decision_str { + "allow" => ( + ActionId::Allowed, + DispositionId::Allowed, + SeverityId::Informational, + ), + "deny" => (ActionId::Denied, DispositionId::Blocked, SeverityId::Medium), + "audit" => ( + ActionId::Allowed, + DispositionId::Allowed, + SeverityId::Informational, + ), + _ => ( + ActionId::Other, + DispositionId::Other, + SeverityId::Informational, + ), + }; + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(action_id) + .disposition(disposition_id) + .severity(severity) + .http_request(HttpRequest::new( + &request_info.action, + OcsfUrl::new("http", &ctx.host, &redacted_target, ctx.port), + )) + .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) + .firewall_rule(&ctx.policy_name, "l7") + .message(format!( + "L7_REQUEST {decision_str} {} {}:{}{} reason={}", + request_info.action, ctx.host, ctx.port, redacted_target, reason, + )) + .build(); + ocsf_emit!(event); + } + + // Store the resolved target for the deny response redaction + let _ = &eval_target; if allowed || config.enforcement == EnforcementMode::Audit { // Forward request to upstream and relay response - let reusable = crate::l7::rest::relay_http_request_with_resolver( + let outcome = crate::l7::rest::relay_http_request_with_resolver( &req, client, upstream, ctx.secret_resolver.as_deref(), ) .await?; - if !reusable { - debug!( - host = %ctx.host, - port = ctx.port, - "Upstream connection not reusable, closing L7 relay" - ); - return Ok(()); + match outcome { + RelayOutcome::Reusable => {} // continue loop + RelayOutcome::Consumed => { + debug!( + host = %ctx.host, + port = ctx.port, + "Upstream connection not reusable, closing L7 relay" + ); + return Ok(()); + } + RelayOutcome::Upgraded { overflow } => { + return handle_upgrade(client, upstream, overflow, &ctx.host, ctx.port).await; + } } } else { - // Enforce mode: deny with 403 and close connection + // Enforce mode: deny with 403 and close connection (use redacted target) crate::l7::rest::RestProvider - .deny(&req, &ctx.policy_name, &reason, client) + .deny_with_redacted_target( + &req, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + ) .await?; return Ok(()); } @@ -180,7 +309,7 @@ fn is_benign_connection_error(err: &miette::Report) -> bool { /// Evaluate an L7 request against the OPA engine. /// /// Returns `(allowed, deny_reason)`. -fn evaluate_l7_request( +pub fn evaluate_l7_request( engine: &Mutex, ctx: &L7EvalContext, request: &L7RequestInfo, @@ -198,6 +327,7 @@ fn evaluate_l7_request( "request": { "method": request.action, "path": request.target, + "query_params": request.query_params.clone(), } }); @@ -263,29 +393,62 @@ where request_count += 1; - // Log for observability. + // Resolve and redact the target for logging. + let redacted_target = if let Some(ref res) = ctx.secret_resolver { + match secrets::rewrite_target_for_eval(&req.target, res) { + Ok(result) => result.redacted, + Err(e) => { + warn!( + host = %ctx.host, + port = ctx.port, + error = %e, + "credential resolution failed in request target, rejecting" + ); + let response = b"HTTP/1.1 500 Internal Server Error\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"; + client.write_all(response).await.into_diagnostic()?; + client.flush().await.into_diagnostic()?; + return Ok(()); + } + } + } else { + req.target.clone() + }; + + // Log for observability via OCSF HTTP Activity event. + // Uses redacted_target (path only, no query params) to avoid logging secrets. let has_creds = resolver.is_some(); - info!( - host = %ctx.host, - port = ctx.port, - method = %req.action, - path = %req.target, - credentials_injected = has_creds, - request_num = request_count, - "HTTP_REQUEST", - ); + { + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(ActionId::Allowed) + .disposition(DispositionId::Allowed) + .severity(SeverityId::Informational) + .http_request(HttpRequest::new( + &req.action, + OcsfUrl::new("http", &ctx.host, &redacted_target, ctx.port), + )) + .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) + .message(format!( + "HTTP_REQUEST {} {}:{}{} credentials_injected={has_creds} request_num={request_count}", + req.action, ctx.host, ctx.port, redacted_target, + )) + .build(); + ocsf_emit!(event); + } - // Forward request with credential rewriting. - let keep_alive = + // Forward request with credential rewriting and relay the response. + // relay_http_request_with_resolver handles both directions: it sends + // the request upstream and reads the response back to the client. + let outcome = crate::l7::rest::relay_http_request_with_resolver(&req, client, upstream, resolver) .await?; - // Relay response back to client. - let reusable = - crate::l7::rest::relay_response_to_client(upstream, client, &req.action).await?; - - if !keep_alive || !reusable { - break; + match outcome { + RelayOutcome::Reusable => {} // continue loop + RelayOutcome::Consumed => break, + RelayOutcome::Upgraded { overflow } => { + return handle_upgrade(client, upstream, overflow, &ctx.host, ctx.port).await; + } } } diff --git a/crates/openshell-sandbox/src/l7/rest.rs b/crates/openshell-sandbox/src/l7/rest.rs index ebb349578..6bbf7be4e 100644 --- a/crates/openshell-sandbox/src/l7/rest.rs +++ b/crates/openshell-sandbox/src/l7/rest.rs @@ -7,11 +7,12 @@ //! policy, and relays allowed requests to upstream. Handles Content-Length //! and chunked transfer encoding for body framing. -use crate::l7::provider::{BodyLength, L7Provider, L7Request}; +use crate::l7::provider::{BodyLength, L7Provider, L7Request, RelayOutcome}; use crate::secrets::rewrite_http_header_block; use miette::{IntoDiagnostic, Result, miette}; +use std::collections::HashMap; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tracing::debug; +use tracing::{debug, warn}; const MAX_HEADER_BYTES: usize = 16384; // 16 KiB for HTTP headers const RELAY_BUF_SIZE: usize = 8192; @@ -31,7 +32,12 @@ impl L7Provider for RestProvider { parse_http_request(client).await } - async fn relay(&self, req: &L7Request, client: &mut C, upstream: &mut U) -> Result + async fn relay( + &self, + req: &L7Request, + client: &mut C, + upstream: &mut U, + ) -> Result where C: AsyncRead + AsyncWrite + Unpin + Send, U: AsyncRead + AsyncWrite + Unpin + Send, @@ -46,7 +52,21 @@ impl L7Provider for RestProvider { reason: &str, client: &mut C, ) -> Result<()> { - send_deny_response(req, policy_name, reason, client).await + send_deny_response(req, policy_name, reason, client, None).await + } +} + +impl RestProvider { + /// Deny with a redacted target for the response body. + pub(crate) async fn deny_with_redacted_target( + &self, + req: &L7Request, + policy_name: &str, + reason: &str, + client: &mut C, + redacted_target: Option<&str>, + ) -> Result<()> { + send_deny_response(req, policy_name, reason, client, redacted_target).await } } @@ -116,7 +136,7 @@ async fn parse_http_request(client: &mut C) -> Result(client: &mut C) -> Result Result<(String, HashMap>)> { + match target.split_once('?') { + Some((path, query)) => Ok((path.to_string(), parse_query_params(query)?)), + None => Ok((target.to_string(), HashMap::new())), + } +} + +fn parse_query_params(query: &str) -> Result>> { + let mut params: HashMap> = HashMap::new(); + if query.is_empty() { + return Ok(params); + } + + for pair in query.split('&') { + if pair.is_empty() { + continue; + } + + let (raw_key, raw_value) = match pair.split_once('=') { + Some((key, value)) => (key, value), + None => (pair, ""), + }; + let key = decode_query_component(raw_key)?; + let value = decode_query_component(raw_value)?; + params.entry(key).or_default().push(value); + } + + Ok(params) +} + +/// Decode a single query string component (key or value). +/// +/// Handles both RFC 3986 percent-encoding (`%20` → space) and the +/// `application/x-www-form-urlencoded` convention (`+` → space). +/// Decoding `+` as space matches the behavior of Python's `urllib.parse`, +/// JavaScript's `URLSearchParams`, Go's `url.ParseQuery`, and most HTTP +/// frameworks. Callers that need a literal `+` should send `%2B`. +fn decode_query_component(input: &str) -> Result { + let bytes = input.as_bytes(); + let mut decoded = Vec::with_capacity(bytes.len()); + let mut i = 0; + + while i < bytes.len() { + if bytes[i] == b'+' { + decoded.push(b' '); + i += 1; + continue; + } + + if bytes[i] != b'%' { + decoded.push(bytes[i]); + i += 1; + continue; + } + + if i + 2 >= bytes.len() { + return Err(miette!("Invalid percent-encoding in query component")); + } + + let hi = decode_hex_nibble(bytes[i + 1]) + .ok_or_else(|| miette!("Invalid percent-encoding in query component"))?; + let lo = decode_hex_nibble(bytes[i + 2]) + .ok_or_else(|| miette!("Invalid percent-encoding in query component"))?; + decoded.push((hi << 4) | lo); + i += 3; + } + + String::from_utf8(decoded).map_err(|_| miette!("Query component is not valid UTF-8")) +} + +fn decode_hex_nibble(byte: u8) -> Option { + match byte { + b'0'..=b'9' => Some(byte - b'0'), + b'a'..=b'f' => Some(byte - b'a' + 10), + b'A'..=b'F' => Some(byte - b'A' + 10), + _ => None, + } +} + /// Forward an allowed HTTP request to upstream and relay the response back. /// -/// Returns `true` if the upstream connection is reusable, `false` if consumed. -async fn relay_http_request(req: &L7Request, client: &mut C, upstream: &mut U) -> Result +/// Returns the relay outcome indicating whether the connection is reusable, +/// consumed, or has been upgraded (e.g. WebSocket via 101 Switching Protocols). +async fn relay_http_request( + req: &L7Request, + client: &mut C, + upstream: &mut U, +) -> Result where C: AsyncRead + AsyncWrite + Unpin, U: AsyncRead + AsyncWrite + Unpin, @@ -154,7 +260,7 @@ pub(crate) async fn relay_http_request_with_resolver( client: &mut C, upstream: &mut U, resolver: Option<&crate::secrets::SecretResolver>, -) -> Result +) -> Result where C: AsyncRead + AsyncWrite + Unpin, U: AsyncRead + AsyncWrite + Unpin, @@ -165,10 +271,11 @@ where .position(|w| w == b"\r\n\r\n") .map_or(req.raw_header.len(), |p| p + 4); - let rewritten_header = rewrite_http_header_block(&req.raw_header[..header_end], resolver); + let rewrite_result = rewrite_http_header_block(&req.raw_header[..header_end], resolver) + .map_err(|e| miette!("credential injection failed: {e}"))?; upstream - .write_all(&rewritten_header) + .write_all(&rewrite_result.rewritten) .await .into_diagnostic()?; @@ -191,21 +298,65 @@ where BodyLength::None => {} } upstream.flush().await.into_diagnostic()?; - let (reusable, _) = relay_response(&req.action, upstream, client).await?; - Ok(reusable) + + let outcome = relay_response(&req.action, upstream, client).await?; + + // Validate that the client actually requested an upgrade before accepting + // a 101 from upstream. Per RFC 9110 Section 7.8, the server MUST NOT send + // 101 unless the client sent Upgrade + Connection: Upgrade headers. A + // non-compliant or malicious upstream could send an unsolicited 101 to + // bypass L7 inspection. + if matches!(outcome, RelayOutcome::Upgraded { .. }) { + let header_str = String::from_utf8_lossy(&req.raw_header[..header_end]); + if !client_requested_upgrade(&header_str) { + openshell_ocsf::ocsf_emit!( + openshell_ocsf::DetectionFindingBuilder::new(crate::ocsf_ctx()) + .activity(openshell_ocsf::ActivityId::Open) + .action(openshell_ocsf::ActionId::Denied) + .disposition(openshell_ocsf::DispositionId::Blocked) + .severity(openshell_ocsf::SeverityId::High) + .confidence(openshell_ocsf::ConfidenceId::High) + .is_alert(true) + .finding_info( + openshell_ocsf::FindingInfo::new( + "unsolicited-101-upgrade", + "Unsolicited 101 Switching Protocols", + ) + .with_desc(&format!( + "Upstream sent 101 without client Upgrade request for {} {} — \ + possible L7 inspection bypass. Connection closed.", + req.action, req.target, + )), + ) + .message(format!( + "Unsolicited 101 upgrade blocked: {} {}", + req.action, req.target, + )) + .build() + ); + return Ok(RelayOutcome::Consumed); + } + } + + Ok(outcome) } /// Send a 403 Forbidden JSON deny response. +/// +/// When `redacted_target` is provided, it is used instead of `req.target` +/// in the response body to avoid leaking resolved credential values. async fn send_deny_response( req: &L7Request, policy_name: &str, reason: &str, client: &mut C, + redacted_target: Option<&str>, ) -> Result<()> { + let target = redacted_target.unwrap_or(&req.target); let body = serde_json::json!({ "error": "policy_denied", "policy": policy_name, - "rule": format!("{} {}", req.action, req.target), + "rule": format!("{} {}", req.action, target), "detail": reason }); let body_bytes = body.to_string(); @@ -242,14 +393,22 @@ fn parse_body_length(headers: &str) -> Result { let lower = line.to_ascii_lowercase(); if lower.starts_with("transfer-encoding:") { let val = lower.split_once(':').map_or("", |(_, v)| v.trim()); - if val.contains("chunked") { + if val.split(',').any(|enc| enc.trim() == "chunked") { has_te_chunked = true; } } - if lower.starts_with("content-length:") - && let Some(val) = lower.split_once(':').map(|(_, v)| v.trim()) - && let Ok(len) = val.parse::() - { + if lower.starts_with("content-length:") { + let val = lower.split_once(':').map_or("", |(_, v)| v.trim()); + let len: u64 = val + .parse() + .map_err(|_| miette!("Request contains invalid Content-Length value"))?; + if let Some(prev) = cl_value { + if prev != len { + return Err(miette!( + "Request contains multiple Content-Length headers with differing values ({prev} vs {len})" + )); + } + } cl_value = Some(len); } } @@ -415,29 +574,28 @@ fn find_crlf(buf: &[u8], start: usize) -> Option { /// Read and relay a full HTTP response (headers + body) from upstream to client. /// -/// Returns `true` if the upstream connection is reusable (keep-alive), -/// `false` if it was consumed (read-until-EOF or `Connection: close`). -/// Relay an HTTP response from upstream back to the client. +/// Returns a [`RelayOutcome`] indicating whether the connection is reusable, +/// consumed, or has been upgraded (101 Switching Protocols). /// -/// Returns `true` if the connection should stay alive for further requests. +/// Note: callers that receive `Upgraded` are responsible for switching to +/// raw bidirectional relay and forwarding the overflow bytes. pub(crate) async fn relay_response_to_client( upstream: &mut U, client: &mut C, request_method: &str, -) -> Result +) -> Result where U: AsyncRead + Unpin, C: AsyncWrite + Unpin, { - let (reusable, _status) = relay_response(request_method, upstream, client).await?; - Ok(reusable) + relay_response(request_method, upstream, client).await } async fn relay_response( request_method: &str, upstream: &mut U, client: &mut C, -) -> Result<(bool, u16)> +) -> Result where U: AsyncRead + Unpin, C: AsyncWrite + Unpin, @@ -458,7 +616,7 @@ where if !buf.is_empty() { client.write_all(&buf).await.into_diagnostic()?; } - return Ok((false, 0)); + return Ok(RelayOutcome::Consumed); } buf.extend_from_slice(&tmp[..n]); @@ -484,6 +642,26 @@ where "relay_response framing" ); + // 101 Switching Protocols: the connection has been upgraded (e.g. to + // WebSocket). Forward the 101 headers to the client and signal the + // caller to switch to raw bidirectional TCP relay. Any bytes read + // from upstream beyond the headers are overflow that belong to the + // upgraded protocol and must be forwarded before switching. + if status_code == 101 { + client + .write_all(&buf[..header_end]) + .await + .into_diagnostic()?; + client.flush().await.into_diagnostic()?; + let overflow = buf[header_end..].to_vec(); + debug!( + request_method, + overflow_bytes = overflow.len(), + "101 Switching Protocols — signaling protocol upgrade" + ); + return Ok(RelayOutcome::Upgraded { overflow }); + } + // Bodiless responses (HEAD, 1xx, 204, 304): forward headers only, skip body if is_bodiless_response(request_method, status_code) { client @@ -491,7 +669,11 @@ where .await .into_diagnostic()?; client.flush().await.into_diagnostic()?; - return Ok((!server_wants_close, status_code)); + return if server_wants_close { + Ok(RelayOutcome::Consumed) + } else { + Ok(RelayOutcome::Reusable) + }; } // No explicit framing (no Content-Length, no Transfer-Encoding). @@ -511,7 +693,7 @@ where } relay_until_eof(upstream, client).await?; client.flush().await.into_diagnostic()?; - return Ok((false, status_code)); + return Ok(RelayOutcome::Consumed); } // No Connection: close — an HTTP/1.1 keep-alive server that omits // framing headers has an empty body. Forward headers and continue @@ -522,7 +704,7 @@ where .await .into_diagnostic()?; client.flush().await.into_diagnostic()?; - return Ok((true, status_code)); + return Ok(RelayOutcome::Reusable); } // Forward response headers + any overflow body bytes @@ -555,7 +737,7 @@ where // loop will exit via the normal error path. Exiting early here would // tear down the CONNECT tunnel before the client can detect the close, // causing ~30 s retry delays in clients like `gh`. - Ok((true, status_code)) + Ok(RelayOutcome::Reusable) } /// Parse the HTTP status code from a response status line. @@ -579,6 +761,33 @@ fn parse_connection_close(headers: &str) -> bool { false } +/// Check if the client request headers contain both `Upgrade` and +/// `Connection: Upgrade` headers, indicating the client requested a +/// protocol upgrade (e.g. WebSocket). +/// +/// Per RFC 9110 Section 7.8, a server MUST NOT send 101 Switching Protocols +/// unless the client sent these headers. +fn client_requested_upgrade(headers: &str) -> bool { + let mut has_upgrade_header = false; + let mut connection_contains_upgrade = false; + + for line in headers.lines().skip(1) { + let lower = line.to_ascii_lowercase(); + if lower.starts_with("upgrade:") { + has_upgrade_header = true; + } + if lower.starts_with("connection:") { + let val = lower.split_once(':').map_or("", |(_, v)| v.trim()); + // Connection header can have comma-separated values + if val.split(',').any(|tok| tok.trim() == "upgrade") { + connection_contains_upgrade = true; + } + } + } + + has_upgrade_header && connection_contains_upgrade +} + /// Returns true for responses that MUST NOT contain a message body per RFC 7230 §3.3.3: /// HEAD responses, 1xx informational, 204 No Content, 304 Not Modified. fn is_bodiless_response(request_method: &str, status_code: u16) -> bool { @@ -652,6 +861,7 @@ fn is_benign_close(err: &std::io::Error) -> bool { mod tests { use super::*; use crate::secrets::SecretResolver; + use base64::Engine as _; #[test] fn parse_content_length() { @@ -681,6 +891,92 @@ mod tests { } } + #[test] + fn parse_target_query_parses_duplicate_values() { + let (path, query) = parse_target_query("/download?tag=a&tag=b").expect("parse"); + assert_eq!(path, "/download"); + assert_eq!( + query.get("tag").cloned(), + Some(vec!["a".into(), "b".into()]) + ); + } + + #[test] + fn parse_target_query_decodes_percent_and_plus() { + let (path, query) = parse_target_query("/download?slug=my%2Fskill&name=Foo+Bar").unwrap(); + assert_eq!(path, "/download"); + assert_eq!( + query.get("slug").cloned(), + Some(vec!["my/skill".to_string()]) + ); + // `+` is decoded as space per application/x-www-form-urlencoded. + // Literal `+` should be sent as `%2B`. + assert_eq!( + query.get("name").cloned(), + Some(vec!["Foo Bar".to_string()]) + ); + } + + #[test] + fn parse_target_query_literal_plus_via_percent_encoding() { + let (_path, query) = parse_target_query("/search?q=a%2Bb").unwrap(); + assert_eq!( + query.get("q").cloned(), + Some(vec!["a+b".to_string()]), + "%2B should decode to literal +" + ); + } + + #[test] + fn parse_target_query_empty_value() { + let (_path, query) = parse_target_query("/api?tag=").unwrap(); + assert_eq!( + query.get("tag").cloned(), + Some(vec!["".to_string()]), + "key with empty value should produce empty string" + ); + } + + #[test] + fn parse_target_query_key_without_value() { + let (_path, query) = parse_target_query("/api?verbose").unwrap(); + assert_eq!( + query.get("verbose").cloned(), + Some(vec!["".to_string()]), + "key without = should produce empty string value" + ); + } + + #[test] + fn parse_target_query_unicode_after_decoding() { + // "café" = c a f %C3%A9 + let (_path, query) = parse_target_query("/search?q=caf%C3%A9").unwrap(); + assert_eq!( + query.get("q").cloned(), + Some(vec!["café".to_string()]), + "percent-encoded UTF-8 should decode correctly" + ); + } + + #[test] + fn parse_target_query_empty_query_string() { + let (path, query) = parse_target_query("/api?").unwrap(); + assert_eq!(path, "/api"); + assert!( + query.is_empty(), + "empty query after ? should produce empty map" + ); + } + + #[test] + fn parse_target_query_rejects_malformed_percent_encoding() { + let err = parse_target_query("/download?slug=bad%2").expect_err("expected parse error"); + assert!( + err.to_string().contains("percent-encoding"), + "unexpected error: {err}" + ); + } + /// SEC-009: Reject requests with both Content-Length and Transfer-Encoding /// to prevent CL/TE request smuggling (RFC 7230 Section 3.3.3). #[test] @@ -702,6 +998,59 @@ mod tests { ); } + /// SEC: Reject differing duplicate Content-Length headers. + #[test] + fn reject_differing_duplicate_content_length() { + let headers = + "POST /api HTTP/1.1\r\nHost: x\r\nContent-Length: 0\r\nContent-Length: 50\r\n\r\n"; + assert!( + parse_body_length(headers).is_err(), + "Must reject differing duplicate Content-Length" + ); + } + + /// SEC: Accept identical duplicate Content-Length headers. + #[test] + fn accept_identical_duplicate_content_length() { + let headers = + "POST /api HTTP/1.1\r\nHost: x\r\nContent-Length: 42\r\nContent-Length: 42\r\n\r\n"; + match parse_body_length(headers).unwrap() { + BodyLength::ContentLength(42) => {} + other => panic!("Expected ContentLength(42), got {other:?}"), + } + } + + /// SEC: Reject non-numeric Content-Length values. + #[test] + fn reject_non_numeric_content_length() { + let headers = "POST /api HTTP/1.1\r\nHost: x\r\nContent-Length: abc\r\n\r\n"; + assert!( + parse_body_length(headers).is_err(), + "Must reject non-numeric Content-Length" + ); + } + + /// SEC: Reject when second Content-Length is non-numeric (bypass test). + #[test] + fn reject_valid_then_invalid_content_length() { + let headers = + "POST /api HTTP/1.1\r\nHost: x\r\nContent-Length: 42\r\nContent-Length: abc\r\n\r\n"; + assert!( + parse_body_length(headers).is_err(), + "Must reject when any Content-Length is non-numeric" + ); + } + + /// SEC: Transfer-Encoding substring match must not match partial tokens. + #[test] + fn te_substring_not_chunked() { + let headers = "POST /api HTTP/1.1\r\nHost: x\r\nTransfer-Encoding: chunkedx\r\n\r\n"; + match parse_body_length(headers).unwrap() { + BodyLength::None => {} + other => panic!("Expected None for non-matching TE, got {other:?}"), + } + } + /// SEC-009: Bare LF in headers enables header injection. #[tokio::test] async fn reject_bare_lf_in_headers() { @@ -746,6 +1095,32 @@ mod tests { assert!(result.is_err(), "Must reject unsupported HTTP version"); } + #[tokio::test] + async fn parse_http_request_splits_path_and_query_params() { + let (mut client, mut writer) = tokio::io::duplex(4096); + tokio::spawn(async move { + writer + .write_all( + b"GET /download?slug=my%2Fskill&tag=foo&tag=bar HTTP/1.1\r\nHost: x\r\n\r\n", + ) + .await + .unwrap(); + }); + let req = parse_http_request(&mut client) + .await + .expect("request should parse") + .expect("request should exist"); + assert_eq!(req.target, "/download"); + assert_eq!( + req.query_params.get("slug").cloned(), + Some(vec!["my/skill".to_string()]) + ); + assert_eq!( + req.query_params.get("tag").cloned(), + Some(vec!["foo".to_string(), "bar".to_string()]) + ); + } + /// Regression test: two pipelined requests in a single write must be /// parsed independently. Before the fix, the 1024-byte `read()` buffer /// could capture bytes from the second request, which were forwarded @@ -770,6 +1145,7 @@ mod tests { .expect("expected first request"); assert_eq!(first.action, "GET"); assert_eq!(first.target, "/allowed"); + assert!(first.query_params.is_empty()); assert_eq!( first.raw_header, b"GET /allowed HTTP/1.1\r\nHost: example.com\r\n\r\n", "raw_header must contain only the first request's headers" @@ -781,6 +1157,7 @@ mod tests { .expect("expected second request"); assert_eq!(second.action, "POST"); assert_eq!(second.target, "/blocked"); + assert!(second.query_params.is_empty()); } #[test] @@ -858,8 +1235,11 @@ mod tests { .await .expect("relay_response should not deadlock"); - let (reusable, _status) = result.expect("relay_response should succeed"); - assert!(!reusable, "connection consumed by read-until-EOF"); + let outcome = result.expect("relay_response should succeed"); + assert!( + matches!(outcome, RelayOutcome::Consumed), + "connection consumed by read-until-EOF" + ); client_write.shutdown().await.unwrap(); let mut received = Vec::new(); @@ -896,8 +1276,11 @@ mod tests { .await .expect("must not block when no Connection: close"); - let (reusable, _status) = result.expect("relay_response should succeed"); - assert!(reusable, "keep-alive implied, connection reusable"); + let outcome = result.expect("relay_response should succeed"); + assert!( + matches!(outcome, RelayOutcome::Reusable), + "keep-alive implied, connection reusable" + ); client_write.shutdown().await.unwrap(); let mut received = Vec::new(); @@ -929,8 +1312,11 @@ mod tests { .await .expect("HEAD relay must not deadlock waiting for body"); - let (reusable, _status) = result.expect("relay_response should succeed"); - assert!(reusable, "HEAD response should be reusable"); + let outcome = result.expect("relay_response should succeed"); + assert!( + matches!(outcome, RelayOutcome::Reusable), + "HEAD response should be reusable" + ); client_write.shutdown().await.unwrap(); let mut received = Vec::new(); @@ -959,8 +1345,11 @@ mod tests { .await .expect("204 relay must not deadlock"); - let (reusable, _status) = result.expect("relay_response should succeed"); - assert!(reusable, "204 response should be reusable"); + let outcome = result.expect("relay_response should succeed"); + assert!( + matches!(outcome, RelayOutcome::Reusable), + "204 response should be reusable" + ); client_write.shutdown().await.unwrap(); let mut received = Vec::new(); @@ -991,8 +1380,11 @@ mod tests { .await .expect("must not block when chunked body is complete in overflow"); - let (reusable, _status) = result.expect("relay_response should succeed"); - assert!(reusable, "connection should be reusable"); + let outcome = result.expect("relay_response should succeed"); + assert!( + matches!(outcome, RelayOutcome::Reusable), + "connection should be reusable" + ); client_write.shutdown().await.unwrap(); let mut received = Vec::new(); @@ -1027,8 +1419,11 @@ mod tests { .await .expect("must not block when chunked response has trailers"); - let (reusable, _status) = result.expect("relay_response should succeed"); - assert!(reusable, "chunked response should be reusable"); + let outcome = result.expect("relay_response should succeed"); + assert!( + matches!(outcome, RelayOutcome::Reusable), + "chunked response should be reusable" + ); client_write.shutdown().await.unwrap(); let mut received = Vec::new(); @@ -1062,8 +1457,11 @@ mod tests { .await .expect("normal relay must not deadlock"); - let (reusable, _status) = result.expect("relay_response should succeed"); - assert!(reusable, "Content-Length response should be reusable"); + let outcome = result.expect("relay_response should succeed"); + assert!( + matches!(outcome, RelayOutcome::Reusable), + "Content-Length response should be reusable" + ); client_write.shutdown().await.unwrap(); let mut received = Vec::new(); @@ -1090,12 +1488,12 @@ mod tests { .await .expect("relay must not deadlock"); - let (reusable, _status) = result.expect("relay_response should succeed"); + let outcome = result.expect("relay_response should succeed"); // With explicit framing, Connection: close is still reported as reusable // so the relay loop continues. The *next* upstream write will fail and // exit the loop via the normal error path. assert!( - reusable, + matches!(outcome, RelayOutcome::Reusable), "explicit framing keeps loop alive despite Connection: close" ); @@ -1105,6 +1503,224 @@ mod tests { assert!(String::from_utf8_lossy(&received).contains("hello")); } + #[tokio::test] + async fn relay_response_101_switching_protocols_returns_upgraded_with_overflow() { + // Build a 101 response followed by WebSocket frame data (overflow). + let mut response = Vec::new(); + response.extend_from_slice(b"HTTP/1.1 101 Switching Protocols\r\n"); + response.extend_from_slice(b"Upgrade: websocket\r\n"); + response.extend_from_slice(b"Connection: Upgrade\r\n"); + response.extend_from_slice(b"\r\n"); + response.extend_from_slice(b"\x81\x05hello"); // WebSocket frame + + let (upstream_read, mut upstream_write) = tokio::io::duplex(4096); + let (mut client_read, client_write) = tokio::io::duplex(4096); + + upstream_write.write_all(&response).await.unwrap(); + drop(upstream_write); + + let mut upstream_read = upstream_read; + let mut client_write = client_write; + + let result = tokio::time::timeout( + std::time::Duration::from_secs(2), + relay_response("GET", &mut upstream_read, &mut client_write), + ) + .await + .expect("relay_response should not deadlock"); + + let outcome = result.expect("relay_response should succeed"); + match outcome { + RelayOutcome::Upgraded { overflow } => { + assert_eq!( + &overflow, b"\x81\x05hello", + "overflow should contain WebSocket frame data" + ); + } + other => panic!("Expected Upgraded, got {other:?}"), + } + + client_write.shutdown().await.unwrap(); + let mut received = Vec::new(); + client_read.read_to_end(&mut received).await.unwrap(); + let received_str = String::from_utf8_lossy(&received); + assert!( + received_str.contains("101 Switching Protocols"), + "client should receive the 101 response headers" + ); + } + + #[tokio::test] + async fn relay_response_101_no_overflow() { + // 101 response with no trailing bytes — overflow should be empty. + let response = b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n"; + + let (upstream_read, mut upstream_write) = tokio::io::duplex(4096); + let (_client_read, client_write) = tokio::io::duplex(4096); + + upstream_write.write_all(response).await.unwrap(); + drop(upstream_write); + + let mut upstream_read = upstream_read; + let mut client_write = client_write; + + let result = tokio::time::timeout( + std::time::Duration::from_secs(2), + relay_response("GET", &mut upstream_read, &mut client_write), + ) + .await + .expect("relay_response should not deadlock"); + + match result.expect("should succeed") { + RelayOutcome::Upgraded { overflow } => { + assert!(overflow.is_empty(), "no overflow expected"); + } + other => panic!("Expected Upgraded, got {other:?}"), + } + } + + #[tokio::test] + async fn relay_rejects_unsolicited_101_without_client_upgrade_header() { + // Client sends a normal GET without Upgrade headers. + // Upstream responds with 101 (non-compliant). The relay should + // reject the upgrade and return Consumed instead. + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let req = L7Request { + action: "GET".to_string(), + target: "/api".to_string(), + query_params: HashMap::new(), + raw_header: b"GET /api HTTP/1.1\r\nHost: example.com\r\n\r\n".to_vec(), + body_length: BodyLength::None, + }; + + let upstream_task = tokio::spawn(async move { + // Read the request + let mut buf = vec![0u8; 4096]; + let mut total = 0; + loop { + let n = upstream_side.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + } + // Send unsolicited 101 + upstream_side + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n", + ) + .await + .unwrap(); + upstream_side.flush().await.unwrap(); + }); + + let result = tokio::time::timeout( + std::time::Duration::from_secs(5), + relay_http_request_with_resolver( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + None, + ), + ) + .await + .expect("relay must not deadlock"); + + let outcome = result.expect("relay should succeed"); + assert!( + matches!(outcome, RelayOutcome::Consumed), + "unsolicited 101 should be rejected as Consumed, got {outcome:?}" + ); + + upstream_task.await.expect("upstream task should complete"); + } + + #[tokio::test] + async fn relay_accepts_101_with_client_upgrade_header() { + // Client sends a proper upgrade request with Upgrade + Connection headers. + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let req = L7Request { + action: "GET".to_string(), + target: "/ws".to_string(), + query_params: HashMap::new(), + raw_header: b"GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n".to_vec(), + body_length: BodyLength::None, + }; + + let upstream_task = tokio::spawn(async move { + let mut buf = vec![0u8; 4096]; + let mut total = 0; + loop { + let n = upstream_side.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + } + upstream_side + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n", + ) + .await + .unwrap(); + upstream_side.flush().await.unwrap(); + }); + + let result = tokio::time::timeout( + std::time::Duration::from_secs(5), + relay_http_request_with_resolver( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + None, + ), + ) + .await + .expect("relay must not deadlock"); + + let outcome = result.expect("relay should succeed"); + assert!( + matches!(outcome, RelayOutcome::Upgraded { .. }), + "proper upgrade request should be accepted, got {outcome:?}" + ); + + upstream_task.await.expect("upstream task should complete"); + } + + #[test] + fn client_requested_upgrade_detects_websocket_headers() { + let headers = "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n"; + assert!(client_requested_upgrade(headers)); + } + + #[test] + fn client_requested_upgrade_rejects_missing_upgrade_header() { + let headers = "GET /api HTTP/1.1\r\nHost: example.com\r\n\r\n"; + assert!(!client_requested_upgrade(headers)); + } + + #[test] + fn client_requested_upgrade_rejects_upgrade_without_connection() { + let headers = "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\n\r\n"; + assert!(!client_requested_upgrade(headers)); + } + + #[test] + fn client_requested_upgrade_handles_comma_separated_connection() { + let headers = "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: keep-alive, Upgrade\r\n\r\n"; + assert!(client_requested_upgrade(headers)); + } + #[test] fn rewrite_header_block_resolves_placeholder_auth_headers() { let (_, resolver) = SecretResolver::from_provider_env( @@ -1114,8 +1730,8 @@ mod tests { ); let raw = b"GET /v1/messages HTTP/1.1\r\nAuthorization: Bearer openshell:resolve:env:ANTHROPIC_API_KEY\r\nHost: example.com\r\n\r\n"; - let rewritten = rewrite_http_header_block(raw, resolver.as_ref()); - let rewritten = String::from_utf8(rewritten).expect("utf8"); + let result = rewrite_http_header_block(raw, resolver.as_ref()).expect("should succeed"); + let rewritten = String::from_utf8(result.rewritten).expect("utf8"); assert!(rewritten.contains("Authorization: Bearer sk-test\r\n")); assert!(!rewritten.contains("openshell:resolve:env:ANTHROPIC_API_KEY")); @@ -1133,7 +1749,7 @@ mod tests { /// to the upstream API, causing 401 Unauthorized errors. #[tokio::test] async fn relay_request_with_resolver_rewrites_credential_placeholders() { - let provider_env: std::collections::HashMap = [( + let provider_env: HashMap = [( "NVIDIA_API_KEY".to_string(), "nvapi-real-secret-key".to_string(), )] @@ -1149,6 +1765,7 @@ mod tests { let req = L7Request { action: "POST".to_string(), target: "/v1/chat/completions".to_string(), + query_params: HashMap::new(), raw_header: format!( "POST /v1/chat/completions HTTP/1.1\r\n\ Host: integrate.api.nvidia.com\r\n\ @@ -1232,6 +1849,7 @@ mod tests { let req = L7Request { action: "POST".to_string(), target: "/v1/chat/completions".to_string(), + query_params: HashMap::new(), raw_header: format!( "POST /v1/chat/completions HTTP/1.1\r\n\ Host: integrate.api.nvidia.com\r\n\ @@ -1293,4 +1911,376 @@ mod tests { "Real secret should NOT appear without resolver, got: {forwarded}" ); } + + // ========================================================================= + // Credential injection integration tests + // + // Each test exercises a different injection location through the full + // relay_http_request_with_resolver pipeline: child builds an HTTP request + // with a placeholder, the relay rewrites it, and we verify what upstream + // receives. + // ========================================================================= + + /// Helper: run a request through the relay and capture what upstream receives. + async fn relay_and_capture( + raw_header: Vec, + body_length: BodyLength, + resolver: Option<&SecretResolver>, + ) -> Result { + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + // Parse the request line to extract action and target for L7Request + let header_str = String::from_utf8_lossy(&raw_header); + let first_line = header_str.lines().next().unwrap_or(""); + let parts: Vec<&str> = first_line.splitn(3, ' ').collect(); + let action = parts.first().unwrap_or(&"GET").to_string(); + let target = parts.get(1).unwrap_or(&"/").to_string(); + + let req = L7Request { + action, + target, + query_params: HashMap::new(), + raw_header, + body_length, + }; + + let content_len = match body_length { + BodyLength::ContentLength(n) => n, + _ => 0, + }; + + let upstream_task = tokio::spawn(async move { + let mut buf = vec![0u8; 8192]; + let mut total = 0; + loop { + let n = upstream_side.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if let Some(hdr_end) = buf[..total].windows(4).position(|w| w == b"\r\n\r\n") { + if total >= hdr_end + 4 + content_len as usize { + break; + } + } + } + upstream_side + .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok") + .await + .unwrap(); + upstream_side.flush().await.unwrap(); + String::from_utf8_lossy(&buf[..total]).to_string() + }); + + let relay = tokio::time::timeout( + std::time::Duration::from_secs(5), + relay_http_request_with_resolver( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + resolver, + ), + ) + .await + .map_err(|_| miette!("relay timed out"))?; + relay?; + + let forwarded = upstream_task + .await + .map_err(|e| miette!("upstream task failed: {e}"))?; + Ok(forwarded) + } + + #[tokio::test] + async fn relay_injects_bearer_header_credential() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [("API_KEY".to_string(), "sk-real-secret-key".to_string())] + .into_iter() + .collect(), + ); + let placeholder = child_env.get("API_KEY").unwrap(); + + let raw = format!( + "POST /v1/chat HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Authorization: Bearer {placeholder}\r\n\ + Content-Length: 2\r\n\r\n{{}}" + ); + + let forwarded = relay_and_capture( + raw.into_bytes(), + BodyLength::ContentLength(2), + resolver.as_ref(), + ) + .await + .expect("relay should succeed"); + + assert!( + forwarded.contains("Authorization: Bearer sk-real-secret-key\r\n"), + "Upstream should see real Bearer token, got: {forwarded}" + ); + assert!( + !forwarded.contains("openshell:resolve:env:"), + "Placeholder leaked to upstream: {forwarded}" + ); + } + + #[tokio::test] + async fn relay_injects_exact_header_credential() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [("CUSTOM_TOKEN".to_string(), "tok-12345".to_string())] + .into_iter() + .collect(), + ); + let placeholder = child_env.get("CUSTOM_TOKEN").unwrap(); + + let raw = format!( + "GET /api/data HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + x-api-key: {placeholder}\r\n\ + Content-Length: 0\r\n\r\n" + ); + + let forwarded = relay_and_capture( + raw.into_bytes(), + BodyLength::ContentLength(0), + resolver.as_ref(), + ) + .await + .expect("relay should succeed"); + + assert!( + forwarded.contains("x-api-key: tok-12345\r\n"), + "Upstream should see real x-api-key, got: {forwarded}" + ); + assert!(!forwarded.contains("openshell:resolve:env:")); + } + + #[tokio::test] + async fn relay_injects_basic_auth_credential() { + let b64 = base64::engine::general_purpose::STANDARD; + + let (child_env, resolver) = SecretResolver::from_provider_env( + [("REGISTRY_PASS".to_string(), "hunter2".to_string())] + .into_iter() + .collect(), + ); + let placeholder = child_env.get("REGISTRY_PASS").unwrap(); + let encoded = b64.encode(format!("deploy:{placeholder}").as_bytes()); + + let raw = format!( + "GET /v2/_catalog HTTP/1.1\r\n\ + Host: registry.example.com\r\n\ + Authorization: Basic {encoded}\r\n\ + Content-Length: 0\r\n\r\n" + ); + + let forwarded = relay_and_capture( + raw.into_bytes(), + BodyLength::ContentLength(0), + resolver.as_ref(), + ) + .await + .expect("relay should succeed"); + + // Extract and decode the Basic auth token from what upstream received + let auth_line = forwarded + .lines() + .find(|l| l.starts_with("Authorization: Basic")) + .expect("upstream should have Authorization header"); + let token = auth_line + .strip_prefix("Authorization: Basic ") + .unwrap() + .trim(); + let decoded = b64.decode(token).expect("valid base64"); + let decoded_str = std::str::from_utf8(&decoded).expect("valid utf8"); + + assert_eq!( + decoded_str, "deploy:hunter2", + "Decoded Basic auth should contain real password" + ); + assert!(!forwarded.contains("openshell:resolve:env:")); + } + + #[tokio::test] + async fn relay_injects_query_param_credential() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [("YOUTUBE_KEY".to_string(), "AIzaSy-secret".to_string())] + .into_iter() + .collect(), + ); + let placeholder = child_env.get("YOUTUBE_KEY").unwrap(); + + let raw = format!( + "GET /v3/search?part=snippet&key={placeholder} HTTP/1.1\r\n\ + Host: www.googleapis.com\r\n\ + Content-Length: 0\r\n\r\n" + ); + + let forwarded = relay_and_capture( + raw.into_bytes(), + BodyLength::ContentLength(0), + resolver.as_ref(), + ) + .await + .expect("relay should succeed"); + + assert!( + forwarded.contains("key=AIzaSy-secret"), + "Upstream should see real API key in query param, got: {forwarded}" + ); + assert!( + forwarded.contains("part=snippet"), + "Non-secret query params should be preserved, got: {forwarded}" + ); + assert!(!forwarded.contains("openshell:resolve:env:")); + } + + #[tokio::test] + async fn relay_injects_url_path_credential_telegram_style() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [( + "TELEGRAM_TOKEN".to_string(), + "123456:ABC-DEF1234ghIkl".to_string(), + )] + .into_iter() + .collect(), + ); + let placeholder = child_env.get("TELEGRAM_TOKEN").unwrap(); + + let raw = format!( + "POST /bot{placeholder}/sendMessage HTTP/1.1\r\n\ + Host: api.telegram.org\r\n\ + Content-Length: 2\r\n\r\n{{}}" + ); + + let forwarded = relay_and_capture( + raw.into_bytes(), + BodyLength::ContentLength(2), + resolver.as_ref(), + ) + .await + .expect("relay should succeed"); + + assert!( + forwarded.contains("POST /bot123456:ABC-DEF1234ghIkl/sendMessage HTTP/1.1"), + "Upstream should see real token in URL path, got: {forwarded}" + ); + assert!(!forwarded.contains("openshell:resolve:env:")); + } + + #[tokio::test] + async fn relay_injects_url_path_credential_standalone_segment() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [("ORG_TOKEN".to_string(), "org-abc-789".to_string())] + .into_iter() + .collect(), + ); + let placeholder = child_env.get("ORG_TOKEN").unwrap(); + + let raw = format!( + "GET /api/{placeholder}/resources HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Content-Length: 0\r\n\r\n" + ); + + let forwarded = relay_and_capture( + raw.into_bytes(), + BodyLength::ContentLength(0), + resolver.as_ref(), + ) + .await + .expect("relay should succeed"); + + assert!( + forwarded.contains("GET /api/org-abc-789/resources HTTP/1.1"), + "Upstream should see real token in path segment, got: {forwarded}" + ); + assert!(!forwarded.contains("openshell:resolve:env:")); + } + + #[tokio::test] + async fn relay_injects_combined_path_and_header_credentials() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [ + ("PATH_TOKEN".to_string(), "tok-path-123".to_string()), + ("HEADER_KEY".to_string(), "sk-header-456".to_string()), + ] + .into_iter() + .collect(), + ); + let path_ph = child_env.get("PATH_TOKEN").unwrap(); + let header_ph = child_env.get("HEADER_KEY").unwrap(); + + let raw = format!( + "POST /bot{path_ph}/send HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + x-api-key: {header_ph}\r\n\ + Content-Length: 2\r\n\r\n{{}}" + ); + + let forwarded = relay_and_capture( + raw.into_bytes(), + BodyLength::ContentLength(2), + resolver.as_ref(), + ) + .await + .expect("relay should succeed"); + + assert!( + forwarded.contains("/bottok-path-123/send"), + "Upstream should see real token in path, got: {forwarded}" + ); + assert!( + forwarded.contains("x-api-key: sk-header-456\r\n"), + "Upstream should see real token in header, got: {forwarded}" + ); + assert!(!forwarded.contains("openshell:resolve:env:")); + } + + #[tokio::test] + async fn relay_fail_closed_rejects_unresolved_placeholder() { + // Create a resolver that knows about KEY1 but not UNKNOWN_KEY + let (child_env, resolver) = SecretResolver::from_provider_env( + [("KEY1".to_string(), "secret1".to_string())] + .into_iter() + .collect(), + ); + let _ = child_env; + + // The request references a placeholder that the resolver doesn't know + let raw = b"GET /api HTTP/1.1\r\n\ + Host: example.com\r\n\ + x-api-key: openshell:resolve:env:UNKNOWN_KEY\r\n\ + Content-Length: 0\r\n\r\n" + .to_vec(); + + let result = relay_and_capture(raw, BodyLength::ContentLength(0), resolver.as_ref()).await; + + assert!( + result.is_err(), + "Relay should fail when placeholder cannot be resolved" + ); + } + + #[tokio::test] + async fn relay_fail_closed_rejects_unresolved_path_placeholder() { + let (_, resolver) = SecretResolver::from_provider_env( + [("KEY1".to_string(), "secret1".to_string())] + .into_iter() + .collect(), + ); + + let raw = + b"GET /api/openshell:resolve:env:UNKNOWN_KEY/data HTTP/1.1\r\nHost: x\r\nContent-Length: 0\r\n\r\n" + .to_vec(); + + let result = relay_and_capture(raw, BodyLength::ContentLength(0), resolver.as_ref()).await; + + assert!( + result.is_err(), + "Relay should fail when path placeholder cannot be resolved" + ); + } } diff --git a/crates/openshell-sandbox/src/lib.rs b/crates/openshell-sandbox/src/lib.rs index 493e4d237..f9e8fb4c5 100644 --- a/crates/openshell-sandbox/src/lib.rs +++ b/crates/openshell-sandbox/src/lib.rs @@ -27,12 +27,63 @@ use miette::{IntoDiagnostic, Result}; use std::collections::HashSet; use std::net::SocketAddr; use std::sync::Arc; +use std::sync::OnceLock; use std::sync::atomic::{AtomicU32, Ordering}; #[cfg(target_os = "linux")] use std::sync::{LazyLock, Mutex}; use std::time::Duration; use tokio::time::timeout; -use tracing::{debug, error, info, trace, warn}; +use tracing::{debug, info, trace, warn}; + +use openshell_ocsf::{ + ActionId, ActivityId, AppLifecycleBuilder, ConfigStateChangeBuilder, DetectionFindingBuilder, + DispositionId, FindingInfo, LaunchTypeId, Process as OcsfProcess, ProcessActivityBuilder, + SandboxContext, SeverityId, StateId, StatusId, ocsf_emit, +}; + +// --------------------------------------------------------------------------- +// OCSF Context +// --------------------------------------------------------------------------- +// +// The following log sites intentionally remain as plain `tracing` macros +// and are NOT migrated to OCSF builders: +// +// - DEBUG/TRACE events (zombie reaping, ip commands, gRPC connects, PTY state) +// - Transient "about to do X" events where the result is logged separately +// (e.g., "Fetching sandbox policy via gRPC", "Creating OPA engine from proto") +// - Internal SSH channel warnings (unknown channel, PTY resize failures) +// - Denial flush telemetry (the individual denials are already OCSF events) +// - Status reporting failures (sync to gateway, non-actionable) +// - Route refresh interval validation warnings +// +// These are operational plumbing that don't represent security decisions, +// policy changes, or observable sandbox behavior worth structuring. +// --------------------------------------------------------------------------- + +/// Process-wide OCSF sandbox context. Initialized once during `run_sandbox()` +/// startup and accessible from any module in the crate via [`ocsf_ctx()`]. +static OCSF_CTX: OnceLock = OnceLock::new(); + +/// Fallback context used when `OCSF_CTX` has not been initialized (e.g. in +/// unit tests that exercise individual functions without calling `run_sandbox`). +static OCSF_CTX_FALLBACK: std::sync::LazyLock = + std::sync::LazyLock::new(|| SandboxContext { + sandbox_id: String::new(), + sandbox_name: String::new(), + container_image: String::new(), + hostname: "test".to_string(), + product_version: openshell_core::VERSION.to_string(), + proxy_ip: std::net::IpAddr::from([127, 0, 0, 1]), + proxy_port: 3128, + }); + +/// Return a reference to the process-wide [`SandboxContext`]. +/// +/// Falls back to a default context if `run_sandbox()` has not yet been called +/// (e.g. during unit tests). +pub(crate) fn ocsf_ctx() -> &'static SandboxContext { + OCSF_CTX.get().unwrap_or(&OCSF_CTX_FALLBACK) +} use crate::identity::BinaryIdentityCache; use crate::l7::tls::{ @@ -162,11 +213,37 @@ pub async fn run_sandbox( _health_check: bool, _health_port: u16, inference_routes: Option, + ocsf_enabled: Arc, ) -> Result { let (program, args) = command .split_first() .ok_or_else(|| miette::miette!("No command specified"))?; + // Initialize the process-wide OCSF context early so that events emitted + // during policy loading (filesystem config, validation) have a context. + // Proxy IP/port use defaults here; they are only significant for network + // events which happen after the netns is created. + { + let hostname = std::fs::read_to_string("/etc/hostname") + .map(|s| s.trim().to_string()) + .unwrap_or_else(|_| "openshell-sandbox".to_string()); + + if OCSF_CTX + .set(SandboxContext { + sandbox_id: sandbox_id.clone().unwrap_or_default(), + sandbox_name: sandbox.as_deref().unwrap_or_default().to_string(), + container_image: std::env::var("OPENSHELL_CONTAINER_IMAGE").unwrap_or_default(), + hostname, + product_version: openshell_core::VERSION.to_string(), + proxy_ip: std::net::IpAddr::from([127, 0, 0, 1]), + proxy_port: 3128, + }) + .is_err() + { + debug!("OCSF context already initialized, keeping existing"); + } + } + // Load policy and initialize OPA engine let openshell_endpoint_for_proxy = openshell_endpoint.clone(); let sandbox_name_for_agg = sandbox.clone(); @@ -190,11 +267,30 @@ pub async fn run_sandbox( let provider_env = if let (Some(id), Some(endpoint)) = (&sandbox_id, &openshell_endpoint) { match grpc_client::fetch_provider_environment(endpoint, id).await { Ok(env) => { - info!(env_count = env.len(), "Fetched provider environment"); + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "loaded") + .message(format!( + "Fetched provider environment [env_count:{}]", + env.len() + )) + .build() + ); env } Err(e) => { - warn!(error = %e, "Failed to fetch provider environment, continuing without"); + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .state(StateId::Other, "degraded") + .message(format!( + "Failed to fetch provider environment, continuing without: {e}" + )) + .build() + ); std::collections::HashMap::new() } } @@ -228,22 +324,41 @@ pub async fn run_sandbox( let upstream_config = build_upstream_client_config(); let cert_cache = CertCache::new(ca); let state = Arc::new(ProxyTlsState::new(cert_cache, upstream_config)); - info!("TLS termination enabled: ephemeral CA generated"); + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "enabled") + .message("TLS termination enabled: ephemeral CA generated") + .build() + ); (Some(state), Some(paths)) } Err(e) => { - tracing::warn!( - error = %e, - "Failed to write CA files, TLS termination disabled" + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .state(StateId::Disabled, "disabled") + .message(format!( + "Failed to write CA files, TLS termination disabled: {e}" + )) + .build() ); (None, None) } } } Err(e) => { - tracing::warn!( - error = %e, - "Failed to generate ephemeral CA, TLS termination disabled" + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .state(StateId::Disabled, "disabled") + .message(format!( + "Failed to generate ephemeral CA, TLS termination disabled: {e}" + )) + .build() ); (None, None) } @@ -269,9 +384,15 @@ pub async fn run_sandbox( .and_then(|p| p.http_addr) .map_or(3128, |addr| addr.port()); if let Err(e) = ns.install_bypass_rules(proxy_port) { - warn!( - error = %e, - "Failed to install bypass detection rules (non-fatal)" + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .state(StateId::Disabled, "degraded") + .message(format!( + "Failed to install bypass detection rules (non-fatal): {e}" + )) + .build() ); } Some(ns) @@ -514,7 +635,14 @@ pub async fn run_sandbox( ) .await { - tracing::error!(error = %err, "SSH server failed"); + ocsf_emit!( + AppLifecycleBuilder::new(ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Critical) + .status(StatusId::Failure) + .message(format!("SSH server failed: {err}")) + .build() + ); } }); @@ -523,7 +651,14 @@ pub async fn run_sandbox( // SSH server startup when Kubernetes marks the pod Ready. match timeout(Duration::from_secs(10), ssh_ready_rx).await { Ok(Ok(Ok(()))) => { - info!("SSH server is ready to accept connections"); + ocsf_emit!( + AppLifecycleBuilder::new(ocsf_ctx()) + .activity(ActivityId::Open) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .message("SSH server is ready to accept connections") + .build() + ); } Ok(Ok(Err(err))) => { return Err(err.context("SSH server failed during startup")); @@ -566,7 +701,18 @@ pub async fn run_sandbox( // Store the entrypoint PID so the proxy can resolve TCP peer identity entrypoint_pid.store(handle.pid(), Ordering::Release); - info!(pid = handle.pid(), "Process started"); + ocsf_emit!( + ProcessActivityBuilder::new(ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Allowed) + .disposition(DispositionId::Allowed) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .launch_type(LaunchTypeId::Spawn) + .process(OcsfProcess::new(program, i64::from(handle.pid()))) + .message(format!("Process started: pid={}", handle.pid())) + .build() + ); // Spawn background policy poll task (gRPC mode only). if let (Some(id), Some(endpoint), Some(engine)) = @@ -575,17 +721,30 @@ pub async fn run_sandbox( let poll_id = id.clone(); let poll_endpoint = endpoint.clone(); let poll_engine = engine.clone(); + let poll_ocsf_enabled = ocsf_enabled.clone(); let poll_interval_secs: u64 = std::env::var("OPENSHELL_POLICY_POLL_INTERVAL_SECS") .ok() .and_then(|v| v.parse().ok()) .unwrap_or(10); tokio::spawn(async move { - if let Err(e) = - run_policy_poll_loop(&poll_endpoint, &poll_id, &poll_engine, poll_interval_secs) - .await + if let Err(e) = run_policy_poll_loop( + &poll_endpoint, + &poll_id, + &poll_engine, + poll_interval_secs, + &poll_ocsf_enabled, + ) + .await { - warn!(error = %e, "Policy poll loop exited with error"); + ocsf_emit!( + AppLifecycleBuilder::new(ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .message(format!("Policy poll loop exited with error: {e}")) + .build() + ); } }); @@ -625,7 +784,16 @@ pub async fn run_sandbox( if let Ok(result) = timeout(Duration::from_secs(timeout_secs), handle.wait()).await { result } else { - error!("Process timed out, killing"); + ocsf_emit!( + ProcessActivityBuilder::new(ocsf_ctx()) + .activity(ActivityId::Close) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Critical) + .status(StatusId::Failure) + .message("Process timed out, killing") + .build() + ); handle.kill()?; return Ok(124); // Standard timeout exit code } @@ -635,7 +803,17 @@ pub async fn run_sandbox( let status = result.into_diagnostic()?; - info!(exit_code = status.code(), "Process exited"); + ocsf_emit!( + ProcessActivityBuilder::new(ocsf_ctx()) + .activity(ActivityId::Close) + .action(ActionId::Allowed) + .disposition(DispositionId::Allowed) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .exit_code(status.code()) + .message(format!("Process exited with code {}", status.code())) + .build() + ); Ok(status.code()) } @@ -672,12 +850,25 @@ async fn build_inference_context( // Standalone mode: load routes from file (fail-fast on errors) if sandbox_id.is_some() { - info!( - inference_routes = %path, - "Inference routes file takes precedence over cluster bundle" - ); + ocsf_emit!(ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "loaded") + .unmapped("inference_routes", serde_json::json!(path)) + .message(format!( + "Inference routes file takes precedence over cluster bundle [path:{path}]" + )) + .build()); } - info!(inference_routes = %path, "Loading inference routes from file"); + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Other, "loading") + .unmapped("inference_routes", serde_json::json!(path)) + .message(format!("Loading inference routes from file [path:{path}]")) + .build() + ); let config = RouterConfig::load_from_file(std::path::Path::new(path)) .map_err(|e| miette::miette!("failed to load inference routes {path}: {e}"))?; config @@ -694,10 +885,19 @@ async fn build_inference_context( match grpc_client::fetch_inference_bundle(endpoint).await { Ok(bundle) => { initial_revision = Some(bundle.revision.clone()); - info!( - route_count = bundle.routes.len(), - revision = %bundle.revision, - "Loaded inference route bundle" + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "loaded") + .unmapped("route_count", serde_json::json!(bundle.routes.len())) + .unmapped("revision", serde_json::json!(&bundle.revision)) + .message(format!( + "Loaded inference route bundle [route_count:{} revision:{}]", + bundle.routes.len(), + bundle.revision + )) + .build() ); bundle_to_resolved_routes(&bundle) } @@ -707,10 +907,28 @@ async fn build_inference_context( // for this sandbox — skip gracefully. Other errors are unexpected. let msg = e.to_string(); if msg.contains("permission denied") || msg.contains("not found") { - info!(error = %e, "Inference bundle unavailable, routing disabled"); + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Disabled, "disabled") + .unmapped("error", serde_json::json!(e.to_string())) + .message(format!( + "Inference bundle unavailable, routing disabled [error:{e}]" + )) + .build() + ); return Ok(None); } - warn!(error = %e, "Failed to fetch inference bundle, inference routing disabled"); + ocsf_emit!(ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .state(StateId::Disabled, "disabled") + .unmapped("error", serde_json::json!(e.to_string())) + .message(format!( + "Failed to fetch inference bundle, inference routing disabled [error:{e}]" + )) + .build()); return Ok(None); } } @@ -722,17 +940,37 @@ async fn build_inference_context( }; if routes.is_empty() && disable_inference_on_empty_routes(source) { - info!("No usable inference routes, inference routing disabled"); + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Disabled, "disabled") + .message("No usable inference routes, inference routing disabled") + .build() + ); return Ok(None); } if routes.is_empty() { - info!("Inference route bundle is empty; keeping routing enabled and waiting for refresh"); + ocsf_emit!(ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Other, "waiting") + .message("Inference route bundle is empty; keeping routing enabled and waiting for refresh") + .build()); } - info!( - route_count = routes.len(), - "Inference routing enabled with local execution" + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "enabled") + .unmapped("route_count", serde_json::json!(routes.len())) + .message(format!( + "Inference routing enabled with local execution [route_count:{}]", + routes.len() + )) + .build() ); // Partition routes by name into user-facing and system caches. @@ -801,6 +1039,11 @@ pub(crate) fn bundle_to_resolved_routes( .map(|r| { let (auth, default_headers) = openshell_core::inference::auth_for_provider_type(&r.provider_type); + let timeout = if r.timeout_secs == 0 { + openshell_router::config::DEFAULT_ROUTE_TIMEOUT + } else { + Duration::from_secs(r.timeout_secs) + }; openshell_router::config::ResolvedRoute { name: r.name.clone(), endpoint: r.base_url.clone(), @@ -809,6 +1052,7 @@ pub(crate) fn bundle_to_resolved_routes( protocols: r.protocols.clone(), auth, default_headers, + timeout, } }) .collect() @@ -847,18 +1091,34 @@ pub(crate) fn spawn_route_refresh( let routes = bundle_to_resolved_routes(&bundle); let (user_routes, system_routes) = partition_routes(routes); - info!( - user_route_count = user_routes.len(), - system_route_count = system_routes.len(), - revision = %bundle.revision, - "Inference routes updated" - ); + ocsf_emit!(ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "updated") + .unmapped("user_route_count", serde_json::json!(user_routes.len())) + .unmapped("system_route_count", serde_json::json!(system_routes.len())) + .unmapped("revision", serde_json::json!(&bundle.revision)) + .message(format!( + "Inference routes updated [user_route_count:{} system_route_count:{} revision:{}]", + user_routes.len(), + system_routes.len(), + bundle.revision + )) + .build()); current_revision = Some(bundle.revision); *user_cache.write().await = user_routes; *system_cache.write().await = system_routes; } Err(e) => { - warn!(error = %e, "Failed to refresh inference route cache, keeping stale routes"); + ocsf_emit!(ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .state(StateId::Other, "stale") + .unmapped("error", serde_json::json!(e.to_string())) + .message(format!( + "Failed to refresh inference route cache, keeping stale routes [error:{e}]" + )) + .build()); } } } @@ -871,13 +1131,106 @@ pub(crate) fn spawn_route_refresh( /// Minimum read-only paths required for a proxy-mode sandbox child process to /// function: dynamic linker, shared libraries, DNS resolution, CA certs, -/// Python venv, and openshell logs. -const PROXY_BASELINE_READ_ONLY: &[&str] = &["/usr", "/lib", "/etc", "/app", "/var/log"]; +/// Python venv, openshell logs, process info, and random bytes. +/// +/// `/proc` and `/dev/urandom` are included here for the same reasons they +/// appear in `restrictive_default_policy()`: virtually every process needs +/// them. Before the Landlock per-path fix (#677) these were effectively free +/// because a missing path silently disabled the entire ruleset; now they must +/// be explicit. +const PROXY_BASELINE_READ_ONLY: &[&str] = &[ + "/usr", + "/lib", + "/etc", + "/app", + "/var/log", + "/proc", + "/dev/urandom", +]; /// Minimum read-write paths required for a proxy-mode sandbox child process: /// user working directory and temporary files. const PROXY_BASELINE_READ_WRITE: &[&str] = &["/sandbox", "/tmp"]; +/// GPU read-only paths. +/// +/// `/run/nvidia-persistenced`: NVML tries to connect to the persistenced +/// socket at init time. If the directory exists but Landlock denies traversal +/// (EACCES vs ECONNREFUSED), NVML returns `NVML_ERROR_INSUFFICIENT_PERMISSIONS` +/// even though the daemon is optional. Only read/traversal access is needed. +const GPU_BASELINE_READ_ONLY: &[&str] = &["/run/nvidia-persistenced"]; + +/// GPU read-write paths (static). +/// +/// `/dev/nvidiactl`, `/dev/nvidia-uvm`, `/dev/nvidia-uvm-tools`, +/// `/dev/nvidia-modeset`: control and UVM devices injected by CDI. +/// Landlock restricts `open(2)` on device files even when DAC allows it; +/// these need read-write because NVML/CUDA opens them with `O_RDWR`. +/// +/// `/proc`: CUDA writes to `/proc//task//comm` during `cuInit()` +/// to set thread names. Without write access, `cuInit()` returns error 304. +/// Must use `/proc` (not `/proc/self/task`) because Landlock rules bind to +/// inodes and child processes have different procfs inodes than the parent. +/// +/// Per-GPU device files (`/dev/nvidia0`, …) are enumerated at runtime by +/// `enumerate_gpu_device_nodes()` since the count varies. +const GPU_BASELINE_READ_WRITE: &[&str] = &[ + "/dev/nvidiactl", + "/dev/nvidia-uvm", + "/dev/nvidia-uvm-tools", + "/dev/nvidia-modeset", + "/proc", +]; + +/// Returns true if GPU devices are present in the container. +fn has_gpu_devices() -> bool { + std::path::Path::new("/dev/nvidiactl").exists() +} + +/// Enumerate per-GPU device nodes (`/dev/nvidia0`, `/dev/nvidia1`, …). +fn enumerate_gpu_device_nodes() -> Vec { + let mut paths = Vec::new(); + if let Ok(entries) = std::fs::read_dir("/dev") { + for entry in entries.flatten() { + let name = entry.file_name(); + let name = name.to_string_lossy(); + if let Some(suffix) = name.strip_prefix("nvidia") { + if suffix.is_empty() || !suffix.chars().all(|c| c.is_ascii_digit()) { + continue; + } + paths.push(entry.path().to_string_lossy().into_owned()); + } + } + } + paths +} + +/// Collect all baseline paths for enrichment: proxy defaults + GPU (if present). +/// Returns `(read_only, read_write)` as owned `String` vecs. +fn baseline_enrichment_paths() -> (Vec, Vec) { + let mut ro: Vec = PROXY_BASELINE_READ_ONLY + .iter() + .map(|&s| s.to_string()) + .collect(); + let mut rw: Vec = PROXY_BASELINE_READ_WRITE + .iter() + .map(|&s| s.to_string()) + .collect(); + + if has_gpu_devices() { + ro.extend(GPU_BASELINE_READ_ONLY.iter().map(|&s| s.to_string())); + rw.extend(GPU_BASELINE_READ_WRITE.iter().map(|&s| s.to_string())); + rw.extend(enumerate_gpu_device_nodes()); + } + + // A path promoted to read_write (e.g. /proc for GPU) should not also + // appear in read_only — Landlock handles the overlap correctly but the + // duplicate is confusing when inspecting the effective policy. + ro.retain(|p| !rw.contains(p)); + + (ro, rw) +} + /// Ensure a proto `SandboxPolicy` includes the baseline filesystem paths /// required for proxy-mode sandboxes. Paths are only added if missing; /// user-specified paths are never removed. @@ -896,22 +1249,50 @@ fn enrich_proto_baseline_paths(proto: &mut openshell_core::proto::SandboxPolicy) ..Default::default() }); + let (ro, rw) = baseline_enrichment_paths(); + + // Baseline paths are system-injected, not user-specified. Skip paths + // that do not exist in this container image to avoid noisy warnings from + // Landlock and, more critically, to prevent a single missing baseline + // path from abandoning the entire Landlock ruleset under best-effort + // mode (see issue #664). let mut modified = false; - for &path in PROXY_BASELINE_READ_ONLY { - if !fs.read_only.iter().any(|p| p.as_str() == path) { - fs.read_only.push(path.to_string()); + for path in &ro { + if !fs.read_only.iter().any(|p| p == path) && !fs.read_write.iter().any(|p| p == path) { + if !std::path::Path::new(path).exists() { + debug!( + path, + "Baseline read-only path does not exist, skipping enrichment" + ); + continue; + } + fs.read_only.push(path.clone()); modified = true; } } - for &path in PROXY_BASELINE_READ_WRITE { - if !fs.read_write.iter().any(|p| p.as_str() == path) { - fs.read_write.push(path.to_string()); + for path in &rw { + if !fs.read_write.iter().any(|p| p == path) { + if !std::path::Path::new(path).exists() { + debug!( + path, + "Baseline read-write path does not exist, skipping enrichment" + ); + continue; + } + fs.read_write.push(path.clone()); modified = true; } } if modified { - info!("Enriched policy with baseline filesystem paths for proxy mode"); + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "enriched") + .message("Enriched policy with baseline filesystem paths for proxy mode") + .build() + ); } modified @@ -925,24 +1306,116 @@ fn enrich_sandbox_baseline_paths(policy: &mut SandboxPolicy) { return; } + let (ro, rw) = baseline_enrichment_paths(); + let mut modified = false; - for &path in PROXY_BASELINE_READ_ONLY { + for path in &ro { let p = std::path::PathBuf::from(path); - if !policy.filesystem.read_only.contains(&p) { + if !policy.filesystem.read_only.contains(&p) && !policy.filesystem.read_write.contains(&p) { + if !p.exists() { + debug!( + path, + "Baseline read-only path does not exist, skipping enrichment" + ); + continue; + } policy.filesystem.read_only.push(p); modified = true; } } - for &path in PROXY_BASELINE_READ_WRITE { + for path in &rw { let p = std::path::PathBuf::from(path); if !policy.filesystem.read_write.contains(&p) { + if !p.exists() { + debug!( + path, + "Baseline read-write path does not exist, skipping enrichment" + ); + continue; + } policy.filesystem.read_write.push(p); modified = true; } } if modified { - info!("Enriched policy with baseline filesystem paths for proxy mode"); + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "enriched") + .message("Enriched policy with baseline filesystem paths for proxy mode") + .build() + ); + } +} + +#[cfg(test)] +mod baseline_tests { + use super::*; + + #[test] + fn proc_not_in_both_read_only_and_read_write_when_gpu_present() { + // When GPU devices are present, /proc is promoted to read_write + // (CUDA needs to write /proc//task//comm). It should + // NOT also appear in read_only. + if !has_gpu_devices() { + // Can't test GPU dedup without GPU devices; skip silently. + return; + } + let (ro, rw) = baseline_enrichment_paths(); + assert!( + rw.contains(&"/proc".to_string()), + "/proc should be in read_write when GPU is present" + ); + assert!( + !ro.contains(&"/proc".to_string()), + "/proc should NOT be in read_only when it is already in read_write" + ); + } + + #[test] + fn proc_in_read_only_without_gpu() { + if has_gpu_devices() { + // On a GPU host we can't test the non-GPU path; skip silently. + return; + } + let (ro, _rw) = baseline_enrichment_paths(); + assert!( + ro.contains(&"/proc".to_string()), + "/proc should be in read_only when GPU is not present" + ); + } + + #[test] + fn baseline_read_write_always_includes_sandbox_and_tmp() { + let (_ro, rw) = baseline_enrichment_paths(); + assert!(rw.contains(&"/sandbox".to_string())); + assert!(rw.contains(&"/tmp".to_string())); + } + + #[test] + fn enumerate_gpu_device_nodes_skips_bare_nvidia() { + // "nvidia" (without a trailing digit) is a valid /dev entry on some + // systems but is not a per-GPU device node. The enumerator must + // not match it. + let nodes = enumerate_gpu_device_nodes(); + assert!( + !nodes.contains(&"/dev/nvidia".to_string()), + "bare /dev/nvidia should not be enumerated: {nodes:?}" + ); + } + + #[test] + fn no_duplicate_paths_in_baseline() { + let (ro, rw) = baseline_enrichment_paths(); + // No path should appear in both lists. + for path in &ro { + assert!( + !rw.contains(path), + "path {path} appears in both read_only and read_write" + ); + } } } @@ -962,11 +1435,16 @@ async fn load_policy( ) -> Result<(SandboxPolicy, Option>)> { // File mode: load OPA engine from rego rules + YAML data (dev override) if let (Some(policy_file), Some(data_file)) = (&policy_rules, &policy_data) { - info!( - policy_rules = %policy_file, - policy_data = %data_file, - "Loading OPA policy engine from local files" - ); + ocsf_emit!(ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Other, "loading") + .unmapped("policy_rules", serde_json::json!(policy_file)) + .unmapped("policy_data", serde_json::json!(data_file)) + .message(format!( + "Loading OPA policy engine from local files [rules:{policy_file} data:{data_file}]" + )) + .build()); let engine = OpaEngine::from_files( std::path::Path::new(policy_file), std::path::Path::new(data_file), @@ -1001,7 +1479,14 @@ async fn load_policy( // No policy configured on the server. Discover from disk or // fall back to the restrictive default, then sync to the // gateway so it becomes the authoritative baseline. - info!("Server returned no policy; attempting local discovery"); + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Other, "discovery") + .message("Server returned no policy; attempting local discovery") + .build() + ); let mut discovered = discover_policy_from_disk_or_default(); // Enrich before syncing so the gateway baseline includes // baseline paths from the start. @@ -1063,10 +1548,22 @@ fn discover_policy_from_disk_or_default() -> openshell_core::proto::SandboxPolic } let legacy = std::path::Path::new(openshell_policy::LEGACY_CONTAINER_POLICY_PATH); if legacy.exists() { - info!( - legacy_path = %legacy.display(), - new_path = %primary.display(), - "Policy found at legacy path; consider moving to the new path" + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "loaded") + .unmapped( + "legacy_path", + serde_json::json!(legacy.display().to_string()) + ) + .unmapped("new_path", serde_json::json!(primary.display().to_string())) + .message(format!( + "Policy found at legacy path; consider moving [legacy_path:{} new_path:{}]", + legacy.display(), + primary.display() + )) + .build() ); return discover_policy_from_path(legacy); } @@ -1082,9 +1579,16 @@ fn discover_policy_from_path(path: &std::path::Path) -> openshell_core::proto::S match std::fs::read_to_string(path) { Ok(yaml) => { - info!( - path = %path.display(), - "Loaded sandbox policy from container disk" + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "loaded") + .message(format!( + "Loaded sandbox policy from container disk [path:{}]", + path.display() + )) + .build() ); match parse_sandbox_policy(&yaml) { Ok(policy) => { @@ -1092,29 +1596,56 @@ fn discover_policy_from_path(path: &std::path::Path) -> openshell_core::proto::S if let Err(violations) = validate_sandbox_policy(&policy) { let messages: Vec = violations.iter().map(ToString::to_string).collect(); - warn!( - path = %path.display(), - violations = %messages.join("; "), - "Disk policy contains unsafe content, using restrictive default" - ); + ocsf_emit!(DetectionFindingBuilder::new(ocsf_ctx()) + .activity(ActivityId::Open) + .severity(SeverityId::Medium) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .finding_info( + FindingInfo::new( + "unsafe-disk-policy", + "Unsafe Disk Policy Content", + ) + .with_desc(&format!( + "Disk policy at {} contains unsafe content: {}", + path.display(), + messages.join("; "), + )), + ) + .message(format!( + "Disk policy contains unsafe content, using restrictive default [path:{}]", + path.display() + )) + .build()); return restrictive_default_policy(); } policy } Err(e) => { - warn!( - path = %path.display(), - error = %e, - "Failed to parse disk policy, using restrictive default" - ); + ocsf_emit!(ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .state(StateId::Other, "fallback") + .message(format!( + "Failed to parse disk policy, using restrictive default [path:{} error:{e}]", + path.display() + )) + .build()); restrictive_default_policy() } } } Err(_) => { - info!( - path = %path.display(), - "No policy file on disk, using restrictive default" + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "default") + .message(format!( + "No policy file on disk, using restrictive default [path:{}]", + path.display() + )) + .build() ); restrictive_default_policy() } @@ -1136,7 +1667,14 @@ fn validate_sandbox_user(policy: &SandboxPolicy) -> Result<()> { if user_name.is_empty() || user_name == "sandbox" { match User::from_name("sandbox") { Ok(Some(_)) => { - info!("Validated 'sandbox' user exists in image"); + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "validated") + .message("Validated 'sandbox' user exists in image") + .build() + ); } Ok(None) => { return Err(miette::miette!( @@ -1307,9 +1845,11 @@ async fn run_policy_poll_loop( sandbox_id: &str, opa_engine: &Arc, interval_secs: u64, + ocsf_enabled: &std::sync::atomic::AtomicBool, ) -> Result<()> { use crate::grpc_client::CachedOpenShellClient; use openshell_core::proto::PolicySource; + use std::sync::atomic::Ordering; let client = CachedOpenShellClient::connect(endpoint).await?; let mut current_config_revision: u64 = 0; @@ -1356,19 +1896,28 @@ async fn run_policy_poll_loop( // Log which settings changed. log_setting_changes(¤t_settings, &result.settings); - info!( - old_config_revision = current_config_revision, - new_config_revision = result.config_revision, - policy_changed, - "Settings poll: config change detected" - ); + ocsf_emit!(ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Other, "detected") + .unmapped("old_config_revision", serde_json::json!(current_config_revision)) + .unmapped("new_config_revision", serde_json::json!(result.config_revision)) + .unmapped("policy_changed", serde_json::json!(policy_changed)) + .message(format!( + "Settings poll: config change detected [old_revision:{current_config_revision} new_revision:{} policy_changed:{policy_changed}]", + result.config_revision + )) + .build()); // Only reload OPA when the policy payload actually changed. if policy_changed { let Some(policy) = result.policy.as_ref() else { - warn!( - "Settings poll: policy hash changed but no policy payload present; skipping reload" - ); + ocsf_emit!(ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .state(StateId::Other, "skipped") + .message("Settings poll: policy hash changed but no policy payload present; skipping reload") + .build()); current_config_revision = result.config_revision; current_policy_hash = result.policy_hash; current_settings = result.settings; @@ -1378,15 +1927,30 @@ async fn run_policy_poll_loop( match opa_engine.reload_from_proto(policy) { Ok(()) => { if result.global_policy_version > 0 { - info!( - policy_hash = %result.policy_hash, - global_version = result.global_policy_version, - "Policy reloaded successfully (global)" - ); + ocsf_emit!(ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "loaded") + .unmapped("policy_hash", serde_json::json!(&result.policy_hash)) + .unmapped("global_version", serde_json::json!(result.global_policy_version)) + .message(format!( + "Policy reloaded successfully (global) [policy_hash:{} global_version:{}]", + result.policy_hash, + result.global_policy_version + )) + .build()); } else { - info!( - policy_hash = %result.policy_hash, - "Policy reloaded successfully" + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "loaded") + .unmapped("policy_hash", serde_json::json!(&result.policy_hash)) + .message(format!( + "Policy reloaded successfully [policy_hash:{}]", + result.policy_hash + )) + .build() ); } if result.version > 0 && result.policy_source == PolicySource::Sandbox { @@ -1399,11 +1963,17 @@ async fn run_policy_poll_loop( } } Err(e) => { - warn!( - version = result.version, - error = %e, - "Policy reload failed, keeping last-known-good policy" - ); + ocsf_emit!(ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .state(StateId::Other, "failed") + .unmapped("version", serde_json::json!(result.version)) + .unmapped("error", serde_json::json!(e.to_string())) + .message(format!( + "Policy reload failed, keeping last-known-good policy [version:{} error:{e}]", + result.version + )) + .build()); if result.version > 0 && result.policy_source == PolicySource::Sandbox { if let Err(report_err) = client .report_policy_status(sandbox_id, result.version, false, &e.to_string()) @@ -1416,12 +1986,35 @@ async fn run_policy_poll_loop( } } + // Apply OCSF JSON toggle from the `ocsf_json_enabled` setting. + let new_ocsf = extract_bool_setting(&result.settings, "ocsf_json_enabled").unwrap_or(false); + let prev_ocsf = ocsf_enabled.swap(new_ocsf, Ordering::Relaxed); + if new_ocsf != prev_ocsf { + info!(ocsf_json_enabled = new_ocsf, "OCSF JSONL logging toggled"); + } + current_config_revision = result.config_revision; current_policy_hash = result.policy_hash; current_settings = result.settings; } } +/// Extract a bool value from an effective setting, if present. +fn extract_bool_setting( + settings: &std::collections::HashMap, + key: &str, +) -> Option { + use openshell_core::proto::setting_value; + settings + .get(key) + .and_then(|es| es.value.as_ref()) + .and_then(|sv| sv.value.as_ref()) + .and_then(|v| match v { + setting_value::Value::BoolValue(b) => Some(*b), + _ => None, + }) +} + /// Log individual setting changes between two snapshots. fn log_setting_changes( old: &std::collections::HashMap, @@ -1433,17 +2026,46 @@ fn log_setting_changes( Some(old_es) => { let old_val = format_setting_value(old_es); if old_val != new_val { - info!(key, old = %old_val, new = %new_val, "Setting changed"); + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "updated") + .unmapped("key", serde_json::json!(key)) + .unmapped("old", serde_json::json!(old_val.to_string())) + .unmapped("new", serde_json::json!(new_val.to_string())) + .message(format!( + "Setting changed [key:{key} old:{old_val} new:{new_val}]" + )) + .build() + ); } } None => { - info!(key, value = %new_val, "Setting added"); + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "enabled") + .unmapped("key", serde_json::json!(key)) + .unmapped("value", serde_json::json!(new_val.to_string())) + .message(format!("Setting added [key:{key} value:{new_val}]")) + .build() + ); } } } for key in old.keys() { if !new.contains_key(key) { - info!(key, "Setting removed"); + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Disabled, "disabled") + .unmapped("key", serde_json::json!(key)) + .message(format!("Setting removed [key:{key}]")) + .build() + ); } } } @@ -1482,6 +2104,7 @@ mod tests { "openai_responses".to_string(), ], provider_type: "openai".to_string(), + timeout_secs: 0, }, openshell_core::proto::ResolvedRoute { name: "local".to_string(), @@ -1490,6 +2113,7 @@ mod tests { model_id: "llama-3".to_string(), protocols: vec!["openai_chat_completions".to_string()], provider_type: String::new(), + timeout_secs: 120, }, ], revision: "abc123".to_string(), @@ -1510,11 +2134,21 @@ mod tests { routes[0].protocols, vec!["openai_chat_completions", "openai_responses"] ); + assert_eq!( + routes[0].timeout, + openshell_router::config::DEFAULT_ROUTE_TIMEOUT, + "timeout_secs=0 should map to default" + ); assert_eq!(routes[1].endpoint, "http://vllm:8000/v1"); assert_eq!( routes[1].auth, openshell_core::inference::AuthHeader::Bearer ); + assert_eq!( + routes[1].timeout, + Duration::from_secs(120), + "timeout_secs=120 should map to 120s" + ); } #[test] @@ -1539,6 +2173,7 @@ mod tests { model_id: "model".to_string(), protocols: vec!["openai_chat_completions".to_string()], provider_type: "openai".to_string(), + timeout_secs: 0, }], revision: "rev".to_string(), generated_at_ms: 0, @@ -1559,6 +2194,7 @@ mod tests { protocols: vec!["openai_chat_completions".to_string()], auth: openshell_core::inference::AuthHeader::Bearer, default_headers: vec![], + timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, }, openshell_router::config::ResolvedRoute { name: "sandbox-system".to_string(), @@ -1568,6 +2204,7 @@ mod tests { protocols: vec!["anthropic_messages".to_string()], auth: openshell_core::inference::AuthHeader::Custom("x-api-key"), default_headers: vec![], + timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, }, ]; @@ -1856,6 +2493,7 @@ filesystem_policy: auth: openshell_core::inference::AuthHeader::Bearer, protocols: vec!["openai_chat_completions".to_string()], default_headers: vec![], + timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, }]; let cache = Arc::new(RwLock::new(routes)); diff --git a/crates/openshell-sandbox/src/log_push.rs b/crates/openshell-sandbox/src/log_push.rs index 21c272a24..17f9bcc3d 100644 --- a/crates/openshell-sandbox/src/log_push.rs +++ b/crates/openshell-sandbox/src/log_push.rs @@ -48,16 +48,37 @@ impl Layer for LogPushLayer { if *meta.level() > self.max_level { return; } - let mut visitor = LogVisitor::default(); - event.record(&mut visitor); - let (msg, fields) = visitor.into_parts(meta.name()); + // OCSF events carry their payload in a thread-local; extract the + // shorthand representation for the push message. Non-OCSF events + // use the original visitor-based extraction. + let (msg, fields) = if meta.target() == openshell_ocsf::OCSF_TARGET { + if let Some(ocsf_event) = openshell_ocsf::clone_current_event() { + ( + ocsf_event.format_shorthand(), + std::collections::HashMap::new(), + ) + } else { + return; + } + } else { + let mut visitor = LogVisitor::default(); + event.record(&mut visitor); + visitor.into_parts(meta.name()) + }; + let ts = current_time_ms().unwrap_or(0); + let is_ocsf = meta.target() == openshell_ocsf::OCSF_TARGET; + let log = SandboxLogLine { sandbox_id: self.sandbox_id.clone(), timestamp_ms: ts, - level: meta.level().to_string(), + level: if is_ocsf { + "OCSF".to_string() + } else { + meta.level().to_string() + }, target: meta.target().to_string(), message: msg, source: "sandbox".to_string(), diff --git a/crates/openshell-sandbox/src/main.rs b/crates/openshell-sandbox/src/main.rs index cdf5f6ff7..a37dce0e4 100644 --- a/crates/openshell-sandbox/src/main.rs +++ b/crates/openshell-sandbox/src/main.rs @@ -3,10 +3,15 @@ //! OpenShell Sandbox - process sandbox and monitor. +use std::sync::Arc; +use std::sync::atomic::AtomicBool; + use clap::Parser; use miette::Result; +use openshell_ocsf::{OcsfJsonlLayer, OcsfShorthandLayer}; use tracing::{info, warn}; use tracing_subscriber::EnvFilter; +use tracing_subscriber::filter::LevelFilter; use tracing_subscriber::{Layer, layer::SubscriberExt, util::SubscriberInitExt}; use openshell_sandbox::run_sandbox; @@ -130,37 +135,63 @@ async fn main() -> Result<()> { let push_layer = log_push_state.as_ref().map(|(layer, _)| layer.clone()); let _log_push_handle = log_push_state.map(|(_, handle)| handle); - // Keep the file guard alive for the entire process. When the guard is - // dropped the non-blocking writer flushes remaining logs. - let _file_guard = if let Some((file_writer, file_guard)) = file_logging { + // Shared flag: the sandbox poll loop toggles this when the + // `ocsf_json_enabled` setting changes. The JSONL layer checks it + // on each event and short-circuits when false. + let ocsf_enabled = Arc::new(AtomicBool::new(false)); + + // Keep guards alive for the entire process. When a guard is dropped the + // non-blocking writer flushes remaining logs. + let (_file_guard, _jsonl_guard) = if let Some((file_writer, file_guard)) = file_logging { let file_filter = EnvFilter::new("info"); + + // OCSF JSONL file: rolling appender matching the main log file + // (daily rotation, 3 files max). Created eagerly but gated by the + // enabled flag — no JSONL is written until ocsf_json_enabled is set. + let jsonl_logging = tracing_appender::rolling::RollingFileAppender::builder() + .rotation(tracing_appender::rolling::Rotation::DAILY) + .filename_prefix("openshell-ocsf") + .filename_suffix("log") + .max_log_files(3) + .build("/var/log") + .ok() + .map(|roller| { + let (writer, guard) = tracing_appender::non_blocking(roller); + let layer = OcsfJsonlLayer::new(writer).with_enabled_flag(ocsf_enabled.clone()); + (layer, guard) + }); + let (jsonl_layer, jsonl_guard) = match jsonl_logging { + Some((layer, guard)) => (Some(layer), Some(guard)), + None => (None, None), + }; + tracing_subscriber::registry() .with( - tracing_subscriber::fmt::layer() - .with_writer(std::io::stdout) + OcsfShorthandLayer::new(std::io::stdout()) + .with_non_ocsf(true) .with_filter(stdout_filter), ) .with( - tracing_subscriber::fmt::layer() - .with_writer(file_writer) - .with_ansi(false) + OcsfShorthandLayer::new(file_writer) + .with_non_ocsf(true) .with_filter(file_filter), ) + .with(jsonl_layer.with_filter(LevelFilter::INFO)) .with(push_layer.clone()) .init(); - Some(file_guard) + (Some(file_guard), jsonl_guard) } else { tracing_subscriber::registry() .with( - tracing_subscriber::fmt::layer() - .with_writer(std::io::stdout) + OcsfShorthandLayer::new(std::io::stdout()) + .with_non_ocsf(true) .with_filter(stdout_filter), ) .with(push_layer) .init(); // Log the warning after the subscriber is initialized warn!("Could not open /var/log for log rotation; using stdout-only logging"); - None + (None, None) }; // Get command - either from CLI args, environment variable, or default to /bin/bash @@ -174,6 +205,9 @@ async fn main() -> Result<()> { }; info!(command = ?command, "Starting sandbox"); + // Note: "Starting sandbox" stays as plain info!() since the OCSF context + // is not yet initialized at this point (run_sandbox hasn't been called). + // The shorthand layer will render it in fallback format. let exit_code = run_sandbox( command, @@ -191,6 +225,7 @@ async fn main() -> Result<()> { args.health_check, args.health_port, args.inference_routes, + ocsf_enabled, ) .await?; diff --git a/crates/openshell-sandbox/src/mechanistic_mapper.rs b/crates/openshell-sandbox/src/mechanistic_mapper.rs index e5ae64977..95800854a 100644 --- a/crates/openshell-sandbox/src/mechanistic_mapper.rs +++ b/crates/openshell-sandbox/src/mechanistic_mapper.rs @@ -337,6 +337,7 @@ fn build_l7_rules(samples: &HashMap<(String, String), u32>) -> Vec { method: method.clone(), path: generalised, command: String::new(), + query: HashMap::new(), }), }); } @@ -448,13 +449,27 @@ async fn resolve_allowed_ips_if_private(host: &str, port: u32) -> Vec { let addrs = match tokio::net::lookup_host(&addr).await { Ok(addrs) => addrs.collect::>(), Err(e) => { - tracing::warn!(host, port, error = %e, "DNS resolution failed for allowed_ips check"); + let event = openshell_ocsf::NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(openshell_ocsf::ActivityId::Fail) + .severity(openshell_ocsf::SeverityId::Low) + .dst_endpoint(openshell_ocsf::Endpoint::from_domain(host, port as u16)) + .message(format!("DNS resolution failed for allowed_ips check: {e}")) + .build(); + openshell_ocsf::ocsf_emit!(event); return Vec::new(); } }; if addrs.is_empty() { - tracing::warn!(host, port, "DNS resolution returned no addresses"); + let event = openshell_ocsf::NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(openshell_ocsf::ActivityId::Fail) + .severity(openshell_ocsf::SeverityId::Low) + .dst_endpoint(openshell_ocsf::Endpoint::from_domain(host, port as u16)) + .message(format!( + "DNS resolution returned no addresses for {host}:{port}" + )) + .build(); + openshell_ocsf::ocsf_emit!(event); return Vec::new(); } diff --git a/crates/openshell-sandbox/src/opa.rs b/crates/openshell-sandbox/src/opa.rs index cd2931b35..970c9226c 100644 --- a/crates/openshell-sandbox/src/opa.rs +++ b/crates/openshell-sandbox/src/opa.rs @@ -121,7 +121,15 @@ impl OpaEngine { // Validate BEFORE expanding presets let (errors, warnings) = crate::l7::validate_l7_policies(&data); for w in &warnings { - tracing::warn!(warning = %w, "L7 policy validation warning"); + openshell_ocsf::ocsf_emit!( + openshell_ocsf::ConfigStateChangeBuilder::new(crate::ocsf_ctx()) + .severity(openshell_ocsf::SeverityId::Medium) + .status(openshell_ocsf::StatusId::Success) + .state(openshell_ocsf::StateId::Enabled, "validated") + .unmapped("warning", serde_json::json!(w.to_string())) + .message(format!("L7 policy validation warning: {w}")) + .build() + ); } if !errors.is_empty() { return Err(miette::miette!( @@ -511,7 +519,7 @@ fn parse_process_policy(val: ®orus::Value) -> ProcessPolicy { /// Preprocess YAML policy data: parse, normalize, validate, expand access presets, return JSON. fn preprocess_yaml_data(yaml_str: &str) -> Result { - let mut data: serde_json::Value = serde_yaml::from_str(yaml_str) + let mut data: serde_json::Value = serde_yml::from_str(yaml_str) .map_err(|e| miette::miette!("failed to parse YAML data: {e}"))?; // Normalize port → ports for all endpoints so Rego always sees "ports" array. @@ -520,7 +528,15 @@ fn preprocess_yaml_data(yaml_str: &str) -> Result { // Validate BEFORE expanding presets (catches user errors like rules+access) let (errors, warnings) = crate::l7::validate_l7_policies(&data); for w in &warnings { - tracing::warn!(warning = %w, "L7 policy validation warning"); + openshell_ocsf::ocsf_emit!( + openshell_ocsf::ConfigStateChangeBuilder::new(crate::ocsf_ctx()) + .severity(openshell_ocsf::SeverityId::Medium) + .status(openshell_ocsf::StatusId::Success) + .state(openshell_ocsf::StateId::Enabled, "validated") + .unmapped("warning", serde_json::json!(w.to_string())) + .message(format!("L7 policy validation warning: {w}")) + .build() + ); } if !errors.is_empty() { return Err(miette::miette!( @@ -667,13 +683,35 @@ fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy) -> String { .iter() .map(|r| { let a = r.allow.as_ref(); - serde_json::json!({ - "allow": { - "method": a.map_or("", |a| &a.method), - "path": a.map_or("", |a| &a.path), - "command": a.map_or("", |a| &a.command), - } - }) + let mut allow = serde_json::json!({ + "method": a.map_or("", |a| &a.method), + "path": a.map_or("", |a| &a.path), + "command": a.map_or("", |a| &a.command), + }); + let query: serde_json::Map = a + .map(|allow| { + allow + .query + .iter() + .map(|(key, matcher)| { + let mut matcher_json = serde_json::json!({}); + if !matcher.glob.is_empty() { + matcher_json["glob"] = + matcher.glob.clone().into(); + } + if !matcher.any.is_empty() { + matcher_json["any"] = + matcher.any.clone().into(); + } + (key.clone(), matcher_json) + }) + .collect() + }) + .unwrap_or_default(); + if !query.is_empty() { + allow["query"] = query.into(); + } + serde_json::json!({ "allow": allow }) }) .collect(); ep["rules"] = rules.into(); @@ -714,8 +752,9 @@ mod tests { use super::*; use openshell_core::proto::{ - FilesystemPolicy as ProtoFs, NetworkBinary, NetworkEndpoint, NetworkPolicyRule, - ProcessPolicy as ProtoProc, SandboxPolicy as ProtoSandboxPolicy, + FilesystemPolicy as ProtoFs, L7Allow, L7QueryMatcher, L7Rule, NetworkBinary, + NetworkEndpoint, NetworkPolicyRule, ProcessPolicy as ProtoProc, + SandboxPolicy as ProtoSandboxPolicy, }; const TEST_POLICY: &str = include_str!("../data/sandbox-policy.rego"); @@ -1337,6 +1376,27 @@ network_policies: access: full binaries: - { path: /usr/bin/curl } + query_api: + name: query_api + endpoints: + - host: api.query.com + port: 8080 + protocol: rest + enforcement: enforce + rules: + - allow: + method: GET + path: "/download" + query: + tag: "foo-*" + - allow: + method: GET + path: "/search" + query: + tag: + any: ["foo-*", "bar-*"] + binaries: + - { path: /usr/bin/curl } l4_only: name: l4_only endpoints: @@ -1359,6 +1419,16 @@ process: } fn l7_input(host: &str, port: u16, method: &str, path: &str) -> serde_json::Value { + l7_input_with_query(host, port, method, path, serde_json::json!({})) + } + + fn l7_input_with_query( + host: &str, + port: u16, + method: &str, + path: &str, + query_params: serde_json::Value, + ) -> serde_json::Value { serde_json::json!({ "network": { "host": host, "port": port }, "exec": { @@ -1368,7 +1438,8 @@ process: }, "request": { "method": method, - "path": path + "path": path, + "query_params": query_params } }) } @@ -1472,6 +1543,140 @@ process: assert!(eval_l7(&engine, &input)); } + #[test] + fn l7_query_glob_allows_matching_duplicate_values() { + let engine = l7_engine(); + let input = l7_input_with_query( + "api.query.com", + 8080, + "GET", + "/download", + serde_json::json!({ + "tag": ["foo-a", "foo-b"], + "extra": ["ignored"], + }), + ); + assert!(eval_l7(&engine, &input)); + } + + #[test] + fn l7_query_glob_denies_on_mismatched_duplicate_value() { + let engine = l7_engine(); + let input = l7_input_with_query( + "api.query.com", + 8080, + "GET", + "/download", + serde_json::json!({ + "tag": ["foo-a", "evil"], + }), + ); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_query_any_allows_if_every_value_matches_any_pattern() { + let engine = l7_engine(); + let input = l7_input_with_query( + "api.query.com", + 8080, + "GET", + "/search", + serde_json::json!({ + "tag": ["foo-a", "bar-b"], + }), + ); + assert!(eval_l7(&engine, &input)); + } + + #[test] + fn l7_query_missing_required_key_denied() { + let engine = l7_engine(); + let input = l7_input_with_query( + "api.query.com", + 8080, + "GET", + "/download", + serde_json::json!({}), + ); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_query_rules_from_proto_are_enforced() { + let mut query = std::collections::HashMap::new(); + query.insert( + "tag".to_string(), + L7QueryMatcher { + glob: "foo-*".to_string(), + any: vec![], + }, + ); + + let mut network_policies = std::collections::HashMap::new(); + network_policies.insert( + "query_proto".to_string(), + NetworkPolicyRule { + name: "query_proto".to_string(), + endpoints: vec![NetworkEndpoint { + host: "api.proto.com".to_string(), + port: 8080, + protocol: "rest".to_string(), + enforcement: "enforce".to_string(), + rules: vec![L7Rule { + allow: Some(L7Allow { + method: "GET".to_string(), + path: "/download".to_string(), + command: String::new(), + query, + }), + }], + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/curl".to_string(), + ..Default::default() + }], + }, + ); + + let proto = ProtoSandboxPolicy { + version: 1, + filesystem: Some(ProtoFs { + include_workdir: true, + read_only: vec![], + read_write: vec![], + }), + landlock: Some(openshell_core::proto::LandlockPolicy { + compatibility: "best_effort".to_string(), + }), + process: Some(ProtoProc { + run_as_user: "sandbox".to_string(), + run_as_group: "sandbox".to_string(), + }), + network_policies, + }; + + let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); + let allow_input = l7_input_with_query( + "api.proto.com", + 8080, + "GET", + "/download", + serde_json::json!({ "tag": ["foo-a"] }), + ); + assert!(eval_l7(&engine, &allow_input)); + + let deny_input = l7_input_with_query( + "api.proto.com", + 8080, + "GET", + "/download", + serde_json::json!({ "tag": ["evil"] }), + ); + assert!(!eval_l7(&engine, &deny_input)); + } + #[test] fn l7_no_request_on_l4_only_endpoint() { // L4-only endpoint should not match L7 allow_request diff --git a/crates/openshell-sandbox/src/process.rs b/crates/openshell-sandbox/src/process.rs index b93d125ab..b29682cf0 100644 --- a/crates/openshell-sandbox/src/process.rs +++ b/crates/openshell-sandbox/src/process.rs @@ -20,7 +20,7 @@ use std::os::unix::io::RawFd; use std::path::PathBuf; use std::process::Stdio; use tokio::process::{Child, Command}; -use tracing::{debug, warn}; +use tracing::debug; const SSH_HANDSHAKE_SECRET_ENV: &str = "OPENSHELL_SSH_HANDSHAKE_SECRET"; @@ -325,7 +325,14 @@ impl ProcessHandle { pub fn kill(&mut self) -> Result<()> { // First try SIGTERM if let Err(e) = self.signal(Signal::SIGTERM) { - warn!(error = %e, "Failed to send SIGTERM"); + openshell_ocsf::ocsf_emit!( + openshell_ocsf::ProcessActivityBuilder::new(crate::ocsf_ctx()) + .activity(openshell_ocsf::ActivityId::Close) + .severity(openshell_ocsf::SeverityId::Medium) + .status(openshell_ocsf::StatusId::Failure) + .message(format!("Failed to send SIGTERM: {e}")) + .build() + ); } // Give the process a moment to terminate gracefully diff --git a/crates/openshell-sandbox/src/procfs.rs b/crates/openshell-sandbox/src/procfs.rs index ece16c82a..785a9489e 100644 --- a/crates/openshell-sandbox/src/procfs.rs +++ b/crates/openshell-sandbox/src/procfs.rs @@ -6,10 +6,11 @@ //! Provides functions to resolve binary paths and compute file hashes //! for process-identity binding in the OPA proxy policy engine. -use miette::{IntoDiagnostic, Result}; +use miette::Result; use std::path::Path; #[cfg(target_os = "linux")] use std::path::PathBuf; +use tracing::debug; /// Read the binary path of a process via `/proc/{pid}/exe` symlink. /// @@ -229,8 +230,9 @@ fn parse_proc_net_tcp(pid: u32, peer_port: u16) -> Result { fn find_pid_by_socket_inode(inode: u64, entrypoint_pid: u32) -> Result { let target = format!("socket:[{inode}]"); - // First: scan descendants of the entrypoint process (targeted, most likely to succeed) + // First: scan descendants of the entrypoint process let descendants = collect_descendant_pids(entrypoint_pid); + for &pid in &descendants { if let Some(found) = check_pid_fds(pid, &target) { return Ok(found); @@ -238,7 +240,6 @@ fn find_pid_by_socket_inode(inode: u64, entrypoint_pid: u32) -> Result { } // Fallback: scan all of /proc in case the process isn't in the tree - // (e.g., if /proc//task//children wasn't available) if let Ok(proc_dir) = std::fs::read_dir("/proc") { for entry in proc_dir.flatten() { let name = entry.file_name(); @@ -318,9 +319,32 @@ fn collect_descendant_pids(root_pid: u32) -> Vec { /// same hash, or the request is denied. pub fn file_sha256(path: &Path) -> Result { use sha2::{Digest, Sha256}; + use std::io::Read; + + let start = std::time::Instant::now(); + let mut file = std::fs::File::open(path) + .map_err(|e| miette::miette!("Failed to open {}: {e}", path.display()))?; + let mut hasher = Sha256::new(); + let mut buf = [0u8; 65536]; + let mut total_read = 0u64; + loop { + let n = file + .read(&mut buf) + .map_err(|e| miette::miette!("Failed to read {}: {e}", path.display()))?; + if n == 0 { + break; + } + total_read += n as u64; + hasher.update(&buf[..n]); + } - let bytes = std::fs::read(path).into_diagnostic()?; - let hash = Sha256::digest(&bytes); + let hash = hasher.finalize(); + debug!( + " file_sha256: {}ms size={} path={}", + start.elapsed().as_millis(), + total_read, + path.display() + ); Ok(hex::encode(hash)) } diff --git a/crates/openshell-sandbox/src/proxy.rs b/crates/openshell-sandbox/src/proxy.rs index 1f38a2cca..b52cc60b9 100644 --- a/crates/openshell-sandbox/src/proxy.rs +++ b/crates/openshell-sandbox/src/proxy.rs @@ -10,6 +10,10 @@ use crate::opa::{NetworkAction, OpaEngine}; use crate::policy::ProxyPolicy; use crate::secrets::{SecretResolver, rewrite_header_line}; use miette::{IntoDiagnostic, Result}; +use openshell_ocsf::{ + ActionId, ActivityId, DispositionId, Endpoint, HttpActivityBuilder, HttpRequest, + NetworkActivityBuilder, Process, SeverityId, StatusId, Url as OcsfUrl, ocsf_emit, +}; use std::net::{IpAddr, SocketAddr}; use std::path::PathBuf; use std::sync::Arc; @@ -18,11 +22,17 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc; use tokio::task::JoinHandle; -use tracing::{debug, info, warn}; +use tracing::{debug, warn}; const MAX_HEADER_BYTES: usize = 8192; const INFERENCE_LOCAL_HOST: &str = "inference.local"; +/// Maximum total bytes for a streaming inference response body (32 MiB). +const MAX_STREAMING_BODY: usize = 32 * 1024 * 1024; + +/// Idle timeout per chunk when relaying streaming inference responses. +const CHUNK_IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30); + /// Result of a proxy CONNECT policy decision. struct ConnectDecision { action: NetworkAction, @@ -150,7 +160,16 @@ impl ProxyHandle { let listener = TcpListener::bind(http_addr).await.into_diagnostic()?; let local_addr = listener.local_addr().into_diagnostic()?; - info!(addr = %local_addr, "Proxy listening (tcp)"); + { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Listen) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .dst_endpoint(Endpoint::from_ip(local_addr.ip(), local_addr.port())) + .message(format!("Proxy listening on {local_addr}")) + .build(); + ocsf_emit!(event); + } let join = tokio::spawn(async move { loop { @@ -169,12 +188,24 @@ impl ProxyHandle { ) .await { - warn!(error = %err, "Proxy connection error"); + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .message(format!("Proxy connection error: {err}")) + .build(); + ocsf_emit!(event); } }); } Err(err) => { - warn!(error = %err, "Proxy accept error"); + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .message(format!("Proxy accept error: {err}")) + .build(); + ocsf_emit!(event); break; } } @@ -328,23 +359,43 @@ async fn handle_tcp_connection( ) .await?; if let InferenceOutcome::Denied { reason } = outcome { - info!(action = "deny", reason = %reason, host = INFERENCE_LOCAL_HOST, "Inference interception denied"); + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(INFERENCE_LOCAL_HOST, port)) + .message(format!("Inference interception denied: {reason}")) + .status_detail(&reason) + .build(); + ocsf_emit!(event); } return Ok(()); } let peer_addr = client.peer_addr().into_diagnostic()?; - let local_addr = client.local_addr().into_diagnostic()?; - - // Evaluate OPA policy with process-identity binding - let decision = evaluate_opa_tcp( - peer_addr, - &opa_engine, - &identity_cache, - &entrypoint_pid, - &host_lc, - port, - ); + let _local_addr = client.local_addr().into_diagnostic()?; + + // Evaluate OPA policy with process-identity binding. + // Wrapped in spawn_blocking because identity resolution does heavy sync I/O: + // /proc scanning + SHA256 hashing of binaries (e.g. node at 124MB). + let opa_clone = opa_engine.clone(); + let cache_clone = identity_cache.clone(); + let pid_clone = entrypoint_pid.clone(); + let host_clone = host_lc.clone(); + let decision = tokio::task::spawn_blocking(move || { + evaluate_opa_tcp( + peer_addr, + &opa_clone, + &cache_clone, + &pid_clone, + &host_clone, + port, + ) + }) + .await + .map_err(|e| miette::miette!("identity resolution task panicked: {e}"))?; // Extract action string and matched policy for logging let (matched_policy, deny_reason) = match &decision.action { @@ -386,22 +437,23 @@ async fn handle_tcp_connection( // Allowed connections are logged after the L7 config check (below) // so we can distinguish CONNECT (L4-only) from CONNECT_L7 (L7 follows). if matches!(decision.action, NetworkAction::Deny { .. }) { - info!( - src_addr = %peer_addr.ip(), - src_port = peer_addr.port(), - proxy_addr = %local_addr, - dst_host = %host_lc, - dst_port = port, - binary = %binary_str, - binary_pid = %pid_str, - ancestors = %ancestors_str, - cmdline = %cmdline_str, - action = "deny", - engine = "opa", - policy = "-", - reason = %deny_reason, - "CONNECT", - ); + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint_addr(peer_addr.ip(), peer_addr.port()) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule("-", "opa") + .message(format!("CONNECT denied {host_lc}:{port}")) + .status_detail(&deny_reason) + .build(); + ocsf_emit!(event); emit_denial( &denial_tx, &host_lc, @@ -426,6 +478,7 @@ async fn handle_tcp_connection( } // Defense-in-depth: resolve DNS and reject connections to internal IPs. + let dns_connect_start = std::time::Instant::now(); let mut upstream = if !raw_allowed_ips.is_empty() { // allowed_ips mode: validate resolved IPs against CIDR allowlist. // Loopback and link-local are still always blocked. @@ -435,12 +488,27 @@ async fn handle_tcp_connection( .await .into_diagnostic()?, Err(reason) => { - warn!( - dst_host = %host_lc, - dst_port = port, - reason = %reason, - "CONNECT blocked: allowed_ips check failed" - ); + { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint_addr(peer_addr.ip(), peer_addr.port()) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule("-", "ssrf") + .message(format!( + "CONNECT blocked: allowed_ips check failed for {host_lc}:{port}" + )) + .status_detail(&reason) + .build(); + ocsf_emit!(event); + } emit_denial( &denial_tx, &host_lc, @@ -455,12 +523,27 @@ async fn handle_tcp_connection( } }, Err(reason) => { - warn!( - dst_host = %host_lc, - dst_port = port, - reason = %reason, - "CONNECT blocked: invalid allowed_ips in policy" - ); + { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint_addr(peer_addr.ip(), peer_addr.port()) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule("-", "ssrf") + .message(format!( + "CONNECT blocked: invalid allowed_ips in policy for {host_lc}:{port}" + )) + .status_detail(&reason) + .build(); + ocsf_emit!(event); + } emit_denial( &denial_tx, &host_lc, @@ -481,12 +564,27 @@ async fn handle_tcp_connection( .await .into_diagnostic()?, Err(reason) => { - warn!( - dst_host = %host_lc, - dst_port = port, - reason = %reason, - "CONNECT blocked: internal address" - ); + { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint_addr(peer_addr.ip(), peer_addr.port()) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule("-", "ssrf") + .message(format!( + "CONNECT blocked: internal address {host_lc}:{port}" + )) + .status_detail(&reason) + .build(); + ocsf_emit!(event); + } emit_denial( &denial_tx, &host_lc, @@ -502,6 +600,11 @@ async fn handle_tcp_connection( } }; + debug!( + "handle_tcp_connection dns_resolve_and_tcp_connect: {}ms host={host_lc}", + dns_connect_start.elapsed().as_millis() + ); + respond(&mut client, b"HTTP/1.1 200 Connection Established\r\n\r\n").await?; // Check if endpoint has L7 config for protocol-aware inspection @@ -514,22 +617,24 @@ async fn handle_tcp_connection( } else { "CONNECT" }; - info!( - src_addr = %peer_addr.ip(), - src_port = peer_addr.port(), - proxy_addr = %local_addr, - dst_host = %host_lc, - dst_port = port, - binary = %binary_str, - binary_pid = %pid_str, - ancestors = %ancestors_str, - cmdline = %cmdline_str, - action = "allow", - engine = "opa", - policy = %policy_str, - reason = "", - "{connect_msg}", - ); + { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Allowed) + .disposition(DispositionId::Allowed) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint_addr(peer_addr.ip(), peer_addr.port()) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule(policy_str, "opa") + .message(format!("{connect_msg} allowed {host_lc}:{port}")) + .build(); + ocsf_emit!(event); + } // Determine effective TLS mode. Check the raw endpoint config for // `tls: skip` independently of L7 config (which requires `protocol`). @@ -594,11 +699,19 @@ async fn handle_tcp_connection( if let Some(ref l7_config) = l7_config { // L7 inspection on terminated TLS traffic. - let tunnel_engine = - opa_engine.clone_engine_for_tunnel().unwrap_or_else(|e| { - warn!(error = %e, "Failed to clone OPA engine for L7, falling back to relay-only"); - regorus::Engine::new() - }); + let tunnel_engine = opa_engine.clone_engine_for_tunnel().unwrap_or_else(|e| { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .message(format!( + "Failed to clone OPA engine for L7, falling back to relay-only: {e}" + )) + .build(); + ocsf_emit!(event); + regorus::Engine::new() + }); crate::l7::relay::relay_with_inspection( l7_config, std::sync::Mutex::new(tunnel_engine), @@ -626,20 +739,29 @@ async fn handle_tcp_connection( "TLS connection closed" ); } else { - warn!( - host = %host_lc, - port = port, - error = %e, - "TLS relay error" - ); + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .message(format!("TLS relay error: {e}")) + .build(); + ocsf_emit!(event); } } } else { - warn!( - host = %host_lc, - port = port, - "TLS detected but TLS state not configured, falling back to raw tunnel" - ); + { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .message(format!( + "TLS detected but TLS state not configured for {host_lc}:{port}, falling back to raw tunnel" + )) + .build(); + ocsf_emit!(event); + } let _ = tokio::io::copy_bidirectional(&mut client, &mut upstream) .await .into_diagnostic()?; @@ -648,7 +770,16 @@ async fn handle_tcp_connection( // Plaintext HTTP detected. if let Some(ref l7_config) = l7_config { let tunnel_engine = opa_engine.clone_engine_for_tunnel().unwrap_or_else(|e| { - warn!(error = %e, "Failed to clone OPA engine for L7, falling back to relay-only"); + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .message(format!( + "Failed to clone OPA engine for L7, falling back to relay-only: {e}" + )) + .build(); + ocsf_emit!(event); regorus::Engine::new() }); if let Err(e) = crate::l7::relay::relay_with_inspection( @@ -663,7 +794,14 @@ async fn handle_tcp_connection( if is_benign_relay_error(&e) { debug!(host = %host_lc, port = port, error = %e, "L7 connection closed"); } else { - warn!(host = %host_lc, port = port, error = %e, "L7 relay error"); + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .message(format!("L7 relay error: {e}")) + .build(); + ocsf_emit!(event); } } } else { @@ -678,7 +816,14 @@ async fn handle_tcp_connection( if is_benign_relay_error(&e) { debug!(host = %host_lc, port = port, error = %e, "HTTP relay closed"); } else { - warn!(host = %host_lc, port = port, error = %e, "HTTP relay error"); + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .message(format!("HTTP relay error: {e}")) + .build(); + ocsf_emit!(event); } } } @@ -736,7 +881,9 @@ fn evaluate_opa_tcp( ); } + let total_start = std::time::Instant::now(); let peer_port = peer_addr.port(); + let (bin_path, binary_pid) = match crate::procfs::resolve_tcp_peer_identity(pid, peer_port) { Ok(r) => r, Err(e) => { @@ -767,7 +914,6 @@ fn evaluate_opa_tcp( // Walk the process tree upward to collect ancestor binaries let ancestors = crate::procfs::collect_ancestor_binaries(binary_pid, pid); - // TOFU verify each ancestor binary for ancestor in &ancestors { if let Err(e) = identity_cache.verify_or_cache(ancestor) { return deny( @@ -784,7 +930,6 @@ fn evaluate_opa_tcp( } // Collect cmdline paths for script-based binary detection. - // Excludes exe paths already captured in bin_path/ancestors to avoid duplicates. let mut exclude = ancestors.clone(); exclude.push(bin_path.clone()); let cmdline_paths = crate::procfs::collect_cmdline_paths(binary_pid, pid, &exclude); @@ -798,7 +943,7 @@ fn evaluate_opa_tcp( cmdline_paths: cmdline_paths.clone(), }; - match engine.evaluate_network_action(&input) { + let result = match engine.evaluate_network_action(&input) { Ok(action) => ConnectDecision { action, binary: Some(bin_path), @@ -813,7 +958,12 @@ fn evaluate_opa_tcp( ancestors, cmdline_paths, ), - } + }; + debug!( + "evaluate_opa_tcp TOTAL: {}ms host={host} port={port}", + total_start.elapsed().as_millis() + ); + result } /// Non-Linux stub: OPA identity binding requires /proc. @@ -854,7 +1004,7 @@ const INITIAL_INFERENCE_BUF: usize = 65536; async fn handle_inference_interception( client: TcpStream, host: &str, - _port: u16, + port: u16, tls_state: Option<&Arc>, inference_ctx: Option<&Arc>, ) -> Result { @@ -943,6 +1093,24 @@ async fn handle_inference_interception( buf.resize((buf.len() * 2).min(MAX_INFERENCE_BUF), 0); } } + ParseResult::Invalid(reason) => { + { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Refuse) + .action(ActionId::Denied) + .disposition(DispositionId::Rejected) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(INFERENCE_LOCAL_HOST, port)) + .message(format!("Rejecting malformed inference request: {reason}")) + .status_detail(&reason) + .build(); + ocsf_emit!(event); + } + let response = format_http_response(400, &[], b"Bad Request"); + write_all(&mut tls_client, &response).await?; + return Ok(InferenceOutcome::Denied { reason }); + } } } @@ -965,13 +1133,21 @@ async fn route_inference_request( if let Some(pattern) = detect_inference_pattern(&request.method, &normalized_path, &ctx.patterns) { - info!( - method = %request.method, - path = %normalized_path, - protocol = %pattern.protocol, - kind = %pattern.kind, - "Intercepted inference request, routing locally" - ); + { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Allowed) + .disposition(DispositionId::Detected) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .dst_endpoint(Endpoint::from_domain(INFERENCE_LOCAL_HOST, 443)) + .message(format!( + "Intercepted inference request, routing locally: {} {} (protocol={}, kind={})", + request.method, normalized_path, pattern.protocol, pattern.kind + )) + .build(); + ocsf_emit!(event); + } // Strip credential + framing/hop-by-hop headers. let filtered_headers = sanitize_inference_request_headers(&request.headers); @@ -1018,16 +1194,44 @@ async fn route_inference_request( let header_bytes = format_http_response_header(resp.status, &resp_headers); write_all(tls_client, &header_bytes).await?; - // Stream body chunks as they arrive from the upstream. + // Stream body chunks with byte cap and idle timeout. + let mut total_bytes: usize = 0; loop { - match resp.next_chunk().await { - Ok(Some(chunk)) => { + match tokio::time::timeout(CHUNK_IDLE_TIMEOUT, resp.next_chunk()).await { + Ok(Ok(Some(chunk))) => { + total_bytes += chunk.len(); + if total_bytes > MAX_STREAMING_BODY { + warn!( + total_bytes = total_bytes, + limit = MAX_STREAMING_BODY, + "streaming response exceeded byte limit, truncating" + ); + break; + } let encoded = format_chunk(&chunk); write_all(tls_client, &encoded).await?; } - Ok(None) => break, - Err(e) => { - warn!(error = %e, "error reading upstream response chunk"); + Ok(Ok(None)) => break, + Ok(Err(e)) => { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(INFERENCE_LOCAL_HOST, 443)) + .message(format!("error reading upstream response chunk: {e}")) + .build(); + ocsf_emit!(event); + break; + } + Err(_) => { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(INFERENCE_LOCAL_HOST, 443)) + .message("streaming response chunk idle timeout, closing") + .build(); + ocsf_emit!(event); break; } } @@ -1037,7 +1241,18 @@ async fn route_inference_request( write_all(tls_client, format_chunk_terminator()).await?; } Err(e) => { - warn!(error = %e, "inference endpoint detected but upstream service failed"); + { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(INFERENCE_LOCAL_HOST, 443)) + .message(format!( + "inference endpoint detected but upstream service failed: {e}" + )) + .build(); + ocsf_emit!(event); + } let (status, msg) = router_error_to_http(&e); let body = serde_json::json!({"error": msg}); let body_bytes = body.to_string(); @@ -1052,11 +1267,21 @@ async fn route_inference_request( Ok(true) } else { // Not an inference request — deny - info!( - method = %request.method, - path = %normalized_path, - "connection not allowed by policy" - ); + { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(INFERENCE_LOCAL_HOST, 443)) + .message(format!( + "connection not allowed by policy: {} {}", + request.method, normalized_path + )) + .build(); + ocsf_emit!(event); + } let body = serde_json::json!({"error": "connection not allowed by policy"}); let body_bytes = body.to_string(); let response = format_http_response( @@ -1172,7 +1397,14 @@ fn query_l7_config( Ok(Some(val)) => crate::l7::parse_l7_config(&val), Ok(None) => None, Err(e) => { - warn!(error = %e, "Failed to query L7 endpoint config"); + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(host, port)) + .message(format!("Failed to query L7 endpoint config: {e}")) + .build(); + ocsf_emit!(event); None } } @@ -1407,12 +1639,16 @@ fn parse_allowed_ips(raw: &[String]) -> std::result::Result, S match parsed { Ok(n) => { if n.prefix_len() < MIN_SAFE_PREFIX_LEN { - warn!( - cidr = %n, - prefix_len = n.prefix_len(), - "allowed_ips entry has a very broad CIDR (< /{MIN_SAFE_PREFIX_LEN}); \ - this may expose control-plane services on the same network" - ); + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .severity(SeverityId::Medium) + .message(format!( + "allowed_ips entry has a very broad CIDR {n} (/{}) < /{MIN_SAFE_PREFIX_LEN}; \ + this may expose control-plane services on the same network", + n.prefix_len() + )) + .build(); + ocsf_emit!(event); } nets.push(n); } @@ -1455,7 +1691,16 @@ fn query_allowed_ips( match engine.query_allowed_ips(&input) { Ok(ips) => ips, Err(e) => { - warn!(error = %e, "Failed to query allowed_ips from endpoint config"); + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(host, port)) + .message(format!( + "Failed to query allowed_ips from endpoint config: {e}" + )) + .build(); + ocsf_emit!(event); vec![] } } @@ -1590,23 +1835,14 @@ fn rewrite_forward_request( used: usize, path: &str, secret_resolver: Option<&SecretResolver>, -) -> Vec { +) -> Result, crate::secrets::UnresolvedPlaceholderError> { let header_end = raw[..used] .windows(4) .position(|w| w == b"\r\n\r\n") .map_or(used, |p| p + 4); let header_str = String::from_utf8_lossy(&raw[..header_end]); - let mut lines = header_str.split("\r\n").collect::>(); - - // Rewrite request line: METHOD absolute-uri HTTP/1.1 → METHOD path HTTP/1.1 - if let Some(first_line) = lines.first_mut() { - let parts: Vec<&str> = first_line.splitn(3, ' ').collect(); - if parts.len() == 3 { - let new_line = format!("{} {} {}", parts[0], path, parts[2]); - *first_line = Box::leak(new_line.into_boxed_str()); // safe: short-lived - } - } + let lines = header_str.split("\r\n").collect::>(); // Rebuild headers, stripping hop-by-hop and adding proxy headers let mut output = Vec::with_capacity(header_end + 128); @@ -1615,8 +1851,17 @@ fn rewrite_forward_request( for (i, line) in lines.iter().enumerate() { if i == 0 { - // Request line — already rewritten - output.extend_from_slice(line.as_bytes()); + // Rewrite request line: METHOD absolute-uri HTTP/1.1 → METHOD path HTTP/1.1 + let parts: Vec<&str> = line.splitn(3, ' ').collect(); + if parts.len() == 3 { + output.extend_from_slice(parts[0].as_bytes()); + output.push(b' '); + output.extend_from_slice(path.as_bytes()); + output.push(b' '); + output.extend_from_slice(parts[2].as_bytes()); + } else { + output.extend_from_slice(line.as_bytes()); + } output.extend_from_slice(b"\r\n"); continue; } @@ -1671,7 +1916,15 @@ fn rewrite_forward_request( output.extend_from_slice(&raw[header_end..used]); } - output + // Fail-closed: scan for any remaining unresolved placeholders + if secret_resolver.is_some() { + let output_str = String::from_utf8_lossy(&output); + if output_str.contains(crate::secrets::PLACEHOLDER_PREFIX_PUBLIC) { + return Err(crate::secrets::UnresolvedPlaceholderError { location: "header" }); + } + } + + Ok(output) } /// Handle a plain HTTP forward proxy request (non-CONNECT). @@ -1696,7 +1949,13 @@ async fn handle_forward_proxy( let (scheme, host, port, path) = match parse_proxy_uri(target_uri) { Ok(parsed) => parsed, Err(e) => { - warn!(target_uri = %target_uri, error = %e, "FORWARD parse error"); + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .message(format!("FORWARD parse error for {target_uri}: {e}")) + .build(); + ocsf_emit!(event); respond(client, b"HTTP/1.1 400 Bad Request\r\n\r\n").await?; return Ok(()); } @@ -1705,11 +1964,20 @@ async fn handle_forward_proxy( // 2. Reject HTTPS — must use CONNECT for TLS if scheme == "https" { - info!( - dst_host = %host_lc, - dst_port = port, - "FORWARD rejected: HTTPS requires CONNECT" - ); + { + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Refuse) + .action(ActionId::Denied) + .disposition(DispositionId::Rejected) + .severity(SeverityId::Informational) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .message(format!( + "FORWARD rejected: HTTPS requires CONNECT for {host_lc}:{port}" + )) + .build(); + ocsf_emit!(event); + } respond( client, b"HTTP/1.1 400 Bad Request\r\nContent-Length: 27\r\n\r\nUse CONNECT for HTTPS URLs", @@ -1720,16 +1988,24 @@ async fn handle_forward_proxy( // 3. Evaluate OPA policy (same identity binding as CONNECT) let peer_addr = client.peer_addr().into_diagnostic()?; - let local_addr = client.local_addr().into_diagnostic()?; - - let decision = evaluate_opa_tcp( - peer_addr, - &opa_engine, - &identity_cache, - &entrypoint_pid, - &host_lc, - port, - ); + let _local_addr = client.local_addr().into_diagnostic()?; + + let opa_clone = opa_engine.clone(); + let cache_clone = identity_cache.clone(); + let pid_clone = entrypoint_pid.clone(); + let host_clone = host_lc.clone(); + let decision = tokio::task::spawn_blocking(move || { + evaluate_opa_tcp( + peer_addr, + &opa_clone, + &cache_clone, + &pid_clone, + &host_clone, + port, + ) + }) + .await + .map_err(|e| miette::miette!("identity resolution task panicked: {e}"))?; // Build log context let binary_str = decision @@ -1764,24 +2040,28 @@ async fn handle_forward_proxy( let matched_policy = match &decision.action { NetworkAction::Allow { matched_policy } => matched_policy.clone(), NetworkAction::Deny { reason } => { - info!( - src_addr = %peer_addr.ip(), - src_port = peer_addr.port(), - proxy_addr = %local_addr, - dst_host = %host_lc, - dst_port = port, - method = %method, - path = %path, - binary = %binary_str, - binary_pid = %pid_str, - ancestors = %ancestors_str, - cmdline = %cmdline_str, - action = "deny", - engine = "opa", - policy = "-", - reason = %reason, - "FORWARD", - ); + { + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .http_request(HttpRequest::new( + method, + OcsfUrl::new("http", &host_lc, &path, port), + )) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule("-", "opa") + .message(format!("FORWARD denied {method} {host_lc}:{port}{path}")) + .build(); + ocsf_emit!(event); + } emit_denial_simple( denial_tx, &host_lc, @@ -1797,32 +2077,131 @@ async fn handle_forward_proxy( }; let policy_str = matched_policy.as_deref().unwrap_or("-"); - // 4b. Reject if the endpoint has L7 config — the forward proxy path does - // not perform per-request method/path inspection, so L7-configured - // endpoints must go through the CONNECT tunnel where inspection happens. - if query_l7_config(&opa_engine, &decision, &host_lc, port).is_some() { - info!( - dst_host = %host_lc, - dst_port = port, - method = %method, - path = %path, - binary = %binary_str, - policy = %policy_str, - action = "deny", - reason = "endpoint has L7 rules; use CONNECT", - "FORWARD", - ); - emit_denial_simple( - denial_tx, - &host_lc, + // 4b. If the endpoint has L7 config, evaluate the request against + // L7 policy. The forward proxy handles exactly one request per + // connection (Connection: close), so a single evaluation suffices. + if let Some(l7_config) = query_l7_config(&opa_engine, &decision, &host_lc, port) { + let tunnel_engine = opa_engine.clone_engine_for_tunnel().unwrap_or_else(|e| { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .message(format!("Failed to clone OPA engine for forward L7: {e}")) + .build(); + ocsf_emit!(event); + regorus::Engine::new() + }); + let engine_mutex = std::sync::Mutex::new(tunnel_engine); + + let l7_ctx = crate::l7::relay::L7EvalContext { + host: host_lc.clone(), port, - &binary_str, - &decision, - "endpoint has L7 rules configured; forward proxy bypasses L7 inspection — use CONNECT", - "forward-l7-bypass", - ); - respond(client, b"HTTP/1.1 403 Forbidden\r\n\r\n").await?; - return Ok(()); + policy_name: matched_policy.clone().unwrap_or_default(), + binary_path: decision + .binary + .as_ref() + .map(|p| p.to_string_lossy().into_owned()) + .unwrap_or_default(), + ancestors: decision + .ancestors + .iter() + .map(|p| p.to_string_lossy().into_owned()) + .collect(), + cmdline_paths: decision + .cmdline_paths + .iter() + .map(|p| p.to_string_lossy().into_owned()) + .collect(), + secret_resolver: secret_resolver.clone(), + }; + + let (target_path, query_params) = crate::l7::rest::parse_target_query(&path) + .unwrap_or_else(|_| (path.clone(), std::collections::HashMap::new())); + let request_info = crate::l7::L7RequestInfo { + action: method.to_string(), + target: target_path, + query_params, + }; + + let (allowed, reason) = + crate::l7::relay::evaluate_l7_request(&engine_mutex, &l7_ctx, &request_info) + .unwrap_or_else(|e| { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .message(format!("L7 eval failed, denying request: {e}")) + .build(); + ocsf_emit!(event); + (false, format!("L7 evaluation error: {e}")) + }); + + let decision_str = match (allowed, l7_config.enforcement) { + (true, _) => "allow", + (false, crate::l7::EnforcementMode::Audit) => "audit", + (false, crate::l7::EnforcementMode::Enforce) => "deny", + }; + + { + let (action_id, disposition_id, severity) = match decision_str { + "allow" => ( + ActionId::Allowed, + DispositionId::Allowed, + SeverityId::Informational, + ), + "deny" => (ActionId::Denied, DispositionId::Blocked, SeverityId::Medium), + "audit" => ( + ActionId::Allowed, + DispositionId::Allowed, + SeverityId::Informational, + ), + _ => ( + ActionId::Other, + DispositionId::Other, + SeverityId::Informational, + ), + }; + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(action_id) + .disposition(disposition_id) + .severity(severity) + .http_request(HttpRequest::new( + method, + OcsfUrl::new("http", &host_lc, &path, port), + )) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule(policy_str, "l7") + .message(format!( + "FORWARD_L7 {decision_str} {method} {host_lc}:{port}{path} reason={reason}" + )) + .build(); + ocsf_emit!(event); + } + + let effectively_denied = + !allowed && l7_config.enforcement == crate::l7::EnforcementMode::Enforce; + + if effectively_denied { + emit_denial_simple( + denial_tx, + &host_lc, + port, + &binary_str, + &decision, + &reason, + "forward-l7-deny", + ); + respond(client, b"HTTP/1.1 403 Forbidden\r\n\r\n").await?; + return Ok(()); + } } // 5. DNS resolution + SSRF defence (mirrors the CONNECT path logic). @@ -1842,12 +2221,30 @@ async fn handle_forward_proxy( Ok(nets) => match resolve_and_check_allowed_ips(&host, port, &nets).await { Ok(addrs) => addrs, Err(reason) => { - warn!( - dst_host = %host_lc, - dst_port = port, - reason = %reason, - "FORWARD blocked: allowed_ips check failed" - ); + { + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .http_request(HttpRequest::new( + method, + OcsfUrl::new("http", &host_lc, &path, port), + )) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule(policy_str, "ssrf") + .message(format!( + "FORWARD blocked: allowed_ips check failed for {host_lc}:{port}: {reason}" + )) + .build(); + ocsf_emit!(event); + } emit_denial_simple( denial_tx, &host_lc, @@ -1862,12 +2259,30 @@ async fn handle_forward_proxy( } }, Err(reason) => { - warn!( - dst_host = %host_lc, - dst_port = port, - reason = %reason, - "FORWARD blocked: invalid allowed_ips in policy" - ); + { + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .http_request(HttpRequest::new( + method, + OcsfUrl::new("http", &host_lc, &path, port), + )) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule(policy_str, "ssrf") + .message(format!( + "FORWARD blocked: invalid allowed_ips in policy for {host_lc}:{port}: {reason}" + )) + .build(); + ocsf_emit!(event); + } emit_denial_simple( denial_tx, &host_lc, @@ -1886,12 +2301,30 @@ async fn handle_forward_proxy( match resolve_and_reject_internal(&host, port).await { Ok(addrs) => addrs, Err(reason) => { - warn!( - dst_host = %host_lc, - dst_port = port, - reason = %reason, - "FORWARD blocked: internal IP without allowed_ips" - ); + { + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .http_request(HttpRequest::new( + method, + OcsfUrl::new("http", &host_lc, &path, port), + )) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule(policy_str, "ssrf") + .message(format!( + "FORWARD blocked: internal IP without allowed_ips for {host_lc}:{port}: {reason}" + )) + .build(); + ocsf_emit!(event); + } emit_denial_simple( denial_tx, &host_lc, @@ -1911,39 +2344,68 @@ async fn handle_forward_proxy( let mut upstream = match TcpStream::connect(addrs.as_slice()).await { Ok(s) => s, Err(e) => { - warn!( - dst_host = %host_lc, - dst_port = port, - error = %e, - "FORWARD upstream connect failed" - ); + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .http_request(HttpRequest::new( + method, + OcsfUrl::new("http", &host_lc, &path, port), + )) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .message(format!( + "FORWARD upstream connect failed for {host_lc}:{port}: {e}" + )) + .build(); + ocsf_emit!(event); respond(client, b"HTTP/1.1 502 Bad Gateway\r\n\r\n").await?; return Ok(()); } }; // Log success - info!( - src_addr = %peer_addr.ip(), - src_port = peer_addr.port(), - proxy_addr = %local_addr, - dst_host = %host_lc, - dst_port = port, - method = %method, - path = %path, - binary = %binary_str, - binary_pid = %pid_str, - ancestors = %ancestors_str, - cmdline = %cmdline_str, - action = "allow", - engine = "opa", - policy = %policy_str, - reason = "", - "FORWARD", - ); + { + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(ActionId::Allowed) + .disposition(DispositionId::Allowed) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .http_request(HttpRequest::new( + method, + OcsfUrl::new("http", &host_lc, &path, port), + )) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule(policy_str, "opa") + .message(format!("FORWARD allowed {method} {host_lc}:{port}{path}")) + .build(); + ocsf_emit!(event); + } // 9. Rewrite request and forward to upstream - let rewritten = rewrite_forward_request(buf, used, &path, secret_resolver.as_deref()); + let rewritten = match rewrite_forward_request(buf, used, &path, secret_resolver.as_deref()) { + Ok(bytes) => bytes, + Err(e) => { + warn!( + dst_host = %host_lc, + dst_port = port, + error = %e, + "credential injection failed in forward proxy" + ); + respond(client, b"HTTP/1.1 500 Internal Server Error\r\n\r\n").await?; + return Ok(()); + } + }; upstream.write_all(&rewritten).await.into_diagnostic()?; // 8. Relay remaining traffic bidirectionally (supports streaming) @@ -2643,7 +3105,7 @@ mod tests { fn test_rewrite_get_request() { let raw = b"GET http://10.0.0.1:8000/api HTTP/1.1\r\nHost: 10.0.0.1:8000\r\nAccept: */*\r\n\r\n"; - let result = rewrite_forward_request(raw, raw.len(), "/api", None); + let result = rewrite_forward_request(raw, raw.len(), "/api", None).expect("should succeed"); let result_str = String::from_utf8_lossy(&result); assert!(result_str.starts_with("GET /api HTTP/1.1\r\n")); assert!(result_str.contains("Host: 10.0.0.1:8000")); @@ -2654,7 +3116,7 @@ mod tests { #[test] fn test_rewrite_strips_proxy_headers() { let raw = b"GET http://host/p HTTP/1.1\r\nHost: host\r\nProxy-Authorization: Basic abc\r\nProxy-Connection: keep-alive\r\nAccept: */*\r\n\r\n"; - let result = rewrite_forward_request(raw, raw.len(), "/p", None); + let result = rewrite_forward_request(raw, raw.len(), "/p", None).expect("should succeed"); let result_str = String::from_utf8_lossy(&result); assert!( !result_str @@ -2668,7 +3130,7 @@ mod tests { #[test] fn test_rewrite_replaces_connection_header() { let raw = b"GET http://host/p HTTP/1.1\r\nHost: host\r\nConnection: keep-alive\r\n\r\n"; - let result = rewrite_forward_request(raw, raw.len(), "/p", None); + let result = rewrite_forward_request(raw, raw.len(), "/p", None).expect("should succeed"); let result_str = String::from_utf8_lossy(&result); assert!(result_str.contains("Connection: close")); assert!(!result_str.contains("keep-alive")); @@ -2677,7 +3139,7 @@ mod tests { #[test] fn test_rewrite_preserves_body_overflow() { let raw = b"POST http://host/api HTTP/1.1\r\nHost: host\r\nContent-Length: 13\r\n\r\n{\"key\":\"val\"}"; - let result = rewrite_forward_request(raw, raw.len(), "/api", None); + let result = rewrite_forward_request(raw, raw.len(), "/api", None).expect("should succeed"); let result_str = String::from_utf8_lossy(&result); assert!(result_str.contains("{\"key\":\"val\"}")); assert!(result_str.contains("POST /api HTTP/1.1")); @@ -2686,7 +3148,7 @@ mod tests { #[test] fn test_rewrite_preserves_existing_via() { let raw = b"GET http://host/p HTTP/1.1\r\nHost: host\r\nVia: 1.0 upstream\r\n\r\n"; - let result = rewrite_forward_request(raw, raw.len(), "/p", None); + let result = rewrite_forward_request(raw, raw.len(), "/p", None).expect("should succeed"); let result_str = String::from_utf8_lossy(&result); assert!(result_str.contains("Via: 1.0 upstream")); // Should not add a second Via header @@ -2701,7 +3163,8 @@ mod tests { .collect(), ); let raw = b"GET http://host/p HTTP/1.1\r\nHost: host\r\nAuthorization: Bearer openshell:resolve:env:ANTHROPIC_API_KEY\r\n\r\n"; - let result = rewrite_forward_request(raw, raw.len(), "/p", resolver.as_ref()); + let result = rewrite_forward_request(raw, raw.len(), "/p", resolver.as_ref()) + .expect("should succeed"); let result_str = String::from_utf8_lossy(&result); assert!(result_str.contains("Authorization: Bearer sk-test")); assert!(!result_str.contains("openshell:resolve:env:ANTHROPIC_API_KEY")); diff --git a/crates/openshell-sandbox/src/sandbox/linux/landlock.rs b/crates/openshell-sandbox/src/sandbox/linux/landlock.rs index 2b9873b50..4dcc55449 100644 --- a/crates/openshell-sandbox/src/sandbox/linux/landlock.rs +++ b/crates/openshell-sandbox/src/sandbox/linux/landlock.rs @@ -5,11 +5,11 @@ use crate::policy::{LandlockCompatibility, SandboxPolicy}; use landlock::{ - ABI, Access, AccessFs, CompatLevel, Compatible, PathBeneath, PathFd, Ruleset, RulesetAttr, - RulesetCreatedAttr, + ABI, Access, AccessFs, CompatLevel, Compatible, PathBeneath, PathFd, PathFdError, Ruleset, + RulesetAttr, RulesetCreatedAttr, }; use miette::{IntoDiagnostic, Result}; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use tracing::debug; pub fn apply(policy: &SandboxPolicy, workdir: Option<&str>) -> Result<()> { @@ -29,49 +29,102 @@ pub fn apply(policy: &SandboxPolicy, workdir: Option<&str>) -> Result<()> { return Ok(()); } + let total_paths = read_only.len() + read_write.len(); + let abi = ABI::V2; + openshell_ocsf::ocsf_emit!( + openshell_ocsf::ConfigStateChangeBuilder::new(crate::ocsf_ctx()) + .severity(openshell_ocsf::SeverityId::Informational) + .status(openshell_ocsf::StatusId::Success) + .state(openshell_ocsf::StateId::Enabled, "applying") + .message(format!( + "Applying Landlock filesystem sandbox [abi:{abi:?} compat:{:?} ro:{} rw:{}]", + policy.landlock.compatibility, + read_only.len(), + read_write.len(), + )) + .build() + ); + + let compatibility = &policy.landlock.compatibility; + let result: Result<()> = (|| { - let abi = ABI::V2; let access_all = AccessFs::from_all(abi); let access_read = AccessFs::from_read(abi); let mut ruleset = Ruleset::default(); ruleset = ruleset - .set_compatibility(compat_level(&policy.landlock.compatibility)) + .set_compatibility(compat_level(compatibility)) .handle_access(access_all) .into_diagnostic()?; let mut ruleset = ruleset.create().into_diagnostic()?; + let mut rules_applied: usize = 0; - for path in read_only { - debug!(path = %path.display(), "Landlock allow read-only"); - ruleset = ruleset - .add_rule(PathBeneath::new( - PathFd::new(path).into_diagnostic()?, - access_read, - )) - .into_diagnostic()?; + for path in &read_only { + if let Some(path_fd) = try_open_path(path, compatibility)? { + debug!(path = %path.display(), "Landlock allow read-only"); + ruleset = ruleset + .add_rule(PathBeneath::new(path_fd, access_read)) + .into_diagnostic()?; + rules_applied += 1; + } } - for path in read_write { - debug!(path = %path.display(), "Landlock allow read-write"); - ruleset = ruleset - .add_rule(PathBeneath::new( - PathFd::new(path).into_diagnostic()?, - access_all, - )) - .into_diagnostic()?; + for path in &read_write { + if let Some(path_fd) = try_open_path(path, compatibility)? { + debug!(path = %path.display(), "Landlock allow read-write"); + ruleset = ruleset + .add_rule(PathBeneath::new(path_fd, access_all)) + .into_diagnostic()?; + rules_applied += 1; + } + } + + if rules_applied == 0 { + return Err(miette::miette!( + "Landlock ruleset has zero valid paths — all {} path(s) failed to open. \ + Refusing to apply an empty ruleset that would block all filesystem access.", + total_paths, + )); } + let skipped = total_paths - rules_applied; + openshell_ocsf::ocsf_emit!( + openshell_ocsf::ConfigStateChangeBuilder::new(crate::ocsf_ctx()) + .severity(openshell_ocsf::SeverityId::Informational) + .status(openshell_ocsf::StatusId::Success) + .state(openshell_ocsf::StateId::Enabled, "built") + .message(format!( + "Landlock ruleset built [rules_applied:{rules_applied} skipped:{skipped}]" + )) + .build() + ); + ruleset.restrict_self().into_diagnostic()?; Ok(()) })(); if let Err(err) = result { - if matches!( - policy.landlock.compatibility, - LandlockCompatibility::BestEffort - ) { - debug!(error = %err, "Landlock unavailable, continuing without filesystem sandbox"); + if matches!(compatibility, LandlockCompatibility::BestEffort) { + openshell_ocsf::ocsf_emit!( + openshell_ocsf::DetectionFindingBuilder::new(crate::ocsf_ctx()) + .activity(openshell_ocsf::ActivityId::Open) + .severity(openshell_ocsf::SeverityId::High) + .confidence(openshell_ocsf::ConfidenceId::High) + .is_alert(true) + .finding_info( + openshell_ocsf::FindingInfo::new( + "landlock-unavailable", + "Landlock Filesystem Sandbox Unavailable", + ) + .with_desc(&format!( + "Running WITHOUT filesystem restrictions: {err}. \ + Set landlock.compatibility to 'hard_requirement' to make this fatal." + )), + ) + .message(format!("Landlock filesystem sandbox unavailable: {err}")) + .build() + ); return Ok(()); } return Err(err); @@ -80,9 +133,182 @@ pub fn apply(policy: &SandboxPolicy, workdir: Option<&str>) -> Result<()> { Ok(()) } +/// Attempt to open a path for Landlock rule creation. +/// +/// In `BestEffort` mode, inaccessible paths (missing, permission denied, symlink +/// loops, etc.) are skipped with a warning and `Ok(None)` is returned so the +/// caller can continue building the ruleset from the remaining valid paths. +/// +/// In `HardRequirement` mode, any failure is fatal — the caller propagates the +/// error, which ultimately aborts sandbox startup. +fn try_open_path(path: &Path, compatibility: &LandlockCompatibility) -> Result> { + match PathFd::new(path) { + Ok(fd) => Ok(Some(fd)), + Err(err) => { + let reason = classify_path_fd_error(&err); + let is_not_found = matches!( + &err, + PathFdError::OpenCall { source, .. } + if source.kind() == std::io::ErrorKind::NotFound + ); + match compatibility { + LandlockCompatibility::BestEffort => { + // NotFound is expected for stale baseline paths (e.g. + // /app baked into the server-stored policy but absent + // in this container image). Log at debug! to avoid + // polluting SSH exec stdout — the pre_exec hook + // inherits the tracing subscriber whose writer targets + // fd 1 (the pipe/PTY). + // + // Other errors (permission denied, symlink loops, etc.) + // are genuinely unexpected and logged at warn!. + if is_not_found { + debug!( + path = %path.display(), + reason, + "Skipping non-existent Landlock path (best-effort mode)" + ); + } else { + openshell_ocsf::ocsf_emit!( + openshell_ocsf::ConfigStateChangeBuilder::new(crate::ocsf_ctx()) + .severity(openshell_ocsf::SeverityId::Medium) + .status(openshell_ocsf::StatusId::Failure) + .state(openshell_ocsf::StateId::Other, "degraded") + .message(format!( + "Skipping inaccessible Landlock path (best-effort) [path:{} error:{err}]", + path.display() + )) + .build() + ); + } + Ok(None) + } + LandlockCompatibility::HardRequirement => Err(miette::miette!( + "Landlock path unavailable in hard_requirement mode: {} ({}): {}", + path.display(), + reason, + err, + )), + } + } + } +} + +/// Classify a [`PathFdError`] into a human-readable reason. +/// +/// `PathFd::new()` wraps `open(path, O_PATH | O_CLOEXEC)` which can fail for +/// several reasons beyond simple non-existence. The `PathFdError::OpenCall` +/// variant wraps the underlying `std::io::Error`. +fn classify_path_fd_error(err: &PathFdError) -> &'static str { + match err { + PathFdError::OpenCall { source, .. } => classify_io_error(source), + // PathFdError is #[non_exhaustive], handle future variants gracefully. + _ => "unexpected error", + } +} + +/// Classify a `std::io::Error` into a human-readable reason string. +fn classify_io_error(err: &std::io::Error) -> &'static str { + match err.kind() { + std::io::ErrorKind::NotFound => "path does not exist", + std::io::ErrorKind::PermissionDenied => "permission denied", + _ => match err.raw_os_error() { + Some(40) => "too many symlink levels", // ELOOP + Some(36) => "path name too long", // ENAMETOOLONG + Some(20) => "path component is not a directory", // ENOTDIR + _ => "unexpected error", + }, + } +} + fn compat_level(level: &LandlockCompatibility) -> CompatLevel { match level { LandlockCompatibility::BestEffort => CompatLevel::BestEffort, LandlockCompatibility::HardRequirement => CompatLevel::HardRequirement, } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn try_open_path_best_effort_returns_none_for_missing_path() { + let result = try_open_path( + &PathBuf::from("/nonexistent/openshell/test/path"), + &LandlockCompatibility::BestEffort, + ); + assert!(result.is_ok()); + assert!(result.unwrap().is_none()); + } + + #[test] + fn try_open_path_hard_requirement_errors_for_missing_path() { + let result = try_open_path( + &PathBuf::from("/nonexistent/openshell/test/path"), + &LandlockCompatibility::HardRequirement, + ); + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("hard_requirement"), + "error should mention hard_requirement mode: {err_msg}" + ); + assert!( + err_msg.contains("does not exist"), + "error should include the classified reason: {err_msg}" + ); + } + + #[test] + fn try_open_path_succeeds_for_existing_path() { + let dir = tempfile::tempdir().unwrap(); + let result = try_open_path(dir.path(), &LandlockCompatibility::BestEffort); + assert!(result.is_ok()); + assert!(result.unwrap().is_some()); + } + + #[test] + fn classify_not_found() { + let err = std::io::Error::from_raw_os_error(libc::ENOENT); + assert_eq!(classify_io_error(&err), "path does not exist"); + } + + #[test] + fn classify_permission_denied() { + let err = std::io::Error::from_raw_os_error(libc::EACCES); + assert_eq!(classify_io_error(&err), "permission denied"); + } + + #[test] + fn classify_symlink_loop() { + let err = std::io::Error::from_raw_os_error(libc::ELOOP); + assert_eq!(classify_io_error(&err), "too many symlink levels"); + } + + #[test] + fn classify_name_too_long() { + let err = std::io::Error::from_raw_os_error(libc::ENAMETOOLONG); + assert_eq!(classify_io_error(&err), "path name too long"); + } + + #[test] + fn classify_not_a_directory() { + let err = std::io::Error::from_raw_os_error(libc::ENOTDIR); + assert_eq!(classify_io_error(&err), "path component is not a directory"); + } + + #[test] + fn classify_unknown_error() { + let err = std::io::Error::from_raw_os_error(libc::EIO); + assert_eq!(classify_io_error(&err), "unexpected error"); + } + + #[test] + fn classify_path_fd_error_extracts_io_error() { + // Use PathFd::new on a non-existent path to get a real PathFdError + // (the OpenCall variant is #[non_exhaustive] and can't be constructed directly). + let err = PathFd::new("/nonexistent/openshell/classify/test").unwrap_err(); + assert_eq!(classify_path_fd_error(&err), "path does not exist"); + } +} diff --git a/crates/openshell-sandbox/src/sandbox/linux/netns.rs b/crates/openshell-sandbox/src/sandbox/linux/netns.rs index 5e6907c53..37d11f0c3 100644 --- a/crates/openshell-sandbox/src/sandbox/linux/netns.rs +++ b/crates/openshell-sandbox/src/sandbox/linux/netns.rs @@ -62,11 +62,15 @@ impl NetworkNamespace { .parse() .unwrap(); - info!( - namespace = %name, - host_veth = %veth_host, - sandbox_veth = %veth_sandbox, - "Creating network namespace" + openshell_ocsf::ocsf_emit!( + openshell_ocsf::ConfigStateChangeBuilder::new(crate::ocsf_ctx()) + .severity(openshell_ocsf::SeverityId::Informational) + .status(openshell_ocsf::StatusId::Success) + .state(openshell_ocsf::StateId::Enabled, "creating") + .message(format!( + "Creating network namespace [ns:{name} host_veth:{veth_host} sandbox_veth:{veth_sandbox}]" + )) + .build() ); // Create the namespace @@ -152,11 +156,15 @@ impl NetworkNamespace { } }; - info!( - namespace = %name, - host_ip = %host_ip, - sandbox_ip = %sandbox_ip, - "Network namespace created" + openshell_ocsf::ocsf_emit!( + openshell_ocsf::ConfigStateChangeBuilder::new(crate::ocsf_ctx()) + .severity(openshell_ocsf::SeverityId::Informational) + .status(openshell_ocsf::StatusId::Success) + .state(openshell_ocsf::StateId::Enabled, "created") + .message(format!( + "Network namespace created [ns:{name} host_ip:{host_ip} sandbox_ip:{sandbox_ip}]" + )) + .build() ); Ok(Self { @@ -246,12 +254,17 @@ impl NetworkNamespace { let iptables_path = match find_iptables() { Some(path) => path, None => { - warn!( - namespace = %self.name, - search_paths = ?IPTABLES_SEARCH_PATHS, - "iptables not found; bypass detection rules will not be installed. \ - Install the iptables package for proxy bypass diagnostics." - ); + openshell_ocsf::ocsf_emit!(openshell_ocsf::ConfigStateChangeBuilder::new( + crate::ocsf_ctx() + ) + .severity(openshell_ocsf::SeverityId::Medium) + .status(openshell_ocsf::StatusId::Failure) + .state(openshell_ocsf::StateId::Disabled, "degraded") + .message(format!( + "iptables not found; bypass detection rules will not be installed [ns:{}]", + self.name + )) + .build()); return Ok(()); } }; @@ -260,40 +273,58 @@ impl NetworkNamespace { let proxy_port_str = proxy_port.to_string(); let log_prefix = format!("openshell:bypass:{}:", &self.name); - info!( - namespace = %self.name, - iptables = iptables_path, - proxy_addr = %format!("{}:{}", host_ip_str, proxy_port), - "Installing bypass detection rules" - ); + // "Installing bypass detection rules" is a transient step — skip OCSF. + // The completion event below covers the outcome. // Install IPv4 rules - if let Err(e) = - self.install_bypass_rules_for(iptables_path, &host_ip_str, &proxy_port_str, &log_prefix) - { - warn!( - namespace = %self.name, - error = %e, - "Failed to install IPv4 bypass detection rules" + if let Err(e) = self.install_bypass_rules_for( + &iptables_path, + &host_ip_str, + &proxy_port_str, + &log_prefix, + ) { + openshell_ocsf::ocsf_emit!( + openshell_ocsf::ConfigStateChangeBuilder::new(crate::ocsf_ctx()) + .severity(openshell_ocsf::SeverityId::Medium) + .status(openshell_ocsf::StatusId::Failure) + .state(openshell_ocsf::StateId::Disabled, "failed") + .message(format!( + "Failed to install IPv4 bypass detection rules [ns:{}]: {e}", + self.name + )) + .build() ); return Err(e); } // Install IPv6 rules — best-effort. // Skip the proxy ACCEPT rule for IPv6 since the proxy address is IPv4. - if let Some(ip6_path) = find_ip6tables(iptables_path) { + if let Some(ip6_path) = find_ip6tables(&iptables_path) { if let Err(e) = self.install_bypass_rules_for_v6(&ip6_path, &log_prefix) { - warn!( - namespace = %self.name, - error = %e, - "Failed to install IPv6 bypass detection rules (non-fatal)" - ); + openshell_ocsf::ocsf_emit!(openshell_ocsf::ConfigStateChangeBuilder::new( + crate::ocsf_ctx() + ) + .severity(openshell_ocsf::SeverityId::Low) + .status(openshell_ocsf::StatusId::Failure) + .state(openshell_ocsf::StateId::Other, "degraded") + .message(format!( + "Failed to install IPv6 bypass detection rules (non-fatal) [ns:{}]: {e}", + self.name + )) + .build()); } } - info!( - namespace = %self.name, - "Bypass detection rules installed" + openshell_ocsf::ocsf_emit!( + openshell_ocsf::ConfigStateChangeBuilder::new(crate::ocsf_ctx()) + .severity(openshell_ocsf::SeverityId::Informational) + .status(openshell_ocsf::StatusId::Success) + .state(openshell_ocsf::StateId::Enabled, "installed") + .message(format!( + "Bypass detection rules installed [ns:{}]", + self.name + )) + .build() ); Ok(()) @@ -372,11 +403,17 @@ impl NetworkNamespace { "--log-uid", ], ) { - warn!( - error = %e, - "Failed to install LOG rule for TCP (xt_LOG module may not be loaded); \ - bypass REJECT rules will still be installed" - ); + openshell_ocsf::ocsf_emit!(openshell_ocsf::ConfigStateChangeBuilder::new( + crate::ocsf_ctx() + ) + .severity(openshell_ocsf::SeverityId::Low) + .status(openshell_ocsf::StatusId::Failure) + .state(openshell_ocsf::StateId::Other, "degraded") + .message(format!( + "Failed to install LOG rule for TCP (xt_LOG module may not be loaded) [ns:{}]: {e}", + self.name + )) + .build()); } // Rule 5: REJECT TCP bypass attempts (fast-fail) @@ -417,9 +454,16 @@ impl NetworkNamespace { "--log-uid", ], ) { - warn!( - error = %e, - "Failed to install LOG rule for UDP; bypass REJECT rules will still be installed" + openshell_ocsf::ocsf_emit!( + openshell_ocsf::ConfigStateChangeBuilder::new(crate::ocsf_ctx()) + .severity(openshell_ocsf::SeverityId::Low) + .status(openshell_ocsf::StatusId::Failure) + .state(openshell_ocsf::StateId::Other, "degraded") + .message(format!( + "Failed to install LOG rule for UDP [ns:{}]: {e}", + self.name + )) + .build() ); } @@ -494,7 +538,17 @@ impl NetworkNamespace { "--log-uid", ], ) { - warn!(error = %e, "Failed to install IPv6 LOG rule for TCP"); + openshell_ocsf::ocsf_emit!( + openshell_ocsf::ConfigStateChangeBuilder::new(crate::ocsf_ctx()) + .severity(openshell_ocsf::SeverityId::Low) + .status(openshell_ocsf::StatusId::Failure) + .state(openshell_ocsf::StateId::Other, "degraded") + .message(format!( + "Failed to install IPv6 LOG rule for TCP [ns:{}]: {e}", + self.name + )) + .build() + ); } // REJECT TCP bypass attempts @@ -535,7 +589,17 @@ impl NetworkNamespace { "--log-uid", ], ) { - warn!(error = %e, "Failed to install IPv6 LOG rule for UDP"); + openshell_ocsf::ocsf_emit!( + openshell_ocsf::ConfigStateChangeBuilder::new(crate::ocsf_ctx()) + .severity(openshell_ocsf::SeverityId::Low) + .status(openshell_ocsf::StatusId::Failure) + .state(openshell_ocsf::StateId::Other, "degraded") + .message(format!( + "Failed to install IPv6 LOG rule for UDP [ns:{}]: {e}", + self.name + )) + .build() + ); } // REJECT UDP bypass attempts @@ -585,7 +649,14 @@ impl Drop for NetworkNamespace { ); } - info!(namespace = %self.name, "Network namespace cleaned up"); + openshell_ocsf::ocsf_emit!( + openshell_ocsf::ConfigStateChangeBuilder::new(crate::ocsf_ctx()) + .severity(openshell_ocsf::SeverityId::Informational) + .status(openshell_ocsf::StatusId::Success) + .state(openshell_ocsf::StateId::Disabled, "cleaned_up") + .message(format!("Network namespace cleaned up [ns:{}]", self.name)) + .build() + ); } } @@ -666,12 +737,92 @@ fn run_iptables_netns(netns: &str, iptables_cmd: &str, args: &[&str]) -> Result< const IPTABLES_SEARCH_PATHS: &[&str] = &["/usr/sbin/iptables", "/sbin/iptables", "/usr/bin/iptables"]; +/// Returns true if xt extension modules (e.g. xt_comment) cannot be used +/// via the given iptables binary. +/// +/// Some kernels have nf_tables but lack the nft_compat bridge that allows +/// xt extension modules to be used through the nf_tables path (e.g. Jetson +/// Linux 5.15-tegra). This probe detects that condition by attempting to +/// insert a rule using the xt_comment extension. If it fails, xt extensions +/// are unavailable and the caller should fall back to iptables-legacy. +fn xt_extensions_unavailable(iptables_path: &str) -> bool { + // Create a temporary probe chain. If this fails (e.g. no CAP_NET_ADMIN), + // we can't determine availability — assume extensions are available. + let created = Command::new(iptables_path) + .args(["-t", "filter", "-N", "_xt_probe"]) + .output() + .map(|o| o.status.success()) + .unwrap_or(false); + + if !created { + return false; + } + + // Attempt to insert a rule using xt_comment. Failure means nft_compat + // cannot bridge xt extension modules on this kernel. + let probe_ok = Command::new(iptables_path) + .args([ + "-t", + "filter", + "-A", + "_xt_probe", + "-m", + "comment", + "--comment", + "probe", + "-j", + "ACCEPT", + ]) + .output() + .map(|o| o.status.success()) + .unwrap_or(false); + + // Clean up — best-effort, ignore failures. + let _ = Command::new(iptables_path) + .args([ + "-t", + "filter", + "-D", + "_xt_probe", + "-m", + "comment", + "--comment", + "probe", + "-j", + "ACCEPT", + ]) + .output(); + let _ = Command::new(iptables_path) + .args(["-t", "filter", "-X", "_xt_probe"]) + .output(); + + !probe_ok +} + /// Find the iptables binary path, checking well-known locations. -fn find_iptables() -> Option<&'static str> { - IPTABLES_SEARCH_PATHS +/// +/// If xt extension modules are unavailable via the standard binary and +/// `iptables-legacy` is available alongside it, the legacy binary is returned +/// instead. This ensures bypass-detection rules can be installed on kernels +/// where `nft_compat` is unavailable (e.g. Jetson Linux 5.15-tegra). +fn find_iptables() -> Option { + let standard_path = IPTABLES_SEARCH_PATHS .iter() .find(|path| std::path::Path::new(path).exists()) - .copied() + .copied()?; + + if xt_extensions_unavailable(standard_path) { + let legacy_path = standard_path.replace("iptables", "iptables-legacy"); + if std::path::Path::new(&legacy_path).exists() { + debug!( + legacy = legacy_path, + "xt extensions unavailable; using iptables-legacy" + ); + return Some(legacy_path); + } + } + + Some(standard_path.to_string()) } /// Find the ip6tables binary path, deriving it from the iptables location. diff --git a/crates/openshell-sandbox/src/sandbox/linux/seccomp.rs b/crates/openshell-sandbox/src/sandbox/linux/seccomp.rs index 6c9d8307b..e23447498 100644 --- a/crates/openshell-sandbox/src/sandbox/linux/seccomp.rs +++ b/crates/openshell-sandbox/src/sandbox/linux/seccomp.rs @@ -2,6 +2,15 @@ // SPDX-License-Identifier: Apache-2.0 //! Seccomp syscall filtering. +//! +//! The filter uses a default-allow policy with targeted blocks: +//! +//! 1. **Socket domain blocks** -- prevent raw/kernel sockets that bypass the proxy +//! 2. **Unconditional syscall blocks** -- block syscalls that enable sandbox escape +//! (fileless exec, ptrace, BPF, cross-process memory access, io_uring, mount) +//! 3. **Conditional syscall blocks** -- block dangerous flag combinations on otherwise +//! needed syscalls (execveat+AT_EMPTY_PATH, unshare+CLONE_NEWUSER, +//! seccomp+SET_MODE_FILTER) use crate::policy::{NetworkMode, SandboxPolicy}; use miette::{IntoDiagnostic, Result}; @@ -13,6 +22,9 @@ use std::collections::BTreeMap; use std::convert::TryInto; use tracing::debug; +/// Value of `SECCOMP_SET_MODE_FILTER` (linux/seccomp.h). +const SECCOMP_SET_MODE_FILTER: u64 = 1; + pub fn apply(policy: &SandboxPolicy) -> Result<()> { if matches!(policy.network.mode, NetworkMode::Allow) { return Ok(()); @@ -37,6 +49,7 @@ pub fn apply(policy: &SandboxPolicy) -> Result<()> { fn build_filter(allow_inet: bool) -> Result { let mut rules: BTreeMap> = BTreeMap::new(); + // --- Socket domain blocks --- let mut blocked_domains = vec![libc::AF_PACKET, libc::AF_BLUETOOTH, libc::AF_VSOCK]; if !allow_inet { blocked_domains.push(libc::AF_INET); @@ -49,6 +62,51 @@ fn build_filter(allow_inet: bool) -> Result { add_socket_domain_rule(&mut rules, domain)?; } + // --- Unconditional syscall blocks --- + // These syscalls are blocked entirely (empty rule vec = unconditional EPERM). + + // Fileless binary execution via memfd bypasses Landlock filesystem restrictions. + rules.entry(libc::SYS_memfd_create).or_default(); + // Cross-process memory inspection and code injection. + rules.entry(libc::SYS_ptrace).or_default(); + // Kernel BPF program loading. + rules.entry(libc::SYS_bpf).or_default(); + // Cross-process memory read. + rules.entry(libc::SYS_process_vm_readv).or_default(); + // Async I/O subsystem with extensive CVE history. + rules.entry(libc::SYS_io_uring_setup).or_default(); + // Filesystem mount could subvert Landlock or overlay writable paths. + rules.entry(libc::SYS_mount).or_default(); + + // --- Conditional syscall blocks --- + + // execveat with AT_EMPTY_PATH enables fileless execution from an anonymous fd. + add_masked_arg_rule( + &mut rules, + libc::SYS_execveat, + 4, // flags argument + libc::AT_EMPTY_PATH as u64, + )?; + + // unshare with CLONE_NEWUSER allows creating user namespaces to escalate privileges. + add_masked_arg_rule( + &mut rules, + libc::SYS_unshare, + 0, // flags argument + libc::CLONE_NEWUSER as u64, + )?; + + // seccomp(SECCOMP_SET_MODE_FILTER) would let sandboxed code replace the active filter. + let condition = SeccompCondition::new( + 0, // operation argument + SeccompCmpArgLen::Dword, + SeccompCmpOp::Eq, + SECCOMP_SET_MODE_FILTER, + ) + .into_diagnostic()?; + let rule = SeccompRule::new(vec![condition]).into_diagnostic()?; + rules.entry(libc::SYS_seccomp).or_default().push(rule); + let arch = std::env::consts::ARCH .try_into() .map_err(|_| miette::miette!("Unsupported architecture for seccomp"))?; @@ -74,3 +132,127 @@ fn add_socket_domain_rule(rules: &mut BTreeMap>, domain: i rules.entry(libc::SYS_socket).or_default().push(rule); Ok(()) } + +/// Block a syscall when a specific bit pattern is set in an argument. +/// +/// Uses `MaskedEq` to check `(arg & flag_bit) == flag_bit`, which triggers +/// EPERM when the flag is present regardless of other bits in the argument. +fn add_masked_arg_rule( + rules: &mut BTreeMap>, + syscall: i64, + arg_index: u8, + flag_bit: u64, +) -> Result<()> { + let condition = SeccompCondition::new( + arg_index, + SeccompCmpArgLen::Dword, + SeccompCmpOp::MaskedEq(flag_bit), + flag_bit, + ) + .into_diagnostic()?; + let rule = SeccompRule::new(vec![condition]).into_diagnostic()?; + rules.entry(syscall).or_default().push(rule); + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn build_filter_proxy_mode_compiles() { + let filter = build_filter(true); + assert!(filter.is_ok(), "build_filter(true) should succeed"); + } + + #[test] + fn build_filter_block_mode_compiles() { + let filter = build_filter(false); + assert!(filter.is_ok(), "build_filter(false) should succeed"); + } + + #[test] + fn add_masked_arg_rule_creates_entry() { + let mut rules: BTreeMap> = BTreeMap::new(); + let result = add_masked_arg_rule(&mut rules, libc::SYS_execveat, 4, 0x1000); + assert!(result.is_ok()); + assert!( + rules.contains_key(&libc::SYS_execveat), + "should have an entry for SYS_execveat" + ); + assert_eq!( + rules[&libc::SYS_execveat].len(), + 1, + "should have exactly one rule" + ); + } + + #[test] + fn unconditional_blocks_present_in_filter() { + let mut rules: BTreeMap> = BTreeMap::new(); + + // Simulate what build_filter does for unconditional blocks + rules.entry(libc::SYS_memfd_create).or_default(); + rules.entry(libc::SYS_ptrace).or_default(); + rules.entry(libc::SYS_bpf).or_default(); + rules.entry(libc::SYS_process_vm_readv).or_default(); + rules.entry(libc::SYS_io_uring_setup).or_default(); + rules.entry(libc::SYS_mount).or_default(); + + // Unconditional blocks have an empty Vec (no conditions = always match) + for syscall in [ + libc::SYS_memfd_create, + libc::SYS_ptrace, + libc::SYS_bpf, + libc::SYS_process_vm_readv, + libc::SYS_io_uring_setup, + libc::SYS_mount, + ] { + assert!( + rules.contains_key(&syscall), + "syscall {syscall} should be in the rules map" + ); + assert!( + rules[&syscall].is_empty(), + "syscall {syscall} should have empty rules (unconditional block)" + ); + } + } + + #[test] + fn conditional_blocks_have_rules() { + // Build a real filter and verify the conditional syscalls have rule entries + // (non-empty Vec means conditional match) + let mut rules: BTreeMap> = BTreeMap::new(); + + add_masked_arg_rule( + &mut rules, + libc::SYS_execveat, + 4, + libc::AT_EMPTY_PATH as u64, + ) + .unwrap(); + add_masked_arg_rule(&mut rules, libc::SYS_unshare, 0, libc::CLONE_NEWUSER as u64).unwrap(); + + let condition = SeccompCondition::new( + 0, + SeccompCmpArgLen::Dword, + SeccompCmpOp::Eq, + SECCOMP_SET_MODE_FILTER, + ) + .unwrap(); + let rule = SeccompRule::new(vec![condition]).unwrap(); + rules.entry(libc::SYS_seccomp).or_default().push(rule); + + for syscall in [libc::SYS_execveat, libc::SYS_unshare, libc::SYS_seccomp] { + assert!( + rules.contains_key(&syscall), + "syscall {syscall} should be in the rules map" + ); + assert!( + !rules[&syscall].is_empty(), + "syscall {syscall} should have conditional rules" + ); + } + } +} diff --git a/crates/openshell-sandbox/src/sandbox/mod.rs b/crates/openshell-sandbox/src/sandbox/mod.rs index f512a8e33..f7b037338 100644 --- a/crates/openshell-sandbox/src/sandbox/mod.rs +++ b/crates/openshell-sandbox/src/sandbox/mod.rs @@ -5,8 +5,6 @@ use crate::policy::SandboxPolicy; use miette::Result; -#[cfg(not(target_os = "linux"))] -use tracing::warn; #[cfg(target_os = "linux")] pub mod linux; @@ -26,7 +24,17 @@ pub fn apply(policy: &SandboxPolicy, workdir: Option<&str>) -> Result<()> { #[cfg(not(target_os = "linux"))] { let _ = (policy, workdir); - warn!("Sandbox policy provided but platform sandboxing is not yet implemented"); + openshell_ocsf::ocsf_emit!( + openshell_ocsf::DetectionFindingBuilder::new(crate::ocsf_ctx()) + .activity(openshell_ocsf::ActivityId::Open) + .severity(openshell_ocsf::SeverityId::Medium) + .finding_info(openshell_ocsf::FindingInfo::new( + "platform-sandbox-unavailable", + "Platform Sandboxing Not Implemented", + ).with_desc("Sandbox policy provided but platform sandboxing is not yet implemented on this OS")) + .message("Platform sandboxing not yet implemented") + .build() + ); Ok(()) } } diff --git a/crates/openshell-sandbox/src/secrets.rs b/crates/openshell-sandbox/src/secrets.rs index 4ee1ee846..a27537c91 100644 --- a/crates/openshell-sandbox/src/secrets.rs +++ b/crates/openshell-sandbox/src/secrets.rs @@ -1,12 +1,68 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 +use base64::Engine as _; use std::collections::HashMap; +use std::fmt; const PLACEHOLDER_PREFIX: &str = "openshell:resolve:env:"; +/// Public access to the placeholder prefix for fail-closed scanning in other modules. +pub(crate) const PLACEHOLDER_PREFIX_PUBLIC: &str = PLACEHOLDER_PREFIX; + +/// Characters that are valid in an env var key name (used to extract +/// placeholder boundaries within concatenated strings like path segments). +fn is_env_key_char(b: u8) -> bool { + b.is_ascii_alphanumeric() || b == b'_' +} + +// --------------------------------------------------------------------------- +// Error and result types +// --------------------------------------------------------------------------- + +/// Error returned when a placeholder cannot be resolved or a resolved secret +/// contains prohibited characters. +#[derive(Debug)] +pub(crate) struct UnresolvedPlaceholderError { + pub location: &'static str, // "header", "query_param", "path" +} + +impl fmt::Display for UnresolvedPlaceholderError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "unresolved credential placeholder in {}: detected openshell:resolve:env:* token that could not be resolved", + self.location + ) + } +} + +/// Result of rewriting an HTTP header block with credential resolution. +#[derive(Debug)] +pub(crate) struct RewriteResult { + /// The rewritten HTTP bytes (headers + body overflow). + pub rewritten: Vec, + /// A redacted version of the request target for logging. + /// Contains `[CREDENTIAL]` in place of resolved credential values. + /// `None` if the target was not modified. + pub redacted_target: Option, +} + +/// Result of rewriting a request target for OPA evaluation. +#[derive(Debug)] +pub(crate) struct RewriteTargetResult { + /// The resolved target (real secrets) — for upstream forwarding only. + pub resolved: String, + /// The redacted target (`[CREDENTIAL]` in place of secrets) — for OPA + logs. + pub redacted: String, +} + +// --------------------------------------------------------------------------- +// SecretResolver +// --------------------------------------------------------------------------- + #[derive(Debug, Clone, Default)] -pub(crate) struct SecretResolver { +pub struct SecretResolver { by_placeholder: HashMap, } @@ -30,45 +86,513 @@ impl SecretResolver { (child_env, Some(Self { by_placeholder })) } + /// Resolve a placeholder string to the real secret value. + /// + /// Returns `None` if the placeholder is unknown or the resolved value + /// contains prohibited control characters (CRLF, null byte). pub(crate) fn resolve_placeholder(&self, value: &str) -> Option<&str> { - self.by_placeholder.get(value).map(String::as_str) + let secret = self.by_placeholder.get(value).map(String::as_str)?; + match validate_resolved_secret(secret) { + Ok(s) => Some(s), + Err(reason) => { + tracing::warn!( + location = "resolve_placeholder", + reason, + "credential resolution rejected: resolved value contains prohibited characters" + ); + None + } + } } pub(crate) fn rewrite_header_value(&self, value: &str) -> Option { + // Direct placeholder match: `x-api-key: openshell:resolve:env:KEY` if let Some(secret) = self.resolve_placeholder(value.trim()) { return Some(secret.to_string()); } let trimmed = value.trim(); + + // Basic auth decoding: `Basic ` where the decoded content + // contains a placeholder (e.g. `user:openshell:resolve:env:PASS`). + if let Some(encoded) = trimmed + .strip_prefix("Basic ") + .or_else(|| trimmed.strip_prefix("basic ")) + .map(str::trim) + { + if let Some(rewritten) = self.rewrite_basic_auth_token(encoded) { + return Some(format!("Basic {rewritten}")); + } + } + + // Prefixed placeholder: `Bearer openshell:resolve:env:KEY` let split_at = trimmed.find(char::is_whitespace)?; let prefix = &trimmed[..split_at]; let candidate = trimmed[split_at..].trim(); let secret = self.resolve_placeholder(candidate)?; Some(format!("{prefix} {secret}")) } + + /// Decode a Base64-encoded Basic auth token, resolve any placeholders in + /// the decoded `username:password` string, and re-encode. + /// + /// Returns `None` if decoding fails or no placeholders are found. + fn rewrite_basic_auth_token(&self, encoded: &str) -> Option { + let b64 = base64::engine::general_purpose::STANDARD; + let decoded_bytes = b64.decode(encoded.trim()).ok()?; + let decoded = std::str::from_utf8(&decoded_bytes).ok()?; + + // Check if the decoded string contains any placeholder + if !decoded.contains(PLACEHOLDER_PREFIX) { + return None; + } + + // Rewrite all placeholder occurrences in the decoded string + let mut rewritten = decoded.to_string(); + for (placeholder, secret) in &self.by_placeholder { + if rewritten.contains(placeholder.as_str()) { + // Validate the resolved secret for control characters + if validate_resolved_secret(secret).is_err() { + tracing::warn!( + location = "basic_auth", + "credential resolution rejected: resolved value contains prohibited characters" + ); + return None; + } + rewritten = rewritten.replace(placeholder.as_str(), secret); + } + } + + // Only return if we actually changed something + if rewritten == decoded { + return None; + } + + Some(b64.encode(rewritten.as_bytes())) + } } pub(crate) fn placeholder_for_env_key(key: &str) -> String { format!("{PLACEHOLDER_PREFIX}{key}") } -pub(crate) fn rewrite_http_header_block(raw: &[u8], resolver: Option<&SecretResolver>) -> Vec { +// --------------------------------------------------------------------------- +// Secret validation (F1 — CWE-113) +// --------------------------------------------------------------------------- + +/// Validate that a resolved secret value does not contain characters that +/// could enable HTTP header injection or request splitting. +fn validate_resolved_secret(value: &str) -> Result<&str, &'static str> { + if value + .bytes() + .any(|b| b == b'\r' || b == b'\n' || b == b'\0') + { + return Err("resolved secret contains prohibited control characters (CR, LF, or NUL)"); + } + Ok(value) +} + +// --------------------------------------------------------------------------- +// Percent encoding/decoding (RFC 3986) +// --------------------------------------------------------------------------- + +/// Percent-encode a string for safe use in URL query parameter values. +/// +/// Encodes all characters except unreserved characters (RFC 3986 Section 2.3): +/// ALPHA / DIGIT / "-" / "." / "_" / "~" +fn percent_encode_query(input: &str) -> String { + let mut encoded = String::with_capacity(input.len()); + for byte in input.bytes() { + match byte { + b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'.' | b'_' | b'~' => { + encoded.push(byte as char); + } + _ => { + use fmt::Write; + let _ = write!(encoded, "%{byte:02X}"); + } + } + } + encoded +} + +/// Percent-encode a string for safe use in URL path segments. +/// +/// RFC 3986 §3.3: pchar = unreserved / pct-encoded / sub-delims / ":" / "@" +/// sub-delims = "!" / "$" / "&" / "'" / "(" / ")" / "*" / "+" / "," / ";" / "=" +/// +/// Must encode: `/`, `?`, `#`, space, and other non-pchar characters. +fn percent_encode_path_segment(input: &str) -> String { + let mut encoded = String::with_capacity(input.len()); + for byte in input.bytes() { + match byte { + // unreserved + b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'.' | b'_' | b'~' => { + encoded.push(byte as char); + } + // sub-delims + ":" + "@" + b'!' | b'$' | b'&' | b'\'' | b'(' | b')' | b'*' | b'+' | b',' | b';' | b'=' | b':' + | b'@' => { + encoded.push(byte as char); + } + _ => { + use fmt::Write; + let _ = write!(encoded, "%{byte:02X}"); + } + } + } + encoded +} + +/// Percent-decode a URL-encoded string. +fn percent_decode(input: &str) -> String { + let mut decoded = Vec::with_capacity(input.len()); + let mut bytes = input.bytes(); + while let Some(b) = bytes.next() { + if b == b'%' { + let hi = bytes.next(); + let lo = bytes.next(); + if let (Some(h), Some(l)) = (hi, lo) { + let hex = [h, l]; + if let Ok(s) = std::str::from_utf8(&hex) { + if let Ok(val) = u8::from_str_radix(s, 16) { + decoded.push(val); + continue; + } + } + // Invalid percent encoding — preserve verbatim + decoded.push(b'%'); + decoded.push(h); + decoded.push(l); + } else { + decoded.push(b'%'); + if let Some(h) = hi { + decoded.push(h); + } + } + } else { + decoded.push(b); + } + } + String::from_utf8_lossy(&decoded).into_owned() +} + +// --------------------------------------------------------------------------- +// Path credential validation (F3 — CWE-22) +// --------------------------------------------------------------------------- + +/// Validate that a resolved credential value is safe for use in a URL path segment. +/// +/// Operates on the raw (decoded) credential value before percent-encoding. +/// Rejects values that could enable path traversal, request splitting, or +/// URI structure breakage. +fn validate_credential_for_path(value: &str) -> Result<(), String> { + if value.contains("../") || value.contains("..\\") || value == ".." { + return Err("credential contains path traversal sequence".into()); + } + if value.contains('\0') || value.contains('\r') || value.contains('\n') { + return Err("credential contains control character".into()); + } + if value.contains('/') || value.contains('\\') { + return Err("credential contains path separator".into()); + } + if value.contains('?') || value.contains('#') { + return Err("credential contains URI delimiter".into()); + } + Ok(()) +} + +// --------------------------------------------------------------------------- +// URI rewriting +// --------------------------------------------------------------------------- + +/// Result of rewriting the request line. +struct RewriteLineResult { + /// The rewritten request line. + line: String, + /// Redacted target for logging (if any rewriting occurred). + redacted_target: Option, +} + +/// Rewrite credential placeholders in the request line's URL. +/// +/// Given a request line like `GET /bot{TOKEN}/path?key={APIKEY} HTTP/1.1`, +/// resolves placeholders in both path segments and query parameter values. +fn rewrite_request_line( + line: &str, + resolver: &SecretResolver, +) -> Result { + // Request line format: METHOD SP REQUEST-URI SP HTTP-VERSION + let mut parts = line.splitn(3, ' '); + let method = match parts.next() { + Some(m) => m, + None => { + return Ok(RewriteLineResult { + line: line.to_string(), + redacted_target: None, + }); + } + }; + let uri = match parts.next() { + Some(u) => u, + None => { + return Ok(RewriteLineResult { + line: line.to_string(), + redacted_target: None, + }); + } + }; + let version = match parts.next() { + Some(v) => v, + None => { + return Ok(RewriteLineResult { + line: line.to_string(), + redacted_target: None, + }); + } + }; + + // Only rewrite if the URI contains a placeholder + if !uri.contains(PLACEHOLDER_PREFIX) { + return Ok(RewriteLineResult { + line: line.to_string(), + redacted_target: None, + }); + } + + // Split URI into path and query + let (path, query) = match uri.split_once('?') { + Some((p, q)) => (p, Some(q)), + None => (uri, None), + }; + + // Rewrite path segments + let (resolved_path, redacted_path) = match rewrite_uri_path(path, resolver)? { + Some((resolved, redacted)) => (resolved, redacted), + None => (path.to_string(), path.to_string()), + }; + + // Rewrite query params + let (resolved_query, redacted_query) = match query { + Some(q) => match rewrite_uri_query_params(q, resolver)? { + Some((resolved, redacted)) => (Some(resolved), Some(redacted)), + None => (Some(q.to_string()), Some(q.to_string())), + }, + None => (None, None), + }; + + // Reassemble + let resolved_uri = match &resolved_query { + Some(q) => format!("{resolved_path}?{q}"), + None => resolved_path.clone(), + }; + let redacted_uri = match &redacted_query { + Some(q) => format!("{redacted_path}?{q}"), + None => redacted_path, + }; + + Ok(RewriteLineResult { + line: format!("{method} {resolved_uri} {version}"), + redacted_target: Some(redacted_uri), + }) +} + +/// Rewrite placeholders in URL path segments. +/// +/// Handles substring matching for cases like Telegram's `/bot{TOKEN}/method` +/// where the placeholder is concatenated with literal text in a segment. +/// +/// Returns `Some((resolved_path, redacted_path))` if any placeholders were found, +/// `None` if no placeholders exist in the path. +fn rewrite_uri_path( + path: &str, + resolver: &SecretResolver, +) -> Result, UnresolvedPlaceholderError> { + if !path.contains(PLACEHOLDER_PREFIX) { + return Ok(None); + } + + let segments: Vec<&str> = path.split('/').collect(); + let mut resolved_segments = Vec::with_capacity(segments.len()); + let mut redacted_segments = Vec::with_capacity(segments.len()); + let mut any_rewritten = false; + + for segment in &segments { + let decoded = percent_decode(segment); + if !decoded.contains(PLACEHOLDER_PREFIX) { + resolved_segments.push(segment.to_string()); + redacted_segments.push(segment.to_string()); + continue; + } + + let (resolved, redacted) = rewrite_path_segment(&decoded, resolver)?; + // Percent-encode the resolved segment for path context + resolved_segments.push(percent_encode_path_segment(&resolved)); + redacted_segments.push(redacted); + any_rewritten = true; + } + + if !any_rewritten { + return Ok(None); + } + + Ok(Some(( + resolved_segments.join("/"), + redacted_segments.join("/"), + ))) +} + +/// Rewrite placeholders within a single path segment (already percent-decoded). +/// +/// Uses the placeholder grammar `openshell:resolve:env:[A-Za-z_][A-Za-z0-9_]*` +/// to determine placeholder boundaries within concatenated text. +fn rewrite_path_segment( + segment: &str, + resolver: &SecretResolver, +) -> Result<(String, String), UnresolvedPlaceholderError> { + let mut resolved = String::with_capacity(segment.len()); + let mut redacted = String::with_capacity(segment.len()); + let mut pos = 0; + let bytes = segment.as_bytes(); + + while pos < bytes.len() { + if let Some(start) = segment[pos..].find(PLACEHOLDER_PREFIX) { + let abs_start = pos + start; + // Copy literal prefix before the placeholder + resolved.push_str(&segment[pos..abs_start]); + redacted.push_str(&segment[pos..abs_start]); + + // Extract the key name using the env var grammar: [A-Za-z_][A-Za-z0-9_]* + let key_start = abs_start + PLACEHOLDER_PREFIX.len(); + let key_end = segment[key_start..] + .bytes() + .position(|b| !is_env_key_char(b)) + .map_or(segment.len(), |p| key_start + p); + + if key_end == key_start { + // Empty key — not a valid placeholder, copy literally + resolved.push_str(&segment[abs_start..abs_start + PLACEHOLDER_PREFIX.len()]); + redacted.push_str(&segment[abs_start..abs_start + PLACEHOLDER_PREFIX.len()]); + pos = abs_start + PLACEHOLDER_PREFIX.len(); + continue; + } + + let full_placeholder = &segment[abs_start..key_end]; + if let Some(secret) = resolver.resolve_placeholder(full_placeholder) { + validate_credential_for_path(secret).map_err(|reason| { + tracing::warn!( + location = "path", + %reason, + "credential resolution rejected: resolved value unsafe for path" + ); + UnresolvedPlaceholderError { location: "path" } + })?; + resolved.push_str(secret); + redacted.push_str("[CREDENTIAL]"); + } else { + return Err(UnresolvedPlaceholderError { location: "path" }); + } + pos = key_end; + } else { + // No more placeholders in remainder + resolved.push_str(&segment[pos..]); + redacted.push_str(&segment[pos..]); + break; + } + } + + Ok((resolved, redacted)) +} + +/// Rewrite placeholders in query parameter values. +/// +/// Returns `Some((resolved_query, redacted_query))` if any placeholders were found. +fn rewrite_uri_query_params( + query: &str, + resolver: &SecretResolver, +) -> Result, UnresolvedPlaceholderError> { + if !query.contains(PLACEHOLDER_PREFIX) { + return Ok(None); + } + + let mut resolved_params = Vec::new(); + let mut redacted_params = Vec::new(); + let mut any_rewritten = false; + + for param in query.split('&') { + if let Some((key, value)) = param.split_once('=') { + let decoded_value = percent_decode(value); + if let Some(secret) = resolver.resolve_placeholder(&decoded_value) { + resolved_params.push(format!("{key}={}", percent_encode_query(secret))); + redacted_params.push(format!("{key}=[CREDENTIAL]")); + any_rewritten = true; + } else if decoded_value.contains(PLACEHOLDER_PREFIX) { + // Placeholder detected but not resolved + return Err(UnresolvedPlaceholderError { + location: "query_param", + }); + } else { + resolved_params.push(param.to_string()); + redacted_params.push(param.to_string()); + } + } else { + resolved_params.push(param.to_string()); + redacted_params.push(param.to_string()); + } + } + + if !any_rewritten { + return Ok(None); + } + + Ok(Some((resolved_params.join("&"), redacted_params.join("&")))) +} + +// --------------------------------------------------------------------------- +// Public rewrite API +// --------------------------------------------------------------------------- + +/// Rewrite credential placeholders in an HTTP header block. +/// +/// Resolves `openshell:resolve:env:*` placeholders in the request line +/// (path segments and query parameter values), header values (including +/// Basic auth tokens), and returns a `RewriteResult` with the rewritten +/// bytes and a redacted target for logging. +/// +/// Returns `Err` if any placeholder is detected but cannot be resolved +/// (fail-closed behavior). +pub(crate) fn rewrite_http_header_block( + raw: &[u8], + resolver: Option<&SecretResolver>, +) -> Result { let Some(resolver) = resolver else { - return raw.to_vec(); + return Ok(RewriteResult { + rewritten: raw.to_vec(), + redacted_target: None, + }); }; let Some(header_end) = raw.windows(4).position(|w| w == b"\r\n\r\n").map(|p| p + 4) else { - return raw.to_vec(); + return Ok(RewriteResult { + rewritten: raw.to_vec(), + redacted_target: None, + }); }; let header_str = String::from_utf8_lossy(&raw[..header_end]); let mut lines = header_str.split("\r\n"); let Some(request_line) = lines.next() else { - return raw.to_vec(); + return Ok(RewriteResult { + rewritten: raw.to_vec(), + redacted_target: None, + }); }; + // Rewrite request line (path + query params) + let rl_result = rewrite_request_line(request_line, resolver)?; + let mut output = Vec::with_capacity(raw.len()); - output.extend_from_slice(request_line.as_bytes()); + output.extend_from_slice(rl_result.line.as_bytes()); output.extend_from_slice(b"\r\n"); for line in lines { @@ -82,7 +606,25 @@ pub(crate) fn rewrite_http_header_block(raw: &[u8], resolver: Option<&SecretReso output.extend_from_slice(b"\r\n"); output.extend_from_slice(&raw[header_end..]); - output + + // Fail-closed scan: check for any remaining unresolved placeholders + // in both raw form and percent-decoded form of the output header block. + let output_header = String::from_utf8_lossy(&output[..output.len().min(header_end + 256)]); + if output_header.contains(PLACEHOLDER_PREFIX) { + return Err(UnresolvedPlaceholderError { location: "header" }); + } + + // Also check percent-decoded form of the request line (F5 — encoded placeholder bypass) + let rewritten_rl = output_header.split("\r\n").next().unwrap_or(""); + let decoded_rl = percent_decode(rewritten_rl); + if decoded_rl.contains(PLACEHOLDER_PREFIX) { + return Err(UnresolvedPlaceholderError { location: "path" }); + } + + Ok(RewriteResult { + rewritten: output, + redacted_target: rl_result.redacted_target, + }) } pub(crate) fn rewrite_header_line(line: &str, resolver: &SecretResolver) -> String { @@ -96,10 +638,68 @@ pub(crate) fn rewrite_header_line(line: &str, resolver: &SecretResolver) -> Stri } } +/// Resolve placeholders in a request target (path + query) for OPA evaluation. +/// +/// Returns the resolved target (real secrets, for upstream) and a redacted +/// version (`[CREDENTIAL]` in place of secrets, for OPA input and logs). +pub(crate) fn rewrite_target_for_eval( + target: &str, + resolver: &SecretResolver, +) -> Result { + if !target.contains(PLACEHOLDER_PREFIX) { + // Also check percent-decoded form + let decoded = percent_decode(target); + if decoded.contains(PLACEHOLDER_PREFIX) { + return Err(UnresolvedPlaceholderError { location: "path" }); + } + return Ok(RewriteTargetResult { + resolved: target.to_string(), + redacted: target.to_string(), + }); + } + + let (path, query) = match target.split_once('?') { + Some((p, q)) => (p, Some(q)), + None => (target, None), + }; + + // Rewrite path + let (resolved_path, redacted_path) = match rewrite_uri_path(path, resolver)? { + Some((resolved, redacted)) => (resolved, redacted), + None => (path.to_string(), path.to_string()), + }; + + // Rewrite query + let (resolved_query, redacted_query) = match query { + Some(q) => match rewrite_uri_query_params(q, resolver)? { + Some((resolved, redacted)) => (Some(resolved), Some(redacted)), + None => (Some(q.to_string()), Some(q.to_string())), + }, + None => (None, None), + }; + + let resolved = match &resolved_query { + Some(q) => format!("{resolved_path}?{q}"), + None => resolved_path, + }; + let redacted = match &redacted_query { + Some(q) => format!("{redacted_path}?{q}"), + None => redacted_path, + }; + + Ok(RewriteTargetResult { resolved, redacted }) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + #[cfg(test)] mod tests { use super::*; + // === Existing tests (preserved) === + #[test] fn provider_env_is_replaced_with_placeholders() { let (child_env, resolver) = SecretResolver::from_provider_env( @@ -163,17 +763,13 @@ mod tests { ); let raw = b"POST /v1 HTTP/1.1\r\nAuthorization: Bearer openshell:resolve:env:CUSTOM_TOKEN\r\nContent-Length: 5\r\n\r\nhello"; - let rewritten = rewrite_http_header_block(raw, resolver.as_ref()); - let rewritten = String::from_utf8(rewritten).expect("utf8"); + let result = rewrite_http_header_block(raw, resolver.as_ref()).expect("should succeed"); + let rewritten = String::from_utf8(result.rewritten).expect("utf8"); assert!(rewritten.contains("Authorization: Bearer secret-token\r\n")); assert!(rewritten.ends_with("\r\n\r\nhello")); } - /// Simulates the full round-trip: provider env → child placeholders → - /// HTTP headers → rewrite. This is the exact flow that occurs when a - /// sandbox child process reads placeholder env vars, constructs an HTTP - /// request, and the proxy rewrites headers before forwarding upstream. #[test] fn full_round_trip_child_env_to_rewritten_headers() { let provider_env: HashMap = [ @@ -191,13 +787,11 @@ mod tests { let (child_env, resolver) = SecretResolver::from_provider_env(provider_env); - // Child process reads placeholders from the environment let auth_value = child_env.get("ANTHROPIC_API_KEY").unwrap(); let token_value = child_env.get("CUSTOM_SERVICE_TOKEN").unwrap(); assert!(auth_value.starts_with(PLACEHOLDER_PREFIX)); assert!(token_value.starts_with(PLACEHOLDER_PREFIX)); - // Child constructs an HTTP request using those placeholders let raw = format!( "GET /v1/messages HTTP/1.1\r\n\ Host: api.example.com\r\n\ @@ -206,11 +800,10 @@ mod tests { Content-Length: 0\r\n\r\n" ); - // Proxy rewrites headers - let rewritten = rewrite_http_header_block(raw.as_bytes(), resolver.as_ref()); - let rewritten = String::from_utf8(rewritten).expect("utf8"); + let result = + rewrite_http_header_block(raw.as_bytes(), resolver.as_ref()).expect("should succeed"); + let rewritten = String::from_utf8(result.rewritten).expect("utf8"); - // Real secrets must appear in the rewritten headers assert!( rewritten.contains("Authorization: Bearer sk-real-key-12345\r\n"), "Expected rewritten Authorization header, got: {rewritten}" @@ -219,14 +812,10 @@ mod tests { rewritten.contains("x-api-key: tok-real-svc-67890\r\n"), "Expected rewritten x-api-key header, got: {rewritten}" ); - - // Placeholders must not appear assert!( !rewritten.contains("openshell:resolve:env:"), "Placeholder leaked into rewritten request: {rewritten}" ); - - // Request line and non-secret headers must be preserved assert!(rewritten.starts_with("GET /v1/messages HTTP/1.1\r\n")); assert!(rewritten.contains("Host: api.example.com\r\n")); assert!(rewritten.contains("Content-Length: 0\r\n")); @@ -241,9 +830,8 @@ mod tests { ); let raw = b"GET / HTTP/1.1\r\nHost: example.com\r\nAccept: application/json\r\nContent-Type: text/plain\r\n\r\n"; - let rewritten = rewrite_http_header_block(raw, resolver.as_ref()); - // The output should be byte-identical since no placeholders are present - assert_eq!(raw.as_slice(), rewritten.as_slice()); + let result = rewrite_http_header_block(raw, resolver.as_ref()).expect("should succeed"); + assert_eq!(raw.as_slice(), result.rewritten.as_slice()); } #[test] @@ -256,7 +844,633 @@ mod tests { #[test] fn rewrite_with_no_resolver_returns_original() { let raw = b"GET / HTTP/1.1\r\nAuthorization: Bearer my-token\r\n\r\n"; - let rewritten = rewrite_http_header_block(raw, None); - assert_eq!(raw.as_slice(), rewritten.as_slice()); + let result = rewrite_http_header_block(raw, None).expect("should succeed"); + assert_eq!(raw.as_slice(), result.rewritten.as_slice()); + } + + // === Secret validation tests (F1 — CWE-113) === + + #[test] + fn resolve_placeholder_rejects_crlf() { + let (_, resolver) = SecretResolver::from_provider_env( + [("BAD_KEY".to_string(), "value\r\nEvil: header".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + assert!( + resolver + .resolve_placeholder("openshell:resolve:env:BAD_KEY") + .is_none() + ); + } + + #[test] + fn resolve_placeholder_rejects_null() { + let (_, resolver) = SecretResolver::from_provider_env( + [("BAD_KEY".to_string(), "value\0rest".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + assert!( + resolver + .resolve_placeholder("openshell:resolve:env:BAD_KEY") + .is_none() + ); + } + + #[test] + fn resolve_placeholder_accepts_normal_values() { + let (_, resolver) = SecretResolver::from_provider_env( + [("KEY".to_string(), "sk-abc123_DEF.456~xyz".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + assert_eq!( + resolver.resolve_placeholder("openshell:resolve:env:KEY"), + Some("sk-abc123_DEF.456~xyz") + ); + } + + // === Query parameter rewriting tests (absorbed from PR #631) === + + #[test] + fn rewrites_query_param_placeholder_in_request_line() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [("YOUTUBE_API_KEY".to_string(), "AIzaSy-secret".to_string())] + .into_iter() + .collect(), + ); + let placeholder = child_env.get("YOUTUBE_API_KEY").unwrap(); + + let raw = format!( + "GET /youtube/v3/search?part=snippet&key={placeholder} HTTP/1.1\r\n\ + Host: www.googleapis.com\r\n\r\n" + ); + let result = + rewrite_http_header_block(raw.as_bytes(), resolver.as_ref()).expect("should succeed"); + let rewritten = String::from_utf8(result.rewritten).expect("utf8"); + + assert!( + rewritten + .starts_with("GET /youtube/v3/search?part=snippet&key=AIzaSy-secret HTTP/1.1\r\n"), + "Expected query param rewritten, got: {rewritten}" + ); + assert!(!rewritten.contains("openshell:resolve:env:")); + } + + #[test] + fn rewrites_query_param_with_special_chars_percent_encoded() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [( + "API_KEY".to_string(), + "key with spaces&symbols=yes".to_string(), + )] + .into_iter() + .collect(), + ); + let placeholder = child_env.get("API_KEY").unwrap(); + + let raw = format!("GET /api?token={placeholder} HTTP/1.1\r\nHost: x\r\n\r\n"); + let result = + rewrite_http_header_block(raw.as_bytes(), resolver.as_ref()).expect("should succeed"); + let rewritten = String::from_utf8(result.rewritten).expect("utf8"); + + assert!( + rewritten.contains("token=key%20with%20spaces%26symbols%3Dyes"), + "Expected percent-encoded secret, got: {rewritten}" + ); + } + + #[test] + fn rewrites_query_param_only_placeholder_first_param() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [("KEY".to_string(), "secret123".to_string())] + .into_iter() + .collect(), + ); + let placeholder = child_env.get("KEY").unwrap(); + + let raw = format!("GET /api?key={placeholder}&format=json HTTP/1.1\r\nHost: x\r\n\r\n"); + let result = + rewrite_http_header_block(raw.as_bytes(), resolver.as_ref()).expect("should succeed"); + let rewritten = String::from_utf8(result.rewritten).expect("utf8"); + + assert!( + rewritten.starts_with("GET /api?key=secret123&format=json HTTP/1.1"), + "Expected first param rewritten, got: {rewritten}" + ); + } + + #[test] + fn no_query_param_rewrite_without_placeholder() { + let (_, resolver) = SecretResolver::from_provider_env( + [("KEY".to_string(), "secret".to_string())] + .into_iter() + .collect(), + ); + + let raw = b"GET /api?key=normalvalue HTTP/1.1\r\nHost: x\r\n\r\n"; + let result = rewrite_http_header_block(raw, resolver.as_ref()).expect("should succeed"); + assert_eq!(raw.as_slice(), result.rewritten.as_slice()); + } + + // === Basic Authorization header encoding tests (absorbed from PR #631) === + + #[test] + fn rewrites_basic_auth_placeholder_in_decoded_token() { + let b64 = base64::engine::general_purpose::STANDARD; + + let (child_env, resolver) = SecretResolver::from_provider_env( + [("DB_PASSWORD".to_string(), "s3cret!".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let placeholder = child_env.get("DB_PASSWORD").unwrap(); + + let credentials = format!("admin:{placeholder}"); + let encoded = b64.encode(credentials.as_bytes()); + + let header_line = format!("Authorization: Basic {encoded}"); + let rewritten = rewrite_header_line(&header_line, &resolver); + + let rewritten_token = rewritten.strip_prefix("Authorization: Basic ").unwrap(); + let decoded = b64.decode(rewritten_token).unwrap(); + let decoded_str = std::str::from_utf8(&decoded).unwrap(); + + assert_eq!(decoded_str, "admin:s3cret!"); + assert!(!rewritten.contains("openshell:resolve:env:")); + } + + #[test] + fn basic_auth_without_placeholder_unchanged() { + let (_, resolver) = SecretResolver::from_provider_env( + [("KEY".to_string(), "secret".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + + let b64 = base64::engine::general_purpose::STANDARD; + let encoded = b64.encode(b"user:password"); + let header_line = format!("Authorization: Basic {encoded}"); + + let rewritten = rewrite_header_line(&header_line, &resolver); + assert_eq!( + rewritten, header_line, + "Should not modify non-placeholder Basic auth" + ); + } + + #[test] + fn basic_auth_full_round_trip_header_block() { + let b64 = base64::engine::general_purpose::STANDARD; + + let (child_env, resolver) = SecretResolver::from_provider_env( + [("REGISTRY_PASS".to_string(), "hunter2".to_string())] + .into_iter() + .collect(), + ); + let placeholder = child_env.get("REGISTRY_PASS").unwrap(); + let credentials = format!("deploy:{placeholder}"); + let encoded = b64.encode(credentials.as_bytes()); + + let raw = format!( + "GET /v2/_catalog HTTP/1.1\r\n\ + Host: registry.example.com\r\n\ + Authorization: Basic {encoded}\r\n\ + Accept: application/json\r\n\r\n" + ); + + let result = + rewrite_http_header_block(raw.as_bytes(), resolver.as_ref()).expect("should succeed"); + let rewritten = String::from_utf8(result.rewritten).expect("utf8"); + + let auth_line = rewritten + .lines() + .find(|l| l.starts_with("Authorization:")) + .unwrap(); + let token = auth_line.strip_prefix("Authorization: Basic ").unwrap(); + let decoded = b64.decode(token).unwrap(); + assert_eq!(std::str::from_utf8(&decoded).unwrap(), "deploy:hunter2"); + + assert!(rewritten.contains("Host: registry.example.com\r\n")); + assert!(rewritten.contains("Accept: application/json\r\n")); + assert!(!rewritten.contains("openshell:resolve:env:")); + } + + // === Percent encoding tests (absorbed from PR #631) === + + #[test] + fn percent_encode_preserves_unreserved() { + assert_eq!(percent_encode_query("abc123-._~"), "abc123-._~"); + } + + #[test] + fn percent_encode_encodes_special_chars() { + assert_eq!(percent_encode_query("a b"), "a%20b"); + assert_eq!(percent_encode_query("key=val&x"), "key%3Dval%26x"); + } + + #[test] + fn percent_decode_round_trips() { + let original = "hello world & more=stuff"; + let encoded = percent_encode_query(original); + let decoded = percent_decode(&encoded); + assert_eq!(decoded, original); + } + + // === URL path rewriting tests === + + #[test] + fn rewrite_path_single_segment_placeholder() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [("TOKEN".to_string(), "abc123".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let placeholder = child_env.get("TOKEN").unwrap(); + + let raw = format!("GET /api/{placeholder}/data HTTP/1.1\r\nHost: x\r\n\r\n"); + let result = + rewrite_http_header_block(raw.as_bytes(), Some(&resolver)).expect("should succeed"); + let rewritten = String::from_utf8(result.rewritten).expect("utf8"); + + assert!( + rewritten.starts_with("GET /api/abc123/data HTTP/1.1"), + "Expected path rewritten, got: {rewritten}" + ); + assert_eq!( + result.redacted_target.as_deref(), + Some("/api/[CREDENTIAL]/data") + ); + } + + #[test] + fn rewrite_path_telegram_style_concatenated() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [( + "TELEGRAM_TOKEN".to_string(), + "123456:ABC-DEF1234ghIkl-zyx57W2v1u123ew11".to_string(), + )] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let placeholder = child_env.get("TELEGRAM_TOKEN").unwrap(); + + let raw = format!( + "POST /bot{placeholder}/sendMessage HTTP/1.1\r\nHost: api.telegram.org\r\n\r\n" + ); + let result = + rewrite_http_header_block(raw.as_bytes(), Some(&resolver)).expect("should succeed"); + let rewritten = String::from_utf8(result.rewritten).expect("utf8"); + + assert!( + rewritten.starts_with( + "POST /bot123456:ABC-DEF1234ghIkl-zyx57W2v1u123ew11/sendMessage HTTP/1.1" + ), + "Expected Telegram-style path rewritten, got: {rewritten}" + ); + assert_eq!( + result.redacted_target.as_deref(), + Some("/bot[CREDENTIAL]/sendMessage") + ); + } + + #[test] + fn rewrite_path_multiple_placeholders_in_separate_segments() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [ + ("ORG_ID".to_string(), "org-123".to_string()), + ("API_KEY".to_string(), "key-456".to_string()), + ] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let org_ph = child_env.get("ORG_ID").unwrap(); + let key_ph = child_env.get("API_KEY").unwrap(); + + let raw = format!("GET /orgs/{org_ph}/keys/{key_ph} HTTP/1.1\r\nHost: x\r\n\r\n"); + let result = + rewrite_http_header_block(raw.as_bytes(), Some(&resolver)).expect("should succeed"); + let rewritten = String::from_utf8(result.rewritten).expect("utf8"); + + assert!( + rewritten.starts_with("GET /orgs/org-123/keys/key-456 HTTP/1.1"), + "Expected both path segments rewritten, got: {rewritten}" + ); + } + + #[test] + fn rewrite_path_no_placeholders_unchanged() { + let (_, resolver) = SecretResolver::from_provider_env( + [("KEY".to_string(), "secret".to_string())] + .into_iter() + .collect(), + ); + + let raw = b"GET /v1/chat/completions HTTP/1.1\r\nHost: x\r\n\r\n"; + let result = rewrite_http_header_block(raw, resolver.as_ref()).expect("should succeed"); + assert_eq!(raw.as_slice(), result.rewritten.as_slice()); + assert!(result.redacted_target.is_none()); + } + + #[test] + fn rewrite_path_preserves_query_params() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [("TOKEN".to_string(), "tok123".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let placeholder = child_env.get("TOKEN").unwrap(); + + let raw = format!("GET /bot{placeholder}/method?format=json HTTP/1.1\r\nHost: x\r\n\r\n"); + let result = + rewrite_http_header_block(raw.as_bytes(), Some(&resolver)).expect("should succeed"); + let rewritten = String::from_utf8(result.rewritten).expect("utf8"); + + assert!( + rewritten.starts_with("GET /bottok123/method?format=json HTTP/1.1"), + "Expected path rewritten and query preserved, got: {rewritten}" + ); + } + + #[test] + fn rewrite_path_credential_traversal_rejected() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [("BAD".to_string(), "../admin".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let placeholder = child_env.get("BAD").unwrap(); + + let raw = format!("GET /api/{placeholder}/data HTTP/1.1\r\nHost: x\r\n\r\n"); + let result = rewrite_http_header_block(raw.as_bytes(), Some(&resolver)); + assert!( + result.is_err(), + "Path traversal credential should be rejected" + ); + } + + #[test] + fn rewrite_path_credential_backslash_rejected() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [("BAD".to_string(), "foo\\bar".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let placeholder = child_env.get("BAD").unwrap(); + + let raw = format!("GET /api/{placeholder} HTTP/1.1\r\nHost: x\r\n\r\n"); + let result = rewrite_http_header_block(raw.as_bytes(), Some(&resolver)); + assert!( + result.is_err(), + "Backslash in credential should be rejected" + ); + } + + #[test] + fn rewrite_path_credential_slash_rejected() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [("BAD".to_string(), "foo/bar".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let placeholder = child_env.get("BAD").unwrap(); + + let raw = format!("GET /api/{placeholder} HTTP/1.1\r\nHost: x\r\n\r\n"); + let result = rewrite_http_header_block(raw.as_bytes(), Some(&resolver)); + assert!( + result.is_err(), + "Slash in path credential should be rejected" + ); + } + + #[test] + fn rewrite_path_credential_null_rejected() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [("BAD".to_string(), "foo\0bar".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let placeholder = child_env.get("BAD").unwrap(); + + let raw = format!("GET /api/{placeholder} HTTP/1.1\r\nHost: x\r\n\r\n"); + // The null byte in the credential is caught by resolve_placeholder's + // validate_resolved_secret, which returns None. This triggers the + // unresolved placeholder path in rewrite_path_segment → fail-closed. + let result = rewrite_http_header_block(raw.as_bytes(), Some(&resolver)); + assert!( + result.is_err(), + "Null byte in credential should be rejected" + ); + } + + #[test] + fn rewrite_path_percent_encodes_special_chars() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [("TOKEN".to_string(), "hello world".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let placeholder = child_env.get("TOKEN").unwrap(); + + // Space in the credential should trigger path validation rejection + // since space is safe to encode but the credential also doesn't + // contain path-unsafe chars. Actually, space IS allowed (just encoded). + // Let's test with a safe value that just needs encoding. + let raw = format!("GET /api/{placeholder}/data HTTP/1.1\r\nHost: x\r\n\r\n"); + let result = + rewrite_http_header_block(raw.as_bytes(), Some(&resolver)).expect("should succeed"); + let rewritten = String::from_utf8(result.rewritten).expect("utf8"); + + assert!( + rewritten.contains("/api/hello%20world/data"), + "Expected percent-encoded path segment, got: {rewritten}" + ); + } + + // === Fail-closed tests === + + #[test] + fn unresolved_header_placeholder_returns_error() { + let (_, resolver) = SecretResolver::from_provider_env( + [("KEY".to_string(), "secret".to_string())] + .into_iter() + .collect(), + ); + + let raw = b"GET / HTTP/1.1\r\nx-api-key: openshell:resolve:env:UNKNOWN_KEY\r\n\r\n"; + let result = rewrite_http_header_block(raw, resolver.as_ref()); + assert!(result.is_err(), "Unresolved header placeholder should fail"); + } + + #[test] + fn unresolved_query_param_returns_error() { + let (_, resolver) = SecretResolver::from_provider_env( + [("KEY".to_string(), "secret".to_string())] + .into_iter() + .collect(), + ); + + let raw = b"GET /api?token=openshell:resolve:env:UNKNOWN HTTP/1.1\r\nHost: x\r\n\r\n"; + let result = rewrite_http_header_block(raw, resolver.as_ref()); + assert!( + result.is_err(), + "Unresolved query param placeholder should fail" + ); + } + + #[test] + fn unresolved_path_placeholder_returns_error() { + let (_, resolver) = SecretResolver::from_provider_env( + [("KEY".to_string(), "secret".to_string())] + .into_iter() + .collect(), + ); + + let raw = b"GET /api/openshell:resolve:env:UNKNOWN/data HTTP/1.1\r\nHost: x\r\n\r\n"; + let result = rewrite_http_header_block(raw, resolver.as_ref()); + assert!(result.is_err(), "Unresolved path placeholder should fail"); + } + + #[test] + fn percent_encoded_placeholder_in_path_caught() { + let (_, resolver) = SecretResolver::from_provider_env( + [("KEY".to_string(), "secret".to_string())] + .into_iter() + .collect(), + ); + + // Percent-encode "openshell:resolve:env:UNKNOWN" in the path + let encoded_placeholder = "openshell%3Aresolve%3Aenv%3AUNKNOWN"; + let raw = format!("GET /api/{encoded_placeholder}/data HTTP/1.1\r\nHost: x\r\n\r\n"); + let result = rewrite_http_header_block(raw.as_bytes(), resolver.as_ref()); + assert!( + result.is_err(), + "Percent-encoded placeholder should be caught by fail-closed scan" + ); + } + + #[test] + fn all_resolved_succeeds() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [ + ("KEY1".to_string(), "secret1".to_string()), + ("KEY2".to_string(), "secret2".to_string()), + ] + .into_iter() + .collect(), + ); + let ph1 = child_env.get("KEY1").unwrap(); + let ph2 = child_env.get("KEY2").unwrap(); + + let raw = format!( + "GET /api/{ph1}?token={ph2} HTTP/1.1\r\n\ + x-auth: {ph1}\r\n\r\n" + ); + let result = + rewrite_http_header_block(raw.as_bytes(), resolver.as_ref()).expect("should succeed"); + let rewritten = String::from_utf8(result.rewritten).expect("utf8"); + + assert!(!rewritten.contains("openshell:resolve:env:")); + assert!(rewritten.contains("secret1")); + assert!(rewritten.contains("secret2")); + } + + #[test] + fn no_resolver_passes_through_without_scanning() { + // Even if placeholders are present, None resolver means no scanning + let raw = b"GET /api/openshell:resolve:env:KEY HTTP/1.1\r\nHost: x\r\n\r\n"; + let result = rewrite_http_header_block(raw, None).expect("should succeed"); + assert_eq!(raw.as_slice(), result.rewritten.as_slice()); + } + + // === Redaction tests === + + #[test] + fn redacted_target_replaces_path_secrets_with_credential_marker() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [("TOKEN".to_string(), "real-secret".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let placeholder = child_env.get("TOKEN").unwrap(); + + let result = rewrite_target_for_eval(&format!("/bot{placeholder}/sendMessage"), &resolver) + .expect("should succeed"); + + assert_eq!(result.redacted, "/bot[CREDENTIAL]/sendMessage"); + assert!(result.resolved.contains("real-secret")); + assert!(!result.redacted.contains("real-secret")); + } + + #[test] + fn redacted_target_replaces_query_secrets_with_credential_marker() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [("KEY".to_string(), "secret123".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let placeholder = child_env.get("KEY").unwrap(); + + let result = + rewrite_target_for_eval(&format!("/api?key={placeholder}&format=json"), &resolver) + .expect("should succeed"); + + assert_eq!(result.redacted, "/api?key=[CREDENTIAL]&format=json"); + assert!(result.resolved.contains("secret123")); + assert!(!result.redacted.contains("secret123")); + } + + #[test] + fn redacted_target_preserves_non_secret_segments() { + let (_, resolver) = SecretResolver::from_provider_env( + [("KEY".to_string(), "secret".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + + let result = rewrite_target_for_eval("/v1/chat/completions?format=json", &resolver) + .expect("should succeed"); + + assert_eq!(result.resolved, "/v1/chat/completions?format=json"); + assert_eq!(result.redacted, "/v1/chat/completions?format=json"); + } + + #[test] + fn rewrite_target_for_eval_roundtrip() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [ + ("TOKEN".to_string(), "tok123".to_string()), + ("KEY".to_string(), "key456".to_string()), + ] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let tok_ph = child_env.get("TOKEN").unwrap(); + let key_ph = child_env.get("KEY").unwrap(); + + let target = format!("/bot{tok_ph}/method?key={key_ph}"); + let result = rewrite_target_for_eval(&target, &resolver).expect("should succeed"); + + assert_eq!(result.resolved, "/bottok123/method?key=key456"); + assert_eq!(result.redacted, "/bot[CREDENTIAL]/method?key=[CREDENTIAL]"); } } diff --git a/crates/openshell-sandbox/src/ssh.rs b/crates/openshell-sandbox/src/ssh.rs index 10eab8c45..b9f947395 100644 --- a/crates/openshell-sandbox/src/ssh.rs +++ b/crates/openshell-sandbox/src/ssh.rs @@ -12,6 +12,10 @@ use crate::{register_managed_child, unregister_managed_child}; use miette::{IntoDiagnostic, Result}; use nix::pty::{Winsize, openpty}; use nix::unistd::setsid; +use openshell_ocsf::{ + ActionId, ActivityId, AuthTypeId, ConfidenceId, DetectionFindingBuilder, DispositionId, + FindingInfo, SeverityId, SshActivityBuilder, StatusId, ocsf_emit, +}; use rand_core::OsRng; use russh::keys::{Algorithm, PrivateKey}; use russh::server::{Auth, Handle, Session}; @@ -26,7 +30,7 @@ use std::sync::{Arc, Mutex, mpsc}; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpListener; -use tracing::{info, warn}; +use tracing::warn; const PREFACE_MAGIC: &str = "NSSH1"; #[cfg(test)] @@ -60,7 +64,15 @@ async fn ssh_server_init( let config = Arc::new(config); let ca_paths = ca_file_paths.as_ref().map(|p| Arc::new(p.clone())); let listener = TcpListener::bind(listen_addr).await.into_diagnostic()?; - info!(addr = %listen_addr, "SSH server listening"); + ocsf_emit!( + SshActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Listen) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .src_endpoint_addr(listen_addr.ip(), listen_addr.port()) + .message(format!("SSH server listening on {listen_addr}")) + .build() + ); Ok((listener, config, ca_paths)) } @@ -139,7 +151,14 @@ pub async fn run_ssh_server( ) .await { - warn!(error = %err, "SSH connection failed"); + ocsf_emit!( + SshActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .message(format!("SSH connection failed: {err}")) + .build() + ); } }); } @@ -160,17 +179,40 @@ async fn handle_connection( provider_env: HashMap, nonce_cache: &NonceCache, ) -> Result<()> { - info!(peer = %peer, "SSH connection: reading handshake preface"); + tracing::debug!(peer = %peer, "SSH connection: reading handshake preface"); let mut line = String::new(); read_line(&mut stream, &mut line).await?; - info!(peer = %peer, preface_len = line.len(), "SSH connection: preface received, verifying"); + tracing::debug!(peer = %peer, preface_len = line.len(), "SSH connection: preface received, verifying"); if !verify_preface(&line, secret, handshake_skew_secs, nonce_cache)? { - warn!(peer = %peer, "SSH connection: handshake verification failed"); + ocsf_emit!( + SshActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .src_endpoint_addr(peer.ip(), peer.port()) + .message(format!( + "SSH connection: handshake verification failed from {peer}" + )) + .build() + ); let _ = stream.write_all(b"ERR\n").await; return Ok(()); } stream.write_all(b"OK\n").await.into_diagnostic()?; - info!(peer = %peer, "SSH handshake accepted"); + ocsf_emit!( + SshActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Allowed) + .disposition(DispositionId::Allowed) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .src_endpoint_addr(peer.ip(), peer.port()) + .auth_type(AuthTypeId::Other, "NSSH1") + .message(format!("SSH handshake accepted from {peer}")) + .build() + ); let handler = SshHandler::new( policy, @@ -245,7 +287,31 @@ fn verify_preface( .lock() .map_err(|_| miette::miette!("nonce cache lock poisoned"))?; if cache.contains_key(nonce) { - warn!(nonce = nonce, "NSSH1 nonce replay detected"); + ocsf_emit!( + SshActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::High) + .auth_type(AuthTypeId::Other, "NSSH1") + .message(format!("NSSH1 nonce replay detected: {nonce}")) + .build() + ); + ocsf_emit!( + DetectionFindingBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::High) + .is_alert(true) + .confidence(ConfidenceId::High) + .finding_info(FindingInfo::new( + "nssh1-nonce-replay", + "NSSH1 Nonce Replay Attack" + )) + .evidence("nonce", nonce) + .build() + ); return Ok(false); } cache.insert(nonce.to_string(), Instant::now()); @@ -263,6 +329,19 @@ fn hmac_sha256(key: &[u8], data: &[u8]) -> String { hex::encode(result) } +/// Per-channel state for tracking PTY resources and I/O senders. +/// +/// Each SSH channel gets its own PTY master (if a PTY was requested) and input +/// sender. This allows `window_change_request` to resize the correct PTY when +/// multiple channels are open simultaneously (e.g. parallel shells, shell + +/// sftp, etc.). +#[derive(Default)] +struct ChannelState { + input_sender: Option>>, + pty_master: Option, + pty_request: Option, +} + struct SshHandler { policy: SandboxPolicy, workdir: Option, @@ -270,9 +349,7 @@ struct SshHandler { proxy_url: Option, ca_file_paths: Option>, provider_env: HashMap, - input_sender: Option>>, - pty_master: Option, - pty_request: Option, + channels: HashMap, } impl SshHandler { @@ -291,9 +368,7 @@ impl SshHandler { proxy_url, ca_file_paths, provider_env, - input_sender: None, - pty_master: None, - pty_request: None, + channels: HashMap::new(), } } } @@ -315,12 +390,27 @@ impl russh::server::Handler for SshHandler { async fn channel_open_session( &mut self, - _channel: russh::Channel, + channel: russh::Channel, _session: &mut Session, ) -> Result { + self.channels.insert(channel.id(), ChannelState::default()); Ok(true) } + /// Clean up per-channel state when the channel is closed. + /// + /// This is the final cleanup and subsumes `channel_eof` — if `channel_close` + /// fires without a preceding `channel_eof`, all resources (pty_master File, + /// input_sender) are dropped here. + async fn channel_close( + &mut self, + channel: ChannelId, + _session: &mut Session, + ) -> Result<(), Self::Error> { + self.channels.remove(&channel); + Ok(()) + } + async fn channel_open_direct_tcpip( &mut self, channel: russh::Channel, @@ -334,22 +424,30 @@ impl russh::server::Handler for SshHandler { // uses u32 for ports, but valid TCP ports are 0-65535. Without this // check, port 65537 truncates to port 1 (privileged). if port_to_connect > u32::from(u16::MAX) { - warn!( - host = host_to_connect, - port = port_to_connect, - "direct-tcpip rejected: port exceeds valid TCP range (0-65535)" - ); + ocsf_emit!(SshActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Refuse) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .message(format!( + "direct-tcpip rejected: port {port_to_connect} exceeds valid TCP range for host {host_to_connect}" + )) + .build()); return Ok(false); } // Only allow forwarding to loopback destinations to prevent the // sandbox SSH server from being used as a generic proxy. if !is_loopback_host(host_to_connect) { - warn!( - host = host_to_connect, - port = port_to_connect, - "direct-tcpip rejected: non-loopback destination" - ); + ocsf_emit!(SshActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Refuse) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .message(format!( + "direct-tcpip rejected: non-loopback destination {host_to_connect}:{port_to_connect}" + )) + .build()); return Ok(false); } @@ -362,7 +460,14 @@ impl russh::server::Handler for SshHandler { let tcp = match connect_in_netns(&addr, netns_fd).await { Ok(stream) => stream, Err(err) => { - warn!(addr = %addr, error = %err, "direct-tcpip: failed to connect"); + ocsf_emit!( + SshActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .message(format!("direct-tcpip: failed to connect to {addr}: {err}")) + .build() + ); let _ = channel.close().await; return; } @@ -388,7 +493,11 @@ impl russh::server::Handler for SshHandler { _modes: &[(russh::Pty, u32)], session: &mut Session, ) -> Result<(), Self::Error> { - self.pty_request = Some(PtyRequest { + let state = self + .channels + .get_mut(&channel) + .ok_or_else(|| anyhow::anyhow!("pty_request on unknown channel {channel:?}"))?; + state.pty_request = Some(PtyRequest { term: term.to_string(), col_width, row_height, @@ -401,21 +510,27 @@ impl russh::server::Handler for SshHandler { async fn window_change_request( &mut self, - _channel: ChannelId, + channel: ChannelId, col_width: u32, row_height: u32, pixel_width: u32, pixel_height: u32, _session: &mut Session, ) -> Result<(), Self::Error> { - if let Some(master) = self.pty_master.as_ref() { + let Some(state) = self.channels.get(&channel) else { + warn!("window_change_request on unknown channel {channel:?}"); + return Ok(()); + }; + if let Some(master) = state.pty_master.as_ref() { let winsize = Winsize { ws_row: to_u16(row_height.max(1)), ws_col: to_u16(col_width.max(1)), ws_xpixel: to_u16(pixel_width), ws_ypixel: to_u16(pixel_height), }; - let _ = unsafe_pty::set_winsize(master.as_raw_fd(), winsize); + if let Err(e) = unsafe_pty::set_winsize(master.as_raw_fd(), winsize) { + warn!("failed to resize PTY for channel {channel:?}: {e}"); + } } Ok(()) } @@ -474,9 +589,20 @@ impl russh::server::Handler for SshHandler { self.ca_file_paths.clone(), &self.provider_env, )?; - self.input_sender = Some(input_sender); + let state = self.channels.get_mut(&channel).ok_or_else(|| { + anyhow::anyhow!("subsystem_request on unknown channel {channel:?}") + })?; + state.input_sender = Some(input_sender); } else { - warn!(subsystem = name, "unsupported subsystem requested"); + ocsf_emit!( + SshActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Refuse) + .action(ActionId::Denied) + .disposition(DispositionId::Rejected) + .severity(SeverityId::Medium) + .message(format!("unsupported subsystem requested: {name}")) + .build() + ); session.channel_failure(channel)?; } Ok(()) @@ -499,11 +625,15 @@ impl russh::server::Handler for SshHandler { async fn data( &mut self, - _channel: ChannelId, + channel: ChannelId, data: &[u8], _session: &mut Session, ) -> Result<(), Self::Error> { - if let Some(sender) = self.input_sender.as_ref() { + let Some(state) = self.channels.get(&channel) else { + warn!("data on unknown channel {channel:?}"); + return Ok(()); + }; + if let Some(sender) = state.input_sender.as_ref() { let _ = sender.send(data.to_vec()); } Ok(()) @@ -511,14 +641,18 @@ impl russh::server::Handler for SshHandler { async fn channel_eof( &mut self, - _channel: ChannelId, + channel: ChannelId, _session: &mut Session, ) -> Result<(), Self::Error> { // Drop the input sender so the stdin writer thread sees a // disconnected channel and closes the child's stdin pipe. This // is essential for commands like `cat | tar xf -` which need // stdin EOF to know the input stream is complete. - self.input_sender.take(); + if let Some(state) = self.channels.get_mut(&channel) { + state.input_sender.take(); + } else { + warn!("channel_eof on unknown channel {channel:?}"); + } Ok(()) } } @@ -530,7 +664,11 @@ impl SshHandler { handle: Handle, command: Option, ) -> anyhow::Result<()> { - if let Some(pty) = self.pty_request.take() { + let state = self + .channels + .get_mut(&channel) + .ok_or_else(|| anyhow::anyhow!("start_shell on unknown channel {channel:?}"))?; + if let Some(pty) = state.pty_request.take() { // PTY was requested — allocate a real PTY (interactive shell or // exec that explicitly asked for a terminal). let (pty_master, input_sender) = spawn_pty_shell( @@ -545,8 +683,8 @@ impl SshHandler { self.ca_file_paths.clone(), &self.provider_env, )?; - self.pty_master = Some(pty_master); - self.input_sender = Some(input_sender); + state.pty_master = Some(pty_master); + state.input_sender = Some(input_sender); } else { // No PTY requested — use plain pipes so stdout/stderr are // separate and output has clean LF line endings. This is the @@ -562,7 +700,7 @@ impl SshHandler { self.ca_file_paths.clone(), &self.provider_env, )?; - self.input_sender = Some(input_sender); + state.input_sender = Some(input_sender); } Ok(()) } @@ -999,7 +1137,7 @@ mod unsafe_pty { #[allow(unsafe_code)] pub fn set_winsize(fd: RawFd, winsize: Winsize) -> std::io::Result<()> { - let rc = unsafe { libc::ioctl(fd, libc::TIOCSWINSZ, winsize) }; + let rc = unsafe { libc::ioctl(fd, libc::TIOCSWINSZ, &winsize) }; if rc != 0 { return Err(std::io::Error::last_os_error()); } @@ -1404,4 +1542,111 @@ mod tests { assert!(!is_loopback_host("not-an-ip")); assert!(!is_loopback_host("[]")); } + + // ----------------------------------------------------------------------- + // Per-channel PTY state tests (#543) + // ----------------------------------------------------------------------- + + #[test] + fn set_winsize_applies_to_correct_pty() { + // Verify that set_winsize applies to a specific PTY master FD, + // which is the mechanism that per-channel tracking relies on. + // With the old single-pty_master design, a window_change_request + // for channel N would resize whatever PTY was stored last — + // potentially belonging to a different channel. + let pty_a = openpty(None, None).expect("openpty a"); + let pty_b = openpty(None, None).expect("openpty b"); + let master_a = std::fs::File::from(pty_a.master); + let master_b = std::fs::File::from(pty_b.master); + let fd_a = master_a.as_raw_fd(); + let fd_b = master_b.as_raw_fd(); + assert_ne!(fd_a, fd_b, "two PTYs must have distinct FDs"); + + // Close the slave ends to avoid leaking FDs in the test. + drop(std::fs::File::from(pty_a.slave)); + drop(std::fs::File::from(pty_b.slave)); + + // Resize only PTY B. + let winsize_b = Winsize { + ws_row: 50, + ws_col: 120, + ws_xpixel: 0, + ws_ypixel: 0, + }; + unsafe_pty::set_winsize(fd_b, winsize_b).expect("set_winsize on PTY B"); + + // Resize PTY A to a different size. + let winsize_a = Winsize { + ws_row: 24, + ws_col: 80, + ws_xpixel: 0, + ws_ypixel: 0, + }; + unsafe_pty::set_winsize(fd_a, winsize_a).expect("set_winsize on PTY A"); + + // Read back sizes via ioctl to verify independence. + let mut actual_a: libc::winsize = unsafe { std::mem::zeroed() }; + let mut actual_b: libc::winsize = unsafe { std::mem::zeroed() }; + #[allow(unsafe_code)] + unsafe { + libc::ioctl(fd_a, libc::TIOCGWINSZ, &mut actual_a); + libc::ioctl(fd_b, libc::TIOCGWINSZ, &mut actual_b); + } + + assert_eq!(actual_a.ws_row, 24, "PTY A should be 24 rows"); + assert_eq!(actual_a.ws_col, 80, "PTY A should be 80 cols"); + assert_eq!(actual_b.ws_row, 50, "PTY B should be 50 rows"); + assert_eq!(actual_b.ws_col, 120, "PTY B should be 120 cols"); + } + + #[test] + fn channel_state_independent_input_senders() { + // Verify that each channel gets its own input sender so that + // data() and channel_eof() affect only the targeted channel. + let (tx_a, rx_a) = mpsc::channel::>(); + let (tx_b, rx_b) = mpsc::channel::>(); + + let mut state_a = ChannelState { + input_sender: Some(tx_a), + ..Default::default() + }; + let state_b = ChannelState { + input_sender: Some(tx_b), + ..Default::default() + }; + + // Send data to channel A only. + state_a + .input_sender + .as_ref() + .unwrap() + .send(b"hello-a".to_vec()) + .unwrap(); + // Send data to channel B only. + state_b + .input_sender + .as_ref() + .unwrap() + .send(b"hello-b".to_vec()) + .unwrap(); + + assert_eq!(rx_a.recv().unwrap(), b"hello-a"); + assert_eq!(rx_b.recv().unwrap(), b"hello-b"); + + // EOF on channel A (drop sender) should not affect channel B. + state_a.input_sender.take(); + assert!( + rx_a.recv().is_err(), + "channel A sender dropped, recv should fail" + ); + + // Channel B should still be functional. + state_b + .input_sender + .as_ref() + .unwrap() + .send(b"still-alive".to_vec()) + .unwrap(); + assert_eq!(rx_b.recv().unwrap(), b"still-alive"); + } } diff --git a/crates/openshell-sandbox/tests/system_inference.rs b/crates/openshell-sandbox/tests/system_inference.rs index 3f6a471e5..5d581fbe2 100644 --- a/crates/openshell-sandbox/tests/system_inference.rs +++ b/crates/openshell-sandbox/tests/system_inference.rs @@ -20,6 +20,7 @@ fn make_system_route() -> ResolvedRoute { protocols: vec!["openai_chat_completions".to_string()], auth: AuthHeader::Bearer, default_headers: Vec::new(), + timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, } } @@ -32,6 +33,7 @@ fn make_user_route() -> ResolvedRoute { protocols: vec!["openai_chat_completions".to_string()], auth: AuthHeader::Bearer, default_headers: Vec::new(), + timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, } } @@ -124,6 +126,7 @@ async fn system_inference_with_anthropic_protocol() { protocols: vec!["anthropic_messages".to_string()], auth: AuthHeader::Custom("x-api-key"), default_headers: vec![("anthropic-version".to_string(), "2023-06-01".to_string())], + timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, }; let ctx = InferenceContext::new(patterns, router, vec![], vec![system_route]); diff --git a/crates/openshell-sandbox/tests/websocket_upgrade.rs b/crates/openshell-sandbox/tests/websocket_upgrade.rs new file mode 100644 index 000000000..ec226c9cf --- /dev/null +++ b/crates/openshell-sandbox/tests/websocket_upgrade.rs @@ -0,0 +1,259 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Integration test: WebSocket upgrade through the L7 relay. +//! +//! Spins up a dummy WebSocket echo server, connects a client through the +//! `L7Provider::relay` pipeline, validates the 101 upgrade succeeds, and +//! exchanges a WebSocket text frame bidirectionally. +//! +//! This test exercises the full upgrade path described in issue #652: +//! 1. Client sends HTTP GET with `Upgrade: websocket` headers +//! 2. Relay forwards to upstream, upstream responds with 101 +//! 3. Relay detects 101, validates client Upgrade headers, returns `Upgraded` +//! 4. Caller forwards overflow + switches to `copy_bidirectional` +//! 5. Client and server exchange a WebSocket text message +//! +//! Reproduction scenario from #652: raw socket test sends upgrade request +//! through the proxy, receives 101, then verifies WebSocket frames flow. + +use futures::SinkExt; +use futures::stream::StreamExt; +use openshell_sandbox::l7::provider::{BodyLength, L7Provider, L7Request, RelayOutcome}; +use openshell_sandbox::l7::rest::RestProvider; +use std::collections::HashMap; +use std::net::SocketAddr; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; +use tokio_tungstenite::accept_async; +use tokio_tungstenite::tungstenite::Message; + +/// Start a minimal WebSocket echo server on an ephemeral port. +async fn start_ws_echo_server() -> SocketAddr { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let ws_stream = accept_async(stream).await.unwrap(); + let (mut write, mut read) = ws_stream.split(); + + while let Some(msg) = read.next().await { + match msg { + Ok(Message::Text(text)) => { + write + .send(Message::Text(format!("echo: {text}").into())) + .await + .unwrap(); + } + Ok(Message::Close(_)) => break, + Ok(_) => {} + Err(_) => break, + } + } + }); + + addr +} + +/// Build raw HTTP upgrade request bytes (mimics the reproduction script from #652). +fn build_ws_upgrade_request(host: &str) -> Vec { + format!( + "GET / HTTP/1.1\r\n\ + Host: {host}\r\n\ + Upgrade: websocket\r\n\ + Connection: Upgrade\r\n\ + Sec-WebSocket-Key: RylUQAh3p5cysfOlexgubw==\r\n\ + Sec-WebSocket-Version: 13\r\n\ + \r\n" + ) + .into_bytes() +} + +/// Build a masked WebSocket text frame (client -> server must be masked per RFC 6455). +fn build_ws_text_frame(payload: &[u8]) -> Vec { + let mask_key: [u8; 4] = [0x37, 0xfa, 0x21, 0x3d]; + let mut frame = Vec::new(); + frame.push(0x81); // FIN + text opcode + frame.push(0x80 | payload.len() as u8); // masked + length + frame.extend_from_slice(&mask_key); + for (i, b) in payload.iter().enumerate() { + frame.push(b ^ mask_key[i % 4]); + } + frame +} + +/// Core test: WebSocket upgrade through `L7Provider::relay`, then exchange a message. +/// +/// This mirrors the reproduction steps from issue #652: +/// - Send WebSocket upgrade → receive 101 → verify frames flow bidirectionally +/// - Previously, 101 was treated as a generic 1xx and frames were dropped +#[tokio::test] +async fn websocket_upgrade_through_l7_relay_exchanges_message() { + let ws_addr = start_ws_echo_server().await; + + // Open a real TCP connection to the WebSocket server (simulates upstream) + let mut upstream = TcpStream::connect(ws_addr).await.unwrap(); + + // In-memory duplex for the client side of the relay + let (mut client_app, mut client_proxy) = tokio::io::duplex(8192); + + let host = format!("127.0.0.1:{}", ws_addr.port()); + let raw_header = build_ws_upgrade_request(&host); + + let req = L7Request { + action: "GET".to_string(), + target: "/".to_string(), + query_params: HashMap::new(), + raw_header, + body_length: BodyLength::None, + }; + + // Run the relay in a background task (simulates what relay_rest does) + let relay_handle = tokio::spawn(async move { + let outcome = RestProvider + .relay(&req, &mut client_proxy, &mut upstream) + .await + .expect("relay should succeed"); + + match outcome { + RelayOutcome::Upgraded { overflow } => { + // This is what handle_upgrade() does in relay.rs + if !overflow.is_empty() { + client_proxy.write_all(&overflow).await.unwrap(); + client_proxy.flush().await.unwrap(); + } + let _ = tokio::io::copy_bidirectional(&mut client_proxy, &mut upstream).await; + } + other => panic!("Expected Upgraded, got {other:?}"), + } + }); + + // Client side: read the 101 response headers byte-by-byte + // (mirrors the reproduction script's recv() after sending the upgrade) + let mut response_buf = Vec::new(); + let mut tmp = [0u8; 1]; + tokio::time::timeout(std::time::Duration::from_secs(5), async { + loop { + client_app.read_exact(&mut tmp).await.unwrap(); + response_buf.push(tmp[0]); + if response_buf.ends_with(b"\r\n\r\n") { + break; + } + } + }) + .await + .expect("should receive 101 headers within 5 seconds"); + + let response_str = String::from_utf8_lossy(&response_buf); + assert!( + response_str.contains("101 Switching Protocols"), + "should receive 101, got: {response_str}" + ); + + // ---- This is the part that was broken before the fix (issue #652) ---- + // Previously, after 101, the relay re-entered the HTTP parsing loop and + // all WebSocket frames were silently dropped. The reproduction script + // would see RECV2: TIMEOUT here. + + // Send a WebSocket text frame + let frame = build_ws_text_frame(b"hello"); + client_app.write_all(&frame).await.unwrap(); + client_app.flush().await.unwrap(); + + // Read the echo response (unmasked server -> client frame) + tokio::time::timeout(std::time::Duration::from_secs(5), async { + let mut header = [0u8; 2]; + client_app.read_exact(&mut header).await.unwrap(); + + let fin_opcode = header[0]; + assert_eq!(fin_opcode & 0x0F, 1, "should be text frame"); + assert!(fin_opcode & 0x80 != 0, "FIN bit should be set"); + + let len = (header[1] & 0x7F) as usize; + let mut payload_buf = vec![0u8; len]; + client_app.read_exact(&mut payload_buf).await.unwrap(); + let text = String::from_utf8(payload_buf).unwrap(); + assert_eq!( + text, "echo: hello", + "server should echo our message back through the relay" + ); + }) + .await + .expect("should receive WebSocket echo within 5 seconds (previously timed out per #652)"); + + // Clean shutdown + let close_frame = [0x88, 0x82, 0x00, 0x00, 0x00, 0x00, 0x03, 0xe8]; + let _ = client_app.write_all(&close_frame).await; + drop(client_app); + + let _ = tokio::time::timeout(std::time::Duration::from_secs(2), relay_handle).await; +} + +/// Test that a normal (non-upgrade) HTTP request still works correctly +/// after the relay_response changes. Ensures the 101 detection doesn't +/// break regular HTTP traffic. +#[tokio::test] +async fn normal_http_request_still_works_after_relay_changes() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + // Simple HTTP echo server + tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut buf = vec![0u8; 4096]; + let mut total = 0; + loop { + let n = stream.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + } + stream + .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok") + .await + .unwrap(); + stream.flush().await.unwrap(); + }); + + let mut upstream = TcpStream::connect(addr).await.unwrap(); + let (mut client_read, mut client_proxy) = tokio::io::duplex(8192); + + let raw_header = format!( + "GET /api HTTP/1.1\r\nHost: 127.0.0.1:{}\r\n\r\n", + addr.port() + ) + .into_bytes(); + + let req = L7Request { + action: "GET".to_string(), + target: "/api".to_string(), + query_params: HashMap::new(), + raw_header, + body_length: BodyLength::None, + }; + + let outcome = tokio::time::timeout( + std::time::Duration::from_secs(5), + RestProvider.relay(&req, &mut client_proxy, &mut upstream), + ) + .await + .expect("should not deadlock") + .expect("relay should succeed"); + + assert!( + matches!(outcome, RelayOutcome::Reusable), + "normal 200 response should be Reusable, got {outcome:?}" + ); + + client_proxy.shutdown().await.unwrap(); + let mut received = Vec::new(); + client_read.read_to_end(&mut received).await.unwrap(); + let body = String::from_utf8_lossy(&received); + assert!(body.contains("200 OK"), "should forward 200 response"); + assert!(body.contains("ok"), "should forward response body"); +} diff --git a/crates/openshell-server/src/auth.rs b/crates/openshell-server/src/auth.rs index 5a3229ffa..b896d062c 100644 --- a/crates/openshell-server/src/auth.rs +++ b/crates/openshell-server/src/auth.rs @@ -22,11 +22,28 @@ use axum::{ response::{Html, IntoResponse}, routing::get, }; +use http::header; use serde::Deserialize; use std::sync::Arc; use crate::ServerState; +/// Validate that a confirmation code matches the CLI-generated format. +/// +/// Codes are 3 alphanumeric characters, a dash, then 4 alphanumeric characters +/// (e.g., "AB7-X9KM"). The CLI generates these from the charset `[A-Z2-9]`. +fn is_valid_code(code: &str) -> bool { + let bytes = code.as_bytes(); + bytes.len() == 8 + && bytes[3] == b'-' + && bytes[..3] + .iter() + .all(|b| b.is_ascii_uppercase() || b.is_ascii_digit()) + && bytes[4..] + .iter() + .all(|b| b.is_ascii_uppercase() || b.is_ascii_digit()) +} + #[derive(Deserialize)] struct ConnectParams { callback_port: u16, @@ -54,6 +71,15 @@ async fn auth_connect( Query(params): Query, headers: HeaderMap, ) -> impl IntoResponse { + // Reject codes that don't match the CLI-generated format to prevent + // reflected XSS via crafted URLs. + if !is_valid_code(¶ms.code) { + return Html( + "

Invalid confirmation code format.

".to_string(), + ) + .into_response(); + } + let cf_token = headers .get("cookie") .and_then(|v| v.to_str().ok()) @@ -68,14 +94,34 @@ async fn auth_connect( .and_then(|v| v.to_str().ok()) .map_or_else(|| state.config.bind_address.to_string(), String::from); + let safe_gateway = html_escape(&gateway_display); + match cf_token { - Some(token) => Html(render_connect_page( - &gateway_display, - params.callback_port, - &token, - ¶ms.code, - )), - None => Html(render_waiting_page(params.callback_port, ¶ms.code)), + Some(token) => { + let nonce = uuid::Uuid::new_v4().to_string(); + let csp = format!( + "default-src 'none'; script-src 'nonce-{nonce}'; style-src 'unsafe-inline'; connect-src http://127.0.0.1:*" + ); + ( + [(header::CONTENT_SECURITY_POLICY, csp)], + Html(render_connect_page( + &safe_gateway, + params.callback_port, + &token, + ¶ms.code, + &nonce, + )), + ) + .into_response() + } + None => { + let csp = "default-src 'none'; style-src 'unsafe-inline'".to_string(); + ( + [(header::CONTENT_SECURITY_POLICY, csp)], + Html(render_waiting_page(params.callback_port, ¶ms.code)), + ) + .into_response() + } } } @@ -104,22 +150,27 @@ fn render_connect_page( callback_port: u16, cf_token: &str, code: &str, + nonce: &str, ) -> String { - // Escape the token for safe embedding in a JS string literal. - let escaped_token = cf_token - .replace('\\', "\\\\") - .replace('\'', "\\'") - .replace('"', "\\\"") - .replace('<', "\\x3c") - .replace('>', "\\x3e"); + // Use JSON serialization for JS-safe string embedding — handles all + // edge cases including \n, \r, U+2028, U+2029 that break JS string + // literals. serde_json::to_string produces a quoted JSON string + // (e.g., "value") which is a valid JS string literal. + // + // We additionally escape < and > to \u003c / \u003e because while + // they're valid in JSON, they're dangerous inside an HTML before the JS parser runs). + let json_token = serde_json::to_string(cf_token) + .unwrap_or_else(|_| "\"\"".to_string()) + .replace('<', "\\u003c") + .replace('>', "\\u003e"); + let json_code = serde_json::to_string(code) + .unwrap_or_else(|_| "\"\"".to_string()) + .replace('<', "\\u003c") + .replace('>', "\\u003e"); - // Escape the code the same way (it's alphanumeric + dash, but be safe). - let escaped_code = code - .replace('\\', "\\\\") - .replace('\'', "\\'") - .replace('"', "\\\"") - .replace('<', "\\x3c") - .replace('>', "\\x3e"); + // HTML-safe version of the code for display in the page body. + let html_code = html_escape(code); let version = openshell_core::VERSION; @@ -250,7 +301,7 @@ fn render_connect_page(
Connect to Gateway
Confirmation Code
-
{escaped_code}
+
{html_code}
Verify this matches the code shown in your terminal
@@ -271,9 +322,9 @@ fn render_connect_page(
- ", "ABC-1234"); - // < and > should be escaped + let html = render_connect_page( + "gw", + 1234, + "token", + "ABC-1234", + "nonce", + ); + // < and > should be escaped via JSON encoding (\u003c) assert!(!html.contains("