Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions crates/test-programs/src/bin/api_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@ impl bindings::exports::wasi::http::incoming_handler::Guest for T {
let req_hdrs = request.headers();

assert!(
!req_hdrs.get(&header).is_empty(),
"missing `custom-forbidden-header` from request"
req_hdrs.get(&header).is_empty(),
"forbidden `custom-forbidden-header` found in request"
);

assert!(req_hdrs.delete(&header).is_err());
assert!(req_hdrs.append(&header, &b"no".to_vec()).is_err());

assert!(
!req_hdrs.get(&header).is_empty(),
"delete of forbidden header succeeded"
req_hdrs.get(&header).is_empty(),
"append of forbidden header succeeded"
);

let hdrs = bindings::wasi::http::types::Headers::new();
Expand Down
59 changes: 53 additions & 6 deletions crates/wasi-http/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,18 @@ pub trait WasiHttpView: Send {
fn new_incoming_request(
&mut self,
req: hyper::Request<HyperIncomingBody>,
) -> wasmtime::Result<Resource<HostIncomingRequest>> {
) -> wasmtime::Result<Resource<HostIncomingRequest>>
where
Self: Sized,
{
let (parts, body) = req.into_parts();
let body = HostIncomingBody::new(
body,
// TODO: this needs to be plumbed through
std::time::Duration::from_millis(600 * 1000),
);
Ok(self.table().push(HostIncomingRequest {
parts,
body: Some(body),
})?)
let incoming_req = HostIncomingRequest::new(self, parts, Some(body));
Ok(self.table().push(incoming_req)?)
}

fn new_response_outparam(
Expand Down Expand Up @@ -73,6 +74,41 @@ pub trait WasiHttpView: Send {
}
}

/// Returns `true` when the header is forbidden according to this [`WasiHttpView`] implementation.
pub(crate) fn is_forbidden_header(view: &mut dyn WasiHttpView, name: &HeaderName) -> bool {
static FORBIDDEN_HEADERS: [HeaderName; 9] = [
hyper::header::CONNECTION,
HeaderName::from_static("keep-alive"),
hyper::header::PROXY_AUTHENTICATE,
hyper::header::PROXY_AUTHORIZATION,
HeaderName::from_static("proxy-connection"),
hyper::header::TE,
hyper::header::TRANSFER_ENCODING,
hyper::header::UPGRADE,
HeaderName::from_static("http2-settings"),
];

FORBIDDEN_HEADERS.contains(name) || view.is_forbidden_header(name)
}

/// Removes forbidden headers from a [`hyper::HeaderMap`].
pub(crate) fn remove_forbidden_headers(
view: &mut dyn WasiHttpView,
headers: &mut hyper::HeaderMap,
) {
let forbidden_keys = Vec::from_iter(headers.keys().filter_map(|name| {
if is_forbidden_header(view, name) {
Some(name.clone())
} else {
None
}
}));

for name in forbidden_keys {
headers.remove(name);
}
}

pub fn default_send_request(
view: &mut dyn WasiHttpView,
OutgoingRequest {
Expand Down Expand Up @@ -264,10 +300,21 @@ impl TryInto<http::Method> for types::Method {
}

pub struct HostIncomingRequest {
pub parts: http::request::Parts,
pub(crate) parts: http::request::Parts,
pub body: Option<HostIncomingBody>,
}

impl HostIncomingRequest {
pub fn new(
view: &mut dyn WasiHttpView,
mut parts: http::request::Parts,
body: Option<HostIncomingBody>,
) -> Self {
remove_forbidden_headers(view, &mut parts.headers);
Self { parts, body }
}
}

pub struct HostResponseOutparam {
pub result:
tokio::sync::oneshot::Sender<Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>>,
Expand Down
28 changes: 7 additions & 21 deletions crates/wasi-http/src/types_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ use crate::{
bindings::http::types::{self, Headers, Method, Scheme, StatusCode, Trailers},
body::{HostFutureTrailers, HostIncomingBody, HostOutgoingBody},
types::{
FieldMap, HostFields, HostFutureIncomingResponse, HostIncomingRequest,
HostIncomingResponse, HostOutgoingRequest, HostOutgoingResponse, HostResponseOutparam,
is_forbidden_header, remove_forbidden_headers, FieldMap, HostFields,
HostFutureIncomingResponse, HostIncomingRequest, HostIncomingResponse, HostOutgoingRequest,
HostOutgoingResponse, HostResponseOutparam,
},
WasiHttpView,
};
use anyhow::Context;
use hyper::header::HeaderName;
use std::any::Any;
use std::str::FromStr;
use wasmtime::component::Resource;
Expand Down Expand Up @@ -88,22 +88,6 @@ fn get_fields_mut<'a>(
}
}

fn is_forbidden_header<T: WasiHttpView>(view: &mut T, name: &HeaderName) -> bool {
static FORBIDDEN_HEADERS: [HeaderName; 9] = [
hyper::header::CONNECTION,
HeaderName::from_static("keep-alive"),
hyper::header::PROXY_AUTHENTICATE,
hyper::header::PROXY_AUTHORIZATION,
HeaderName::from_static("proxy-connection"),
hyper::header::TE,
hyper::header::TRANSFER_ENCODING,
hyper::header::UPGRADE,
HeaderName::from_static("http2-settings"),
];

FORBIDDEN_HEADERS.contains(name) || view.is_forbidden_header(name)
}

impl<T: WasiHttpView> crate::bindings::http::types::HostFields for T {
fn new(&mut self) -> wasmtime::Result<Resource<HostFields>> {
let id = self
Expand Down Expand Up @@ -833,11 +817,13 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostFutureIncomingResponse f
Ok(Err(e)) => return Ok(Some(Ok(Err(e)))),
};

let (parts, body) = resp.resp.into_parts();
let (mut parts, body) = resp.resp.into_parts();

remove_forbidden_headers(self, &mut parts.headers);

let resp = self.table().push(HostIncomingResponse {
status: parts.status.as_u16(),
headers: FieldMap::from(parts.headers),
headers: parts.headers,
body: Some({
let mut body = HostIncomingBody::new(body, resp.between_bytes_timeout);
body.retain_worker(&resp.worker);
Expand Down