diff --git a/architecture/inference-routing.md b/architecture/inference-routing.md index e2fd67178..c0c42f4f6 100644 --- a/architecture/inference-routing.md +++ b/architecture/inference-routing.md @@ -171,16 +171,17 @@ Files: `prepare_backend_request()` (shared by both buffered and streaming paths) rewrites outgoing requests: 1. **Auth injection**: Uses the route's `AuthHeader` -- either `Authorization: Bearer ` or a custom header (e.g. `x-api-key: ` for Anthropic). -2. **Header stripping**: Removes `authorization`, `x-api-key`, `host`, and any header names that will be set from route defaults. -3. **Default headers**: Applies route-level default headers (e.g. `anthropic-version: 2023-06-01`) unless the client already sent them. -4. **Model rewrite**: Parses the request body as JSON and replaces the `model` field with the route's configured model. Non-JSON bodies are forwarded unchanged. -5. **URL construction**: `build_backend_url()` appends the request path to the route endpoint. If the endpoint already ends with `/v1` and the request path starts with `/v1/`, the duplicate prefix is deduplicated. +2. **Header allowlist**: Keeps only explicitly approved request headers: common inference headers (`content-type`, `accept`, `accept-encoding`, `user-agent`), route-specific passthrough headers (for example `openai-organization`, `x-model-id`, `anthropic-version`, `anthropic-beta`), and any route default header names. +3. **Header stripping**: Removes `authorization`, `x-api-key`, `host`, `content-length`, hop-by-hop headers, and any non-allowlisted request headers. +4. **Default headers**: Applies route-level default headers (e.g. `anthropic-version: 2023-06-01`) unless the client already sent them. +5. **Model rewrite**: Parses the request body as JSON and replaces the `model` field with the route's configured model. Non-JSON bodies are forwarded unchanged. +6. **URL construction**: `build_backend_url()` appends the request path to the route endpoint. If the endpoint already ends with `/v1` and the request path starts with `/v1/`, the duplicate prefix is deduplicated. ### Header sanitization -Before forwarding inference requests, the proxy strips sensitive and hop-by-hop headers from both requests and responses: +Before forwarding inference requests, the router enforces a route-aware request allowlist and strips sensitive/framing headers. Response sanitization remains framing-only: -- **Request**: `authorization`, `x-api-key`, `host`, `content-length`, and hop-by-hop headers (`connection`, `keep-alive`, `proxy-authenticate`, `proxy-authorization`, `proxy-connection`, `te`, `trailer`, `transfer-encoding`, `upgrade`). +- **Request**: forwards only common inference headers plus route-specific passthrough headers and route default header names. Always strips `authorization`, `x-api-key`, `host`, `content-length`, unknown headers such as `cookie`, and hop-by-hop headers (`connection`, `keep-alive`, `proxy-authenticate`, `proxy-authorization`, `proxy-connection`, `te`, `trailer`, `transfer-encoding`, `upgrade`). - **Response**: `content-length` and hop-by-hop headers. ### Response streaming diff --git a/architecture/sandbox.md b/architecture/sandbox.md index 656f42138..e83681466 100644 --- a/architecture/sandbox.md +++ b/architecture/sandbox.md @@ -839,9 +839,9 @@ The interception steps: Pattern matching strips query strings. Exact path comparison is used for most patterns; the `/v1/models/*` pattern matches `/v1/models` itself or any path under `/v1/models/` (e.g., `/v1/models/gpt-4.1`). -4. **Header sanitization**: For matched inference requests, the proxy strips credential headers (`Authorization`, `x-api-key`) and framing/hop-by-hop headers (`host`, `content-length`, `transfer-encoding`, `connection`, etc.). The router rebuilds correct framing for the forwarded body. +4. **Header sanitization**: For matched inference requests, the proxy passes the parsed headers to the router. The router applies a route-aware allowlist before forwarding: common inference headers (`content-type`, `accept`, `accept-encoding`, `user-agent`), provider-specific passthrough headers (for example `openai-organization`, `x-model-id`, `anthropic-version`, `anthropic-beta`), and any route default header names. It always strips client-supplied credential headers (`Authorization`, `x-api-key`) and framing/hop-by-hop headers (`host`, `content-length`, `transfer-encoding`, `connection`, etc.). The router rebuilds correct framing for the forwarded body. -5. **Local routing**: Matched requests are executed by calling `Router::proxy_with_candidates_streaming()`, passing the detected protocol, HTTP method, path, sanitized headers, body, and the cached `ResolvedRoute` list from `InferenceContext`. The router selects the first route whose `protocols` list contains the source protocol (see [Inference Routing -- Response streaming](inference-routing.md#response-streaming) for details). When forwarding to the backend, the router rewrites the request: the route's `api_key` replaces the `Authorization` header, the `Host` header is set to the backend endpoint, and the `"model"` field in the JSON request body is replaced with the route's configured `model` value. If the request body is not valid JSON or does not contain a `"model"` key, the body is forwarded unchanged. +5. **Local routing**: Matched requests are executed by calling `Router::proxy_with_candidates_streaming()`, passing the detected protocol, HTTP method, path, original parsed headers, body, and the cached `ResolvedRoute` list from `InferenceContext`. The router selects the first route whose `protocols` list contains the source protocol (see [Inference Routing -- Response streaming](inference-routing.md#response-streaming) for details). When forwarding to the backend, the router rewrites the request: the route's `api_key` replaces the client auth header, the `Host` header is set to the backend endpoint, only allowlisted request headers survive, and the `"model"` field in the JSON request body is replaced with the route's configured `model` value. If the request body is not valid JSON or does not contain a `"model"` key, the body is forwarded unchanged. 6. **Response handling (streaming)**: - On success: response headers are sent back to the client immediately as an HTTP/1.1 response with `Transfer-Encoding: chunked`, using `format_http_response_header()`. Framing/hop-by-hop headers are stripped from the upstream response. Body chunks are then forwarded incrementally as they arrive from the backend via `StreamingProxyResponse::next_chunk()`, each wrapped in HTTP chunked encoding by `format_chunk()`. The stream is terminated with a `0\r\n\r\n` chunk terminator. This ensures time-to-first-byte reflects the backend's first token latency rather than the full generation time. diff --git a/crates/openshell-core/src/inference.rs b/crates/openshell-core/src/inference.rs index a06c427f8..d2581f7eb 100644 --- a/crates/openshell-core/src/inference.rs +++ b/crates/openshell-core/src/inference.rs @@ -28,7 +28,8 @@ pub enum AuthHeader { /// /// This is the single source of truth for provider-specific inference knowledge: /// default endpoint, supported protocols, credential key lookup order, auth -/// header style, and default headers. +/// header style, default headers, and allowed client-supplied passthrough +/// headers. /// /// This is separate from [`openshell_providers::ProviderPlugin`] which handles /// credential *discovery* (scanning env vars). `InferenceProviderProfile` handles @@ -45,6 +46,10 @@ pub struct InferenceProviderProfile { pub auth: AuthHeader, /// Default headers injected on every outgoing request. pub default_headers: &'static [(&'static str, &'static str)], + /// Client-supplied headers that may be forwarded to the upstream backend. + /// + /// Header names must be lowercase and must not include auth headers. + pub passthrough_headers: &'static [&'static str], } const OPENAI_PROTOCOLS: &[&str] = &[ @@ -64,6 +69,7 @@ static OPENAI_PROFILE: InferenceProviderProfile = InferenceProviderProfile { base_url_config_keys: &["OPENAI_BASE_URL"], auth: AuthHeader::Bearer, default_headers: &[], + passthrough_headers: &["openai-organization", "x-model-id"], }; static ANTHROPIC_PROFILE: InferenceProviderProfile = InferenceProviderProfile { @@ -74,6 +80,7 @@ static ANTHROPIC_PROFILE: InferenceProviderProfile = InferenceProviderProfile { base_url_config_keys: &["ANTHROPIC_BASE_URL"], auth: AuthHeader::Custom("x-api-key"), default_headers: &[("anthropic-version", "2023-06-01")], + passthrough_headers: &["anthropic-version", "anthropic-beta"], }; static NVIDIA_PROFILE: InferenceProviderProfile = InferenceProviderProfile { @@ -84,6 +91,7 @@ static NVIDIA_PROFILE: InferenceProviderProfile = InferenceProviderProfile { base_url_config_keys: &["NVIDIA_BASE_URL"], auth: AuthHeader::Bearer, default_headers: &[], + passthrough_headers: &["x-model-id"], }; /// Look up the inference provider profile for a given provider type. @@ -105,6 +113,17 @@ pub fn profile_for(provider_type: &str) -> Option<&'static InferenceProviderProf /// need the auth/header information (e.g. the sandbox bundle-to-route /// conversion). pub fn auth_for_provider_type(provider_type: &str) -> (AuthHeader, Vec<(String, String)>) { + let (auth, headers, _) = route_headers_for_provider_type(provider_type); + (auth, headers) +} + +/// Derive routing header policy for a provider type string. +/// +/// Returns the auth injection mode, route-level default headers, and the +/// allowed client-supplied passthrough headers for `inference.local`. +pub fn route_headers_for_provider_type( + provider_type: &str, +) -> (AuthHeader, Vec<(String, String)>, Vec) { match profile_for(provider_type) { Some(profile) => { let headers = profile @@ -112,9 +131,14 @@ pub fn auth_for_provider_type(provider_type: &str) -> (AuthHeader, Vec<(String, .iter() .map(|(k, v)| ((*k).to_string(), (*v).to_string())) .collect(); - (profile.auth.clone(), headers) + let passthrough_headers = profile + .passthrough_headers + .iter() + .map(|name| (*name).to_string()) + .collect(); + (profile.auth.clone(), headers, passthrough_headers) } - None => (AuthHeader::Bearer, Vec::new()), + None => (AuthHeader::Bearer, Vec::new(), Vec::new()), } } @@ -193,6 +217,32 @@ mod tests { assert!(headers.iter().any(|(k, _)| k == "anthropic-version")); } + #[test] + fn route_headers_for_openai_include_passthrough_headers() { + let (_, _, passthrough_headers) = route_headers_for_provider_type("openai"); + assert!( + passthrough_headers + .iter() + .any(|name| name == "openai-organization") + ); + assert!(passthrough_headers.iter().any(|name| name == "x-model-id")); + } + + #[test] + fn route_headers_for_anthropic_include_passthrough_headers() { + let (_, _, passthrough_headers) = route_headers_for_provider_type("anthropic"); + assert!( + passthrough_headers + .iter() + .any(|name| name == "anthropic-version") + ); + assert!( + passthrough_headers + .iter() + .any(|name| name == "anthropic-beta") + ); + } + #[test] fn auth_for_openai_uses_bearer() { let (auth, headers) = auth_for_provider_type("openai"); @@ -206,4 +256,12 @@ mod tests { assert_eq!(auth, AuthHeader::Bearer); assert!(headers.is_empty()); } + + #[test] + fn route_headers_for_unknown_are_empty() { + let (auth, headers, passthrough_headers) = route_headers_for_provider_type("unknown"); + assert_eq!(auth, AuthHeader::Bearer); + assert!(headers.is_empty()); + assert!(passthrough_headers.is_empty()); + } } diff --git a/crates/openshell-router/src/backend.rs b/crates/openshell-router/src/backend.rs index 8dbae6502..fbca70ae1 100644 --- a/crates/openshell-router/src/backend.rs +++ b/crates/openshell-router/src/backend.rs @@ -4,6 +4,7 @@ use crate::RouterError; use crate::config::{AuthHeader, ResolvedRoute}; use crate::mock; +use std::collections::HashSet; #[derive(Debug, Clone, PartialEq, Eq)] pub struct ValidatedEndpoint { @@ -62,6 +63,9 @@ enum StreamingBody { Buffered(Option), } +const COMMON_INFERENCE_REQUEST_HEADERS: [&str; 4] = + ["content-type", "accept", "accept-encoding", "user-agent"]; + impl StreamingProxyResponse { /// Create from a fully-buffered [`ProxyResponse`] (for mock routes). pub fn from_buffered(resp: ProxyResponse) -> Self { @@ -83,7 +87,64 @@ impl StreamingProxyResponse { } } -/// Build an HTTP request to the backend configured in `route`. +fn sanitize_request_headers( + route: &ResolvedRoute, + headers: &[(String, String)], +) -> Vec<(String, String)> { + let mut allowed = HashSet::new(); + allowed.extend( + COMMON_INFERENCE_REQUEST_HEADERS + .iter() + .map(|name| (*name).to_string()), + ); + allowed.extend( + route + .passthrough_headers + .iter() + .map(|name| name.to_ascii_lowercase()), + ); + allowed.extend( + route + .default_headers + .iter() + .map(|(name, _)| name.to_ascii_lowercase()), + ); + + headers + .iter() + .filter_map(|(name, value)| { + let name_lc = name.to_ascii_lowercase(); + if should_strip_request_header(&name_lc) || !allowed.contains(&name_lc) { + return None; + } + Some((name.clone(), value.clone())) + }) + .collect() +} + +fn should_strip_request_header(name: &str) -> bool { + matches!( + name, + "authorization" | "x-api-key" | "host" | "content-length" + ) || is_hop_by_hop_header(name) +} + +fn is_hop_by_hop_header(name: &str) -> bool { + matches!( + name, + "connection" + | "keep-alive" + | "proxy-authenticate" + | "proxy-authorization" + | "proxy-connection" + | "te" + | "trailer" + | "transfer-encoding" + | "upgrade" + ) +} + +/// Build and send an HTTP request to the backend configured in `route`. /// /// Returns the prepared [`reqwest::RequestBuilder`] with auth, headers, model /// rewrite, and body applied. The caller decides whether to apply a total @@ -97,6 +158,7 @@ fn prepare_backend_request( body: bytes::Bytes, ) -> Result<(reqwest::RequestBuilder, String), RouterError> { let url = build_backend_url(&route.endpoint, path); + let headers = sanitize_request_headers(route, &headers); let reqwest_method: reqwest::Method = method .parse() @@ -113,17 +175,7 @@ fn prepare_backend_request( builder = builder.header(*header_name, &route.api_key); } } - - // Strip auth and host headers — auth is re-injected above from the route - // config, and host must match the upstream. - let strip_headers: [&str; 3] = ["authorization", "x-api-key", "host"]; - - // Forward non-sensitive headers. - for (name, value) in headers { - let name_lc = name.to_ascii_lowercase(); - if strip_headers.contains(&name_lc.as_str()) { - continue; - } + for (name, value) in &headers { builder = builder.header(name.as_str(), value.as_str()); } @@ -510,10 +562,95 @@ mod tests { protocols: protocols.iter().map(|p| (*p).to_string()).collect(), auth, default_headers: vec![("anthropic-version".to_string(), "2023-06-01".to_string())], + passthrough_headers: vec![ + "anthropic-version".to_string(), + "anthropic-beta".to_string(), + ], timeout: crate::config::DEFAULT_ROUTE_TIMEOUT, } } + #[test] + fn sanitize_request_headers_drops_unknown_sensitive_headers() { + let route = ResolvedRoute { + name: "inference.local".to_string(), + endpoint: "https://api.example.com/v1".to_string(), + model: "test-model".to_string(), + api_key: "sk-test".to_string(), + protocols: vec!["openai_chat_completions".to_string()], + auth: AuthHeader::Bearer, + default_headers: Vec::new(), + passthrough_headers: vec!["openai-organization".to_string()], + timeout: crate::config::DEFAULT_ROUTE_TIMEOUT, + }; + + let kept = super::sanitize_request_headers( + &route, + &[ + ("content-type".to_string(), "application/json".to_string()), + ("authorization".to_string(), "Bearer client".to_string()), + ("cookie".to_string(), "session=1".to_string()), + ("x-amz-security-token".to_string(), "token".to_string()), + ("openai-organization".to_string(), "org_123".to_string()), + ], + ); + + assert!( + kept.iter() + .any(|(name, _)| name.eq_ignore_ascii_case("content-type")) + ); + assert!( + kept.iter() + .any(|(name, _)| name.eq_ignore_ascii_case("openai-organization")) + ); + assert!( + kept.iter() + .all(|(name, _)| !name.eq_ignore_ascii_case("authorization")) + ); + assert!( + kept.iter() + .all(|(name, _)| !name.eq_ignore_ascii_case("cookie")) + ); + assert!( + kept.iter() + .all(|(name, _)| !name.eq_ignore_ascii_case("x-amz-security-token")) + ); + } + + #[test] + fn sanitize_request_headers_preserves_allowed_provider_headers() { + let route = test_route( + "https://api.anthropic.com/v1", + &["anthropic_messages"], + AuthHeader::Custom("x-api-key"), + ); + + let kept = super::sanitize_request_headers( + &route, + &[ + ("anthropic-version".to_string(), "2024-10-22".to_string()), + ( + "anthropic-beta".to_string(), + "tool-use-2024-10-22".to_string(), + ), + ("x-api-key".to_string(), "client-key".to_string()), + ], + ); + + assert!(kept.iter().any( + |(name, value)| name.eq_ignore_ascii_case("anthropic-version") && value == "2024-10-22" + )); + assert!( + kept.iter() + .any(|(name, value)| name.eq_ignore_ascii_case("anthropic-beta") + && value == "tool-use-2024-10-22") + ); + assert!( + kept.iter() + .all(|(name, _)| !name.eq_ignore_ascii_case("x-api-key")) + ); + } + #[tokio::test] async fn verify_backend_endpoint_uses_route_auth_and_shape() { let mock_server = MockServer::start().await; diff --git a/crates/openshell-router/src/config.rs b/crates/openshell-router/src/config.rs index b531e091d..660509d9e 100644 --- a/crates/openshell-router/src/config.rs +++ b/crates/openshell-router/src/config.rs @@ -34,7 +34,7 @@ pub struct RouteConfig { /// A fully-resolved route ready for the router to forward requests. /// /// The router is provider-agnostic — all provider-specific decisions -/// (auth header style, default headers, base URL) are made by the +/// (auth header style, default headers, passthrough headers, base URL) are made by the /// caller during resolution. #[derive(Clone)] pub struct ResolvedRoute { @@ -48,6 +48,8 @@ pub struct ResolvedRoute { pub auth: AuthHeader, /// Extra headers injected on every request (e.g. `anthropic-version`). pub default_headers: Vec<(String, String)>, + /// Client-supplied headers that may be forwarded to the upstream backend. + pub passthrough_headers: Vec, /// Per-request timeout for proxied inference calls. pub timeout: Duration, } @@ -62,6 +64,7 @@ impl std::fmt::Debug for ResolvedRoute { .field("protocols", &self.protocols) .field("auth", &self.auth) .field("default_headers", &self.default_headers) + .field("passthrough_headers", &self.passthrough_headers) .field("timeout", &self.timeout) .finish() } @@ -125,7 +128,8 @@ impl RouteConfig { ))); } - let (auth, default_headers) = auth_from_provider_type(self.provider_type.as_deref()); + let (auth, default_headers, passthrough_headers) = + route_headers_from_provider_type(self.provider_type.as_deref()); Ok(ResolvedRoute { name: self.name.clone(), @@ -135,17 +139,21 @@ impl RouteConfig { protocols, auth, default_headers, + passthrough_headers, timeout: DEFAULT_ROUTE_TIMEOUT, }) } } -/// Derive auth header style and default headers from a provider type string. +/// Derive auth header style, default headers, and passthrough headers from a +/// provider type string. /// -/// Delegates to [`openshell_core::inference::auth_for_provider_type`] which -/// uses the centralized `InferenceProviderProfile` registry. -fn auth_from_provider_type(provider_type: Option<&str>) -> (AuthHeader, Vec<(String, String)>) { - openshell_core::inference::auth_for_provider_type(provider_type.unwrap_or("")) +/// Delegates to [`openshell_core::inference::route_headers_for_provider_type`] +/// which uses the centralized `InferenceProviderProfile` registry. +fn route_headers_from_provider_type( + provider_type: Option<&str>, +) -> (AuthHeader, Vec<(String, String)>, Vec) { + openshell_core::inference::route_headers_for_provider_type(provider_type.unwrap_or("")) } #[cfg(test)] @@ -263,6 +271,7 @@ routes: protocols: vec!["openai_chat_completions".to_string()], auth: AuthHeader::Bearer, default_headers: Vec::new(), + passthrough_headers: Vec::new(), timeout: DEFAULT_ROUTE_TIMEOUT, }; let debug_output = format!("{route:?}"); @@ -278,22 +287,34 @@ routes: #[test] fn auth_from_anthropic_provider_uses_custom_header() { - let (auth, headers) = auth_from_provider_type(Some("anthropic")); + let (auth, headers, passthrough_headers) = + route_headers_from_provider_type(Some("anthropic")); assert_eq!(auth, AuthHeader::Custom("x-api-key")); assert!(headers.iter().any(|(k, _)| k == "anthropic-version")); + assert!( + passthrough_headers + .iter() + .any(|name| name == "anthropic-beta") + ); } #[test] fn auth_from_openai_provider_uses_bearer() { - let (auth, headers) = auth_from_provider_type(Some("openai")); + let (auth, headers, passthrough_headers) = route_headers_from_provider_type(Some("openai")); assert_eq!(auth, AuthHeader::Bearer); assert!(headers.is_empty()); + assert!( + passthrough_headers + .iter() + .any(|name| name == "openai-organization") + ); } #[test] fn auth_from_none_defaults_to_bearer() { - let (auth, headers) = auth_from_provider_type(None); + let (auth, headers, passthrough_headers) = route_headers_from_provider_type(None); assert_eq!(auth, AuthHeader::Bearer); assert!(headers.is_empty()); + assert!(passthrough_headers.is_empty()); } } diff --git a/crates/openshell-router/src/mock.rs b/crates/openshell-router/src/mock.rs index a17ce486f..66fc80414 100644 --- a/crates/openshell-router/src/mock.rs +++ b/crates/openshell-router/src/mock.rs @@ -131,6 +131,7 @@ mod tests { protocols: protocols.iter().map(ToString::to_string).collect(), auth: crate::config::AuthHeader::Bearer, default_headers: Vec::new(), + passthrough_headers: Vec::new(), timeout: crate::config::DEFAULT_ROUTE_TIMEOUT, } } diff --git a/crates/openshell-router/tests/backend_integration.rs b/crates/openshell-router/tests/backend_integration.rs index d9aecb0e3..6b21de94d 100644 --- a/crates/openshell-router/tests/backend_integration.rs +++ b/crates/openshell-router/tests/backend_integration.rs @@ -15,6 +15,7 @@ fn mock_candidates(base_url: &str) -> Vec { protocols: vec!["openai_chat_completions".to_string()], auth: AuthHeader::Bearer, default_headers: Vec::new(), + passthrough_headers: vec!["openai-organization".to_string(), "x-model-id".to_string()], timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, }] } @@ -118,6 +119,7 @@ async fn proxy_no_compatible_route_returns_error() { protocols: vec!["anthropic_messages".to_string()], auth: AuthHeader::Custom("x-api-key"), default_headers: Vec::new(), + passthrough_headers: Vec::new(), timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, }]; @@ -169,6 +171,39 @@ async fn proxy_strips_auth_header() { assert_eq!(response.status, 200); } +#[tokio::test] +async fn proxy_forwards_openai_organization_header() { + let mock_server = MockServer::start().await; + + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .and(bearer_token("test-api-key")) + .and(header("openai-organization", "org_123")) + .respond_with(ResponseTemplate::new(200).set_body_string("{}")) + .mount(&mock_server) + .await; + + let router = Router::new().unwrap(); + let candidates = mock_candidates(&mock_server.uri()); + + let response = router + .proxy_with_candidates( + "openai_chat_completions", + "POST", + "/v1/chat/completions", + vec![ + ("openai-organization".to_string(), "org_123".to_string()), + ("cookie".to_string(), "session=abc".to_string()), + ], + bytes::Bytes::new(), + &candidates, + ) + .await + .unwrap(); + + assert_eq!(response.status, 200); +} + #[tokio::test] async fn proxy_mock_route_returns_canned_response() { let router = Router::new().unwrap(); @@ -180,6 +215,7 @@ async fn proxy_mock_route_returns_canned_response() { protocols: vec!["openai_chat_completions".to_string()], auth: AuthHeader::Bearer, default_headers: Vec::new(), + passthrough_headers: Vec::new(), timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, }]; @@ -315,6 +351,10 @@ async fn proxy_uses_x_api_key_for_anthropic_route() { protocols: vec!["anthropic_messages".to_string()], auth: AuthHeader::Custom("x-api-key"), default_headers: vec![("anthropic-version".to_string(), "2023-06-01".to_string())], + passthrough_headers: vec![ + "anthropic-version".to_string(), + "anthropic-beta".to_string(), + ], timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, }]; @@ -374,6 +414,10 @@ async fn proxy_anthropic_does_not_send_bearer_auth() { protocols: vec!["anthropic_messages".to_string()], auth: AuthHeader::Custom("x-api-key"), default_headers: vec![("anthropic-version".to_string(), "2023-06-01".to_string())], + passthrough_headers: vec![ + "anthropic-version".to_string(), + "anthropic-beta".to_string(), + ], timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, }]; @@ -419,6 +463,10 @@ async fn proxy_forwards_client_anthropic_version_header() { protocols: vec!["anthropic_messages".to_string()], auth: AuthHeader::Custom("x-api-key"), default_headers: vec![("anthropic-version".to_string(), "2023-06-01".to_string())], + passthrough_headers: vec![ + "anthropic-version".to_string(), + "anthropic-beta".to_string(), + ], timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, }]; @@ -509,6 +557,7 @@ async fn streaming_proxy_completes_despite_exceeding_route_timeout() { protocols: vec!["openai_chat_completions".to_string()], auth: AuthHeader::Bearer, default_headers: Vec::new(), + passthrough_headers: Vec::new(), // Route timeout shorter than the backend delay — streaming must // NOT be constrained by this. timeout: Duration::from_secs(1), @@ -572,6 +621,7 @@ async fn buffered_proxy_enforces_route_timeout() { protocols: vec!["openai_chat_completions".to_string()], auth: AuthHeader::Bearer, default_headers: Vec::new(), + passthrough_headers: Vec::new(), timeout: Duration::from_secs(1), }]; diff --git a/crates/openshell-sandbox/src/lib.rs b/crates/openshell-sandbox/src/lib.rs index 5d7bda98f..b754f564c 100644 --- a/crates/openshell-sandbox/src/lib.rs +++ b/crates/openshell-sandbox/src/lib.rs @@ -1099,8 +1099,8 @@ pub(crate) fn bundle_to_resolved_routes( .routes .iter() .map(|r| { - let (auth, default_headers) = - openshell_core::inference::auth_for_provider_type(&r.provider_type); + let (auth, default_headers, passthrough_headers) = + openshell_core::inference::route_headers_for_provider_type(&r.provider_type); let timeout = if r.timeout_secs == 0 { openshell_router::config::DEFAULT_ROUTE_TIMEOUT } else { @@ -1114,6 +1114,7 @@ pub(crate) fn bundle_to_resolved_routes( protocols: r.protocols.clone(), auth, default_headers, + passthrough_headers, timeout, } }) @@ -2272,6 +2273,7 @@ mod tests { protocols: vec!["openai_chat_completions".to_string()], auth: openshell_core::inference::AuthHeader::Bearer, default_headers: vec![], + passthrough_headers: vec![], timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, }, openshell_router::config::ResolvedRoute { @@ -2282,6 +2284,7 @@ mod tests { protocols: vec!["anthropic_messages".to_string()], auth: openshell_core::inference::AuthHeader::Custom("x-api-key"), default_headers: vec![], + passthrough_headers: vec![], timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, }, ]; @@ -2571,6 +2574,7 @@ filesystem_policy: auth: openshell_core::inference::AuthHeader::Bearer, protocols: vec!["openai_chat_completions".to_string()], default_headers: vec![], + passthrough_headers: vec![], timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, }]; diff --git a/crates/openshell-sandbox/src/proxy.rs b/crates/openshell-sandbox/src/proxy.rs index 1b87a0c7f..6f85e848e 100644 --- a/crates/openshell-sandbox/src/proxy.rs +++ b/crates/openshell-sandbox/src/proxy.rs @@ -1205,9 +1205,6 @@ async fn route_inference_request( ocsf_emit!(event); } - // Strip credential + framing/hop-by-hop headers. - let filtered_headers = sanitize_inference_request_headers(&request.headers); - let routes = ctx.routes.read().await; if routes.is_empty() { @@ -1231,7 +1228,7 @@ async fn route_inference_request( &pattern.protocol, &request.method, &normalized_path, - filtered_headers, + request.headers.clone(), bytes::Bytes::from(request.body.clone()), &routes, ) @@ -1393,14 +1390,6 @@ fn router_error_to_http(err: &openshell_router::RouterError) -> (u16, String) { } } -fn sanitize_inference_request_headers(headers: &[(String, String)]) -> Vec<(String, String)> { - headers - .iter() - .filter(|(name, _)| !should_strip_request_header(name)) - .cloned() - .collect() -} - fn sanitize_inference_response_headers(headers: Vec<(String, String)>) -> Vec<(String, String)> { headers .into_iter() @@ -1408,14 +1397,6 @@ fn sanitize_inference_response_headers(headers: Vec<(String, String)>) -> Vec<(S .collect() } -fn should_strip_request_header(name: &str) -> bool { - let name_lc = name.to_ascii_lowercase(); - matches!( - name_lc.as_str(), - "authorization" | "x-api-key" | "host" | "content-length" - ) || is_hop_by_hop_header(&name_lc) -} - fn should_strip_response_header(name: &str) -> bool { let name_lc = name.to_ascii_lowercase(); matches!(name_lc.as_str(), "content-length") || is_hop_by_hop_header(&name_lc) @@ -2792,48 +2773,102 @@ mod tests { ); } - #[test] - fn sanitize_request_headers_strips_auth_and_framing() { - let headers = vec![ - ("authorization".to_string(), "Bearer test".to_string()), - ("x-api-key".to_string(), "secret".to_string()), - ("transfer-encoding".to_string(), "chunked".to_string()), - ("content-length".to_string(), "42".to_string()), - ("content-type".to_string(), "application/json".to_string()), - ("accept".to_string(), "text/event-stream".to_string()), - ]; + #[tokio::test] + async fn inference_interception_applies_router_header_allowlist() { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpListener; - let kept = sanitize_inference_request_headers(&headers); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let upstream_addr = listener.local_addr().unwrap(); + let upstream_task = tokio::spawn(async move { + use crate::l7::inference::{ParseResult, try_parse_http_request}; - assert!( - kept.iter() - .all(|(k, _)| !k.eq_ignore_ascii_case("authorization")), - "authorization should be stripped" - ); - assert!( - kept.iter() - .all(|(k, _)| !k.eq_ignore_ascii_case("x-api-key")), - "x-api-key should be stripped" - ); - assert!( - kept.iter() - .all(|(k, _)| !k.eq_ignore_ascii_case("transfer-encoding")), - "transfer-encoding should be stripped" - ); - assert!( - kept.iter() - .all(|(k, _)| !k.eq_ignore_ascii_case("content-length")), - "content-length should be stripped" + let (mut upstream, _) = listener.accept().await.unwrap(); + let mut buf = Vec::new(); + let mut chunk = [0u8; 4096]; + + loop { + let n = upstream.read(&mut chunk).await.unwrap(); + assert!(n > 0, "upstream request closed before request completed"); + buf.extend_from_slice(&chunk[..n]); + + match try_parse_http_request(&buf) { + ParseResult::Complete(_, consumed) => { + upstream + .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok") + .await + .unwrap(); + return String::from_utf8_lossy(&buf[..consumed]).to_string(); + } + ParseResult::Incomplete => continue, + ParseResult::Invalid(reason) => { + panic!("forwarded request should parse cleanly: {reason}"); + } + } + } + }); + + let router = openshell_router::Router::new().unwrap(); + let patterns = crate::l7::inference::default_patterns(); + let ctx = InferenceContext::new( + patterns, + router, + vec![openshell_router::config::ResolvedRoute { + name: "inference.local".to_string(), + endpoint: format!("http://{upstream_addr}"), + model: "meta/llama-3.1-8b-instruct".to_string(), + api_key: "test-api-key".to_string(), + protocols: vec!["openai_chat_completions".to_string()], + auth: openshell_router::config::AuthHeader::Bearer, + default_headers: vec![], + passthrough_headers: vec![ + "openai-organization".to_string(), + "x-model-id".to_string(), + ], + timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, + }], + vec![], ); - assert!( - kept.iter() - .any(|(k, _)| k.eq_ignore_ascii_case("content-type")), - "content-type should be preserved" + + let body = r#"{"model":"ignored","messages":[{"role":"user","content":"hi"}]}"#; + let request = format!( + "POST /v1/chat/completions HTTP/1.1\r\n\ + Host: inference.local\r\n\ + Content-Type: application/json\r\n\ + OpenAI-Organization: org_123\r\n\ + Authorization: Bearer client-key\r\n\ + Cookie: session=abc\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body, ); + + let (client, mut server) = tokio::io::duplex(65536); + let (mut client_read, mut client_write) = tokio::io::split(client); + + let server_task = + tokio::spawn(async move { process_inference_keepalive(&mut server, &ctx, 443).await }); + + client_write.write_all(request.as_bytes()).await.unwrap(); + client_write.shutdown().await.unwrap(); + + let mut response = Vec::new(); + client_read.read_to_end(&mut response).await.unwrap(); + let response_text = String::from_utf8_lossy(&response); + assert!(response_text.starts_with("HTTP/1.1 200")); + + let outcome = server_task.await.unwrap().unwrap(); assert!( - kept.iter().any(|(k, _)| k.eq_ignore_ascii_case("accept")), - "accept should be preserved" + matches!(outcome, InferenceOutcome::Routed), + "expected Routed outcome, got: {outcome:?}" ); + + let forwarded = upstream_task.await.unwrap(); + let forwarded_lc = forwarded.to_ascii_lowercase(); + assert!(forwarded_lc.contains("openai-organization: org_123")); + assert!(forwarded_lc.contains("authorization: bearer test-api-key")); + assert!(!forwarded_lc.contains("authorization: bearer client-key")); + assert!(!forwarded_lc.contains("cookie:")); } // -- router_error_to_http -- diff --git a/crates/openshell-sandbox/tests/system_inference.rs b/crates/openshell-sandbox/tests/system_inference.rs index 5d581fbe2..20c39f3b6 100644 --- a/crates/openshell-sandbox/tests/system_inference.rs +++ b/crates/openshell-sandbox/tests/system_inference.rs @@ -20,6 +20,7 @@ fn make_system_route() -> ResolvedRoute { protocols: vec!["openai_chat_completions".to_string()], auth: AuthHeader::Bearer, default_headers: Vec::new(), + passthrough_headers: Vec::new(), timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, } } @@ -33,6 +34,7 @@ fn make_user_route() -> ResolvedRoute { protocols: vec!["openai_chat_completions".to_string()], auth: AuthHeader::Bearer, default_headers: Vec::new(), + passthrough_headers: Vec::new(), timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, } } @@ -126,6 +128,10 @@ async fn system_inference_with_anthropic_protocol() { protocols: vec!["anthropic_messages".to_string()], auth: AuthHeader::Custom("x-api-key"), default_headers: vec![("anthropic-version".to_string(), "2023-06-01".to_string())], + passthrough_headers: vec![ + "anthropic-version".to_string(), + "anthropic-beta".to_string(), + ], timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, }; diff --git a/crates/openshell-server/src/inference.rs b/crates/openshell-server/src/inference.rs index 0fb29bde5..79f303aeb 100644 --- a/crates/openshell-server/src/inference.rs +++ b/crates/openshell-server/src/inference.rs @@ -276,6 +276,11 @@ fn resolve_provider_route(provider: &Provider) -> Result