Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions architecture/inference-routing.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,16 +171,17 @@ Files:
`prepare_backend_request()` (shared by both buffered and streaming paths) rewrites outgoing requests:

1. **Auth injection**: Uses the route's `AuthHeader` -- either `Authorization: Bearer <key>` or a custom header (e.g. `x-api-key: <key>` for Anthropic).
2. **Header stripping**: Removes `authorization`, `x-api-key`, `host`, and any header names that will be set from route defaults.
3. **Default headers**: Applies route-level default headers (e.g. `anthropic-version: 2023-06-01`) unless the client already sent them.
4. **Model rewrite**: Parses the request body as JSON and replaces the `model` field with the route's configured model. Non-JSON bodies are forwarded unchanged.
5. **URL construction**: `build_backend_url()` appends the request path to the route endpoint. If the endpoint already ends with `/v1` and the request path starts with `/v1/`, the duplicate prefix is deduplicated.
2. **Header allowlist**: Keeps only explicitly approved request headers: common inference headers (`content-type`, `accept`, `accept-encoding`, `user-agent`), route-specific passthrough headers (for example `openai-organization`, `x-model-id`, `anthropic-version`, `anthropic-beta`), and any route default header names.
3. **Header stripping**: Removes `authorization`, `x-api-key`, `host`, `content-length`, hop-by-hop headers, and any non-allowlisted request headers.
4. **Default headers**: Applies route-level default headers (e.g. `anthropic-version: 2023-06-01`) unless the client already sent them.
5. **Model rewrite**: Parses the request body as JSON and replaces the `model` field with the route's configured model. Non-JSON bodies are forwarded unchanged.
6. **URL construction**: `build_backend_url()` appends the request path to the route endpoint. If the endpoint already ends with `/v1` and the request path starts with `/v1/`, the duplicate prefix is deduplicated.

### Header sanitization

Before forwarding inference requests, the proxy strips sensitive and hop-by-hop headers from both requests and responses:
Before forwarding inference requests, the router enforces a route-aware request allowlist and strips sensitive/framing headers. Response sanitization remains framing-only:

- **Request**: `authorization`, `x-api-key`, `host`, `content-length`, and hop-by-hop headers (`connection`, `keep-alive`, `proxy-authenticate`, `proxy-authorization`, `proxy-connection`, `te`, `trailer`, `transfer-encoding`, `upgrade`).
- **Request**: forwards only common inference headers plus route-specific passthrough headers and route default header names. Always strips `authorization`, `x-api-key`, `host`, `content-length`, unknown headers such as `cookie`, and hop-by-hop headers (`connection`, `keep-alive`, `proxy-authenticate`, `proxy-authorization`, `proxy-connection`, `te`, `trailer`, `transfer-encoding`, `upgrade`).
- **Response**: `content-length` and hop-by-hop headers.

### Response streaming
Expand Down
4 changes: 2 additions & 2 deletions architecture/sandbox.md
Original file line number Diff line number Diff line change
Expand Up @@ -839,9 +839,9 @@ The interception steps:

Pattern matching strips query strings. Exact path comparison is used for most patterns; the `/v1/models/*` pattern matches `/v1/models` itself or any path under `/v1/models/` (e.g., `/v1/models/gpt-4.1`).

4. **Header sanitization**: For matched inference requests, the proxy strips credential headers (`Authorization`, `x-api-key`) and framing/hop-by-hop headers (`host`, `content-length`, `transfer-encoding`, `connection`, etc.). The router rebuilds correct framing for the forwarded body.
4. **Header sanitization**: For matched inference requests, the proxy passes the parsed headers to the router. The router applies a route-aware allowlist before forwarding: common inference headers (`content-type`, `accept`, `accept-encoding`, `user-agent`), provider-specific passthrough headers (for example `openai-organization`, `x-model-id`, `anthropic-version`, `anthropic-beta`), and any route default header names. It always strips client-supplied credential headers (`Authorization`, `x-api-key`) and framing/hop-by-hop headers (`host`, `content-length`, `transfer-encoding`, `connection`, etc.). The router rebuilds correct framing for the forwarded body.

5. **Local routing**: Matched requests are executed by calling `Router::proxy_with_candidates_streaming()`, passing the detected protocol, HTTP method, path, sanitized headers, body, and the cached `ResolvedRoute` list from `InferenceContext`. The router selects the first route whose `protocols` list contains the source protocol (see [Inference Routing -- Response streaming](inference-routing.md#response-streaming) for details). When forwarding to the backend, the router rewrites the request: the route's `api_key` replaces the `Authorization` header, the `Host` header is set to the backend endpoint, and the `"model"` field in the JSON request body is replaced with the route's configured `model` value. If the request body is not valid JSON or does not contain a `"model"` key, the body is forwarded unchanged.
5. **Local routing**: Matched requests are executed by calling `Router::proxy_with_candidates_streaming()`, passing the detected protocol, HTTP method, path, original parsed headers, body, and the cached `ResolvedRoute` list from `InferenceContext`. The router selects the first route whose `protocols` list contains the source protocol (see [Inference Routing -- Response streaming](inference-routing.md#response-streaming) for details). When forwarding to the backend, the router rewrites the request: the route's `api_key` replaces the client auth header, the `Host` header is set to the backend endpoint, only allowlisted request headers survive, and the `"model"` field in the JSON request body is replaced with the route's configured `model` value. If the request body is not valid JSON or does not contain a `"model"` key, the body is forwarded unchanged.

6. **Response handling (streaming)**:
- On success: response headers are sent back to the client immediately as an HTTP/1.1 response with `Transfer-Encoding: chunked`, using `format_http_response_header()`. Framing/hop-by-hop headers are stripped from the upstream response. Body chunks are then forwarded incrementally as they arrive from the backend via `StreamingProxyResponse::next_chunk()`, each wrapped in HTTP chunked encoding by `format_chunk()`. The stream is terminated with a `0\r\n\r\n` chunk terminator. This ensures time-to-first-byte reflects the backend's first token latency rather than the full generation time.
Expand Down
64 changes: 61 additions & 3 deletions crates/openshell-core/src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ pub enum AuthHeader {
///
/// This is the single source of truth for provider-specific inference knowledge:
/// default endpoint, supported protocols, credential key lookup order, auth
/// header style, and default headers.
/// header style, default headers, and allowed client-supplied passthrough
/// headers.
///
/// This is separate from [`openshell_providers::ProviderPlugin`] which handles
/// credential *discovery* (scanning env vars). `InferenceProviderProfile` handles
Expand All @@ -45,6 +46,10 @@ pub struct InferenceProviderProfile {
pub auth: AuthHeader,
/// Default headers injected on every outgoing request.
pub default_headers: &'static [(&'static str, &'static str)],
/// Client-supplied headers that may be forwarded to the upstream backend.
///
/// Header names must be lowercase and must not include auth headers.
pub passthrough_headers: &'static [&'static str],
}

const OPENAI_PROTOCOLS: &[&str] = &[
Expand All @@ -64,6 +69,7 @@ static OPENAI_PROFILE: InferenceProviderProfile = InferenceProviderProfile {
base_url_config_keys: &["OPENAI_BASE_URL"],
auth: AuthHeader::Bearer,
default_headers: &[],
passthrough_headers: &["openai-organization", "x-model-id"],
};

static ANTHROPIC_PROFILE: InferenceProviderProfile = InferenceProviderProfile {
Expand All @@ -74,6 +80,7 @@ static ANTHROPIC_PROFILE: InferenceProviderProfile = InferenceProviderProfile {
base_url_config_keys: &["ANTHROPIC_BASE_URL"],
auth: AuthHeader::Custom("x-api-key"),
default_headers: &[("anthropic-version", "2023-06-01")],
passthrough_headers: &["anthropic-version", "anthropic-beta"],
};

static NVIDIA_PROFILE: InferenceProviderProfile = InferenceProviderProfile {
Expand All @@ -84,6 +91,7 @@ static NVIDIA_PROFILE: InferenceProviderProfile = InferenceProviderProfile {
base_url_config_keys: &["NVIDIA_BASE_URL"],
auth: AuthHeader::Bearer,
default_headers: &[],
passthrough_headers: &["x-model-id"],
};

/// Look up the inference provider profile for a given provider type.
Expand All @@ -105,16 +113,32 @@ pub fn profile_for(provider_type: &str) -> Option<&'static InferenceProviderProf
/// need the auth/header information (e.g. the sandbox bundle-to-route
/// conversion).
pub fn auth_for_provider_type(provider_type: &str) -> (AuthHeader, Vec<(String, String)>) {
let (auth, headers, _) = route_headers_for_provider_type(provider_type);
(auth, headers)
}

/// Derive routing header policy for a provider type string.
///
/// Returns the auth injection mode, route-level default headers, and the
/// allowed client-supplied passthrough headers for `inference.local`.
pub fn route_headers_for_provider_type(
provider_type: &str,
) -> (AuthHeader, Vec<(String, String)>, Vec<String>) {
match profile_for(provider_type) {
Some(profile) => {
let headers = profile
.default_headers
.iter()
.map(|(k, v)| ((*k).to_string(), (*v).to_string()))
.collect();
(profile.auth.clone(), headers)
let passthrough_headers = profile
.passthrough_headers
.iter()
.map(|name| (*name).to_string())
.collect();
(profile.auth.clone(), headers, passthrough_headers)
}
None => (AuthHeader::Bearer, Vec::new()),
None => (AuthHeader::Bearer, Vec::new(), Vec::new()),
}
}

Expand Down Expand Up @@ -193,6 +217,32 @@ mod tests {
assert!(headers.iter().any(|(k, _)| k == "anthropic-version"));
}

#[test]
fn route_headers_for_openai_include_passthrough_headers() {
let (_, _, passthrough_headers) = route_headers_for_provider_type("openai");
assert!(
passthrough_headers
.iter()
.any(|name| name == "openai-organization")
);
assert!(passthrough_headers.iter().any(|name| name == "x-model-id"));
}

#[test]
fn route_headers_for_anthropic_include_passthrough_headers() {
let (_, _, passthrough_headers) = route_headers_for_provider_type("anthropic");
assert!(
passthrough_headers
.iter()
.any(|name| name == "anthropic-version")
);
assert!(
passthrough_headers
.iter()
.any(|name| name == "anthropic-beta")
);
}

#[test]
fn auth_for_openai_uses_bearer() {
let (auth, headers) = auth_for_provider_type("openai");
Expand All @@ -206,4 +256,12 @@ mod tests {
assert_eq!(auth, AuthHeader::Bearer);
assert!(headers.is_empty());
}

#[test]
fn route_headers_for_unknown_are_empty() {
let (auth, headers, passthrough_headers) = route_headers_for_provider_type("unknown");
assert_eq!(auth, AuthHeader::Bearer);
assert!(headers.is_empty());
assert!(passthrough_headers.is_empty());
}
}
161 changes: 149 additions & 12 deletions crates/openshell-router/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
use crate::RouterError;
use crate::config::{AuthHeader, ResolvedRoute};
use crate::mock;
use std::collections::HashSet;

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ValidatedEndpoint {
Expand Down Expand Up @@ -62,6 +63,9 @@ enum StreamingBody {
Buffered(Option<bytes::Bytes>),
}

const COMMON_INFERENCE_REQUEST_HEADERS: [&str; 4] =
["content-type", "accept", "accept-encoding", "user-agent"];

impl StreamingProxyResponse {
/// Create from a fully-buffered [`ProxyResponse`] (for mock routes).
pub fn from_buffered(resp: ProxyResponse) -> Self {
Expand All @@ -83,7 +87,64 @@ impl StreamingProxyResponse {
}
}

/// Build an HTTP request to the backend configured in `route`.
fn sanitize_request_headers(
route: &ResolvedRoute,
headers: &[(String, String)],
) -> Vec<(String, String)> {
let mut allowed = HashSet::new();
allowed.extend(
COMMON_INFERENCE_REQUEST_HEADERS
.iter()
.map(|name| (*name).to_string()),
);
allowed.extend(
route
.passthrough_headers
.iter()
.map(|name| name.to_ascii_lowercase()),
);
allowed.extend(
route
.default_headers
.iter()
.map(|(name, _)| name.to_ascii_lowercase()),
);

headers
.iter()
.filter_map(|(name, value)| {
let name_lc = name.to_ascii_lowercase();
if should_strip_request_header(&name_lc) || !allowed.contains(&name_lc) {
return None;
}
Some((name.clone(), value.clone()))
})
.collect()
}

fn should_strip_request_header(name: &str) -> bool {
matches!(
name,
"authorization" | "x-api-key" | "host" | "content-length"
) || is_hop_by_hop_header(name)
}

fn is_hop_by_hop_header(name: &str) -> bool {
matches!(
name,
"connection"
| "keep-alive"
| "proxy-authenticate"
| "proxy-authorization"
| "proxy-connection"
| "te"
| "trailer"
| "transfer-encoding"
| "upgrade"
)
}

/// Build and send an HTTP request to the backend configured in `route`.
///
/// Returns the prepared [`reqwest::RequestBuilder`] with auth, headers, model
/// rewrite, and body applied. The caller decides whether to apply a total
Expand All @@ -97,6 +158,7 @@ fn prepare_backend_request(
body: bytes::Bytes,
) -> Result<(reqwest::RequestBuilder, String), RouterError> {
let url = build_backend_url(&route.endpoint, path);
let headers = sanitize_request_headers(route, &headers);

let reqwest_method: reqwest::Method = method
.parse()
Expand All @@ -113,17 +175,7 @@ fn prepare_backend_request(
builder = builder.header(*header_name, &route.api_key);
}
}

// Strip auth and host headers — auth is re-injected above from the route
// config, and host must match the upstream.
let strip_headers: [&str; 3] = ["authorization", "x-api-key", "host"];

// Forward non-sensitive headers.
for (name, value) in headers {
let name_lc = name.to_ascii_lowercase();
if strip_headers.contains(&name_lc.as_str()) {
continue;
}
for (name, value) in &headers {
builder = builder.header(name.as_str(), value.as_str());
}

Expand Down Expand Up @@ -510,10 +562,95 @@ mod tests {
protocols: protocols.iter().map(|p| (*p).to_string()).collect(),
auth,
default_headers: vec![("anthropic-version".to_string(), "2023-06-01".to_string())],
passthrough_headers: vec![
"anthropic-version".to_string(),
"anthropic-beta".to_string(),
],
timeout: crate::config::DEFAULT_ROUTE_TIMEOUT,
}
}

#[test]
fn sanitize_request_headers_drops_unknown_sensitive_headers() {
let route = ResolvedRoute {
name: "inference.local".to_string(),
endpoint: "https://api.example.com/v1".to_string(),
model: "test-model".to_string(),
api_key: "sk-test".to_string(),
protocols: vec!["openai_chat_completions".to_string()],
auth: AuthHeader::Bearer,
default_headers: Vec::new(),
passthrough_headers: vec!["openai-organization".to_string()],
timeout: crate::config::DEFAULT_ROUTE_TIMEOUT,
};

let kept = super::sanitize_request_headers(
&route,
&[
("content-type".to_string(), "application/json".to_string()),
("authorization".to_string(), "Bearer client".to_string()),
("cookie".to_string(), "session=1".to_string()),
("x-amz-security-token".to_string(), "token".to_string()),
("openai-organization".to_string(), "org_123".to_string()),
],
);

assert!(
kept.iter()
.any(|(name, _)| name.eq_ignore_ascii_case("content-type"))
);
assert!(
kept.iter()
.any(|(name, _)| name.eq_ignore_ascii_case("openai-organization"))
);
assert!(
kept.iter()
.all(|(name, _)| !name.eq_ignore_ascii_case("authorization"))
);
assert!(
kept.iter()
.all(|(name, _)| !name.eq_ignore_ascii_case("cookie"))
);
assert!(
kept.iter()
.all(|(name, _)| !name.eq_ignore_ascii_case("x-amz-security-token"))
);
}

#[test]
fn sanitize_request_headers_preserves_allowed_provider_headers() {
let route = test_route(
"https://api.anthropic.com/v1",
&["anthropic_messages"],
AuthHeader::Custom("x-api-key"),
);

let kept = super::sanitize_request_headers(
&route,
&[
("anthropic-version".to_string(), "2024-10-22".to_string()),
(
"anthropic-beta".to_string(),
"tool-use-2024-10-22".to_string(),
),
("x-api-key".to_string(), "client-key".to_string()),
],
);

assert!(kept.iter().any(
|(name, value)| name.eq_ignore_ascii_case("anthropic-version") && value == "2024-10-22"
));
assert!(
kept.iter()
.any(|(name, value)| name.eq_ignore_ascii_case("anthropic-beta")
&& value == "tool-use-2024-10-22")
);
assert!(
kept.iter()
.all(|(name, _)| !name.eq_ignore_ascii_case("x-api-key"))
);
}

#[tokio::test]
async fn verify_backend_endpoint_uses_route_auth_and_shape() {
let mock_server = MockServer::start().await;
Expand Down
Loading
Loading