From 96fbe6bd0d5cb9b2553600d207dc5bb35423d546 Mon Sep 17 00:00:00 2001 From: Trevor Elliott Date: Tue, 14 Nov 2023 11:32:44 -0800 Subject: [PATCH 1/2] wasi-http: Implement http-error-code, and centralize error conversions (#7534) * Implement the missing http-error-code function * Add concrete error conversion functions * Test the error returned when writing too much * Remove an unused import * Only log errors when they will get the default representation * Use `source` instead of `into_cause` --- .../http_outbound_request_content_length.rs | 12 ++++- crates/wasi-http/src/http_impl.rs | 16 +++--- crates/wasi-http/src/lib.rs | 52 +++++++++++++++++++ crates/wasi-http/src/types.rs | 29 +++++------ crates/wasi-http/src/types_impl.rs | 5 +- src/commands/serve.rs | 15 ++---- 6 files changed, 91 insertions(+), 38 deletions(-) diff --git a/crates/test-programs/src/bin/http_outbound_request_content_length.rs b/crates/test-programs/src/bin/http_outbound_request_content_length.rs index c4c4e4a4a7fa..ea7d69c58f4b 100644 --- a/crates/test-programs/src/bin/http_outbound_request_content_length.rs +++ b/crates/test-programs/src/bin/http_outbound_request_content_length.rs @@ -70,11 +70,19 @@ fn main() { { let request_body = outgoing_body.write().unwrap(); - request_body + let e = request_body .blocking_write_and_flush("more than 11 bytes".as_bytes()) .expect_err("write should fail"); - // TODO: show how to use http-error-code to unwrap this error + let e = match e { + test_programs::wasi::io::streams::StreamError::LastOperationFailed(e) => e, + test_programs::wasi::io::streams::StreamError::Closed => panic!("request closed"), + }; + + assert!(matches!( + http_types::http_error_code(&e), + Some(http_types::ErrorCode::InternalError(Some(msg))) + if msg == "too much written to output stream")); } let e = diff --git a/crates/wasi-http/src/http_impl.rs b/crates/wasi-http/src/http_impl.rs index 7e5594712c35..7804edffc3a2 100644 --- a/crates/wasi-http/src/http_impl.rs +++ b/crates/wasi-http/src/http_impl.rs @@ -3,6 +3,7 @@ use crate::{ outgoing_handler, types::{self, Scheme}, }, + http_request_error, internal_error, types::{HostFutureIncomingResponse, HostOutgoingRequest, OutgoingRequest}, WasiHttpView, }; @@ -77,22 +78,21 @@ impl outgoing_handler::Host for T { uri = uri.path_and_query(path); } - builder = builder.uri( - uri.build() - .map_err(|_| types::ErrorCode::HttpRequestUriInvalid)?, - ); + builder = builder.uri(uri.build().map_err(http_request_error)?); for (k, v) in req.headers.iter() { builder = builder.header(k, v); } - let body = req - .body - .unwrap_or_else(|| Empty::::new().map_err(|_| todo!("thing")).boxed()); + let body = req.body.unwrap_or_else(|| { + Empty::::new() + .map_err(|_| unreachable!("Infallible error")) + .boxed() + }); let request = builder .body(body) - .map_err(|err| types::ErrorCode::InternalError(Some(err.to_string())))?; + .map_err(|err| internal_error(err.to_string()))?; Ok(Ok(self.send_request(OutgoingRequest { use_tls, diff --git a/crates/wasi-http/src/lib.rs b/crates/wasi-http/src/lib.rs index 46a3a9024df2..66852af8985e 100644 --- a/crates/wasi-http/src/lib.rs +++ b/crates/wasi-http/src/lib.rs @@ -17,6 +17,7 @@ pub mod bindings { tracing: true, async: false, with: { + "wasi:io/error": wasmtime_wasi::preview2::bindings::io::error, "wasi:io/streams": wasmtime_wasi::preview2::bindings::io::streams, "wasi:io/poll": wasmtime_wasi::preview2::bindings::io::poll, @@ -47,3 +48,54 @@ pub(crate) fn dns_error(rcode: String, info_code: u16) -> bindings::http::types: pub(crate) fn internal_error(msg: String) -> bindings::http::types::ErrorCode { bindings::http::types::ErrorCode::InternalError(Some(msg)) } + +/// Translate a [`http::Error`] to a wasi-http `ErrorCode` in the context of a request. +pub fn http_request_error(err: http::Error) -> bindings::http::types::ErrorCode { + use bindings::http::types::ErrorCode; + + if err.is::() { + return ErrorCode::HttpRequestUriInvalid; + } + + tracing::warn!("http request error: {err:?}"); + + ErrorCode::HttpProtocolError +} + +/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request. +pub fn hyper_request_error(err: hyper::Error) -> bindings::http::types::ErrorCode { + use bindings::http::types::ErrorCode; + use std::error::Error; + + // If there's a source, we might be able to extract a wasi-http error from it. + if let Some(cause) = err.source() { + if let Some(err) = cause.downcast_ref::() { + return err.clone(); + } + } + + tracing::warn!("hyper request error: {err:?}"); + + ErrorCode::HttpProtocolError +} + +/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a response. +pub fn hyper_response_error(err: hyper::Error) -> bindings::http::types::ErrorCode { + use bindings::http::types::ErrorCode; + use std::error::Error; + + if err.is_timeout() { + return ErrorCode::HttpResponseTimeout; + } + + // If there's a source, we might be able to extract a wasi-http error from it. + if let Some(cause) = err.source() { + if let Some(err) = cause.downcast_ref::() { + return err.clone(); + } + } + + tracing::warn!("hyper response error: {err:?}"); + + ErrorCode::HttpProtocolError +} diff --git a/crates/wasi-http/src/types.rs b/crates/wasi-http/src/types.rs index a5bc098dada1..1a44b7910869 100644 --- a/crates/wasi-http/src/types.rs +++ b/crates/wasi-http/src/types.rs @@ -4,7 +4,7 @@ use crate::{ bindings::http::types::{self, Method, Scheme}, body::{HostIncomingBody, HyperIncomingBody, HyperOutgoingBody}, - dns_error, + dns_error, hyper_request_error, }; use http_body_util::BodyExt; use hyper::header::HeaderName; @@ -156,12 +156,14 @@ async fn handler( let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config)); let mut parts = authority.split(":"); let host = parts.next().unwrap_or(&authority); - let domain = rustls::ServerName::try_from(host) - .map_err(|_| dns_error("invalid dns name".to_string(), 0))?; - let stream = connector - .connect(domain, tcp_stream) - .await - .map_err(|_| types::ErrorCode::TlsProtocolError)?; + let domain = rustls::ServerName::try_from(host).map_err(|e| { + tracing::warn!("dns lookup error: {e:?}"); + dns_error("invalid dns name".to_string(), 0) + })?; + let stream = connector.connect(domain, tcp_stream).await.map_err(|e| { + tracing::warn!("tls protocol error: {e:?}"); + types::ErrorCode::TlsProtocolError + })?; let (sender, conn) = timeout( connect_timeout, @@ -169,7 +171,7 @@ async fn handler( ) .await .map_err(|_| types::ErrorCode::ConnectionTimeout)? - .map_err(|_| types::ErrorCode::ConnectionTimeout)?; + .map_err(hyper_request_error)?; let worker = preview2::spawn(async move { match conn.await { @@ -190,7 +192,7 @@ async fn handler( ) .await .map_err(|_| types::ErrorCode::ConnectionTimeout)? - .map_err(|_| types::ErrorCode::HttpProtocolError)?; + .map_err(hyper_request_error)?; let worker = preview2::spawn(async move { match conn.await { @@ -206,11 +208,8 @@ async fn handler( let resp = timeout(first_byte_timeout, sender.send_request(request)) .await .map_err(|_| types::ErrorCode::ConnectionReadTimeout)? - .map_err(|_| types::ErrorCode::HttpProtocolError)? - .map(|body| { - body.map_err(|_| types::ErrorCode::HttpProtocolError) - .boxed() - }); + .map_err(hyper_request_error)? + .map(|body| body.map_err(hyper_request_error).boxed()); Ok(IncomingResponseInternal { resp, @@ -318,7 +317,7 @@ impl TryFrom for hyper::Response { Some(body) => builder.body(body), None => builder.body( Empty::::new() - .map_err(|_| unreachable!()) + .map_err(|_| unreachable!("Infallible error")) .boxed(), ), } diff --git a/crates/wasi-http/src/types_impl.rs b/crates/wasi-http/src/types_impl.rs index 6ee24e790478..6529381c3fb8 100644 --- a/crates/wasi-http/src/types_impl.rs +++ b/crates/wasi-http/src/types_impl.rs @@ -20,9 +20,10 @@ use wasmtime_wasi::preview2::{ impl crate::bindings::http::types::Host for T { fn http_error_code( &mut self, - _err: wasmtime::component::Resource, + err: wasmtime::component::Resource, ) -> wasmtime::Result> { - todo!() + let e = self.table().get(&err)?; + Ok(e.downcast_ref::().cloned()) } } diff --git a/src/commands/serve.rs b/src/commands/serve.rs index 5aad87decdbb..fe5a8f240857 100644 --- a/src/commands/serve.rs +++ b/src/commands/serve.rs @@ -15,7 +15,7 @@ use wasmtime_wasi::preview2::{ self, StreamError, StreamResult, Table, WasiCtx, WasiCtxBuilder, WasiView, }; use wasmtime_wasi_http::{ - bindings::http::types as http_types, body::HyperOutgoingBody, WasiHttpCtx, WasiHttpView, + body::HyperOutgoingBody, hyper_response_error, WasiHttpCtx, WasiHttpView, }; #[cfg(feature = "wasi-nn")] @@ -365,16 +365,9 @@ impl hyper::service::Service for ProxyHandler { let mut store = inner.cmd.new_store(&inner.engine, req_id)?; - let req = store.data_mut().new_incoming_request(req.map(|body| { - body.map_err(|err| { - if err.is_timeout() { - http_types::ErrorCode::HttpResponseTimeout - } else { - http_types::ErrorCode::InternalError(Some(err.message().to_string())) - } - }) - .boxed() - }))?; + let req = store + .data_mut() + .new_incoming_request(req.map(|body| body.map_err(hyper_response_error).boxed()))?; let out = store.data_mut().new_response_outparam(sender)?; From e72b7493c6b7478579ca2da1bec9320ea4b456c8 Mon Sep 17 00:00:00 2001 From: Trevor Elliott Date: Tue, 14 Nov 2023 14:06:14 -0800 Subject: [PATCH 2/2] Filter out forbidden headers on incoming request and response resources (#7538) --- crates/test-programs/src/bin/api_proxy.rs | 9 ++-- crates/wasi-http/src/types.rs | 59 ++++++++++++++++++++--- crates/wasi-http/src/types_impl.rs | 28 +++-------- 3 files changed, 65 insertions(+), 31 deletions(-) diff --git a/crates/test-programs/src/bin/api_proxy.rs b/crates/test-programs/src/bin/api_proxy.rs index c9f292330dc6..d9ca3936fab0 100644 --- a/crates/test-programs/src/bin/api_proxy.rs +++ b/crates/test-programs/src/bin/api_proxy.rs @@ -20,15 +20,16 @@ impl bindings::exports::wasi::http::incoming_handler::Guest for T { let req_hdrs = request.headers(); assert!( - !req_hdrs.get(&header).is_empty(), - "missing `custom-forbidden-header` from request" + req_hdrs.get(&header).is_empty(), + "forbidden `custom-forbidden-header` found in request" ); assert!(req_hdrs.delete(&header).is_err()); + assert!(req_hdrs.append(&header, &b"no".to_vec()).is_err()); assert!( - !req_hdrs.get(&header).is_empty(), - "delete of forbidden header succeeded" + req_hdrs.get(&header).is_empty(), + "append of forbidden header succeeded" ); let hdrs = bindings::wasi::http::types::Headers::new(); diff --git a/crates/wasi-http/src/types.rs b/crates/wasi-http/src/types.rs index 1a44b7910869..4473b3912574 100644 --- a/crates/wasi-http/src/types.rs +++ b/crates/wasi-http/src/types.rs @@ -35,17 +35,18 @@ pub trait WasiHttpView: Send { fn new_incoming_request( &mut self, req: hyper::Request, - ) -> wasmtime::Result> { + ) -> wasmtime::Result> + where + Self: Sized, + { let (parts, body) = req.into_parts(); let body = HostIncomingBody::new( body, // TODO: this needs to be plumbed through std::time::Duration::from_millis(600 * 1000), ); - Ok(self.table().push(HostIncomingRequest { - parts, - body: Some(body), - })?) + let incoming_req = HostIncomingRequest::new(self, parts, Some(body)); + Ok(self.table().push(incoming_req)?) } fn new_response_outparam( @@ -73,6 +74,41 @@ pub trait WasiHttpView: Send { } } +/// Returns `true` when the header is forbidden according to this [`WasiHttpView`] implementation. +pub(crate) fn is_forbidden_header(view: &mut dyn WasiHttpView, name: &HeaderName) -> bool { + static FORBIDDEN_HEADERS: [HeaderName; 9] = [ + hyper::header::CONNECTION, + HeaderName::from_static("keep-alive"), + hyper::header::PROXY_AUTHENTICATE, + hyper::header::PROXY_AUTHORIZATION, + HeaderName::from_static("proxy-connection"), + hyper::header::TE, + hyper::header::TRANSFER_ENCODING, + hyper::header::UPGRADE, + HeaderName::from_static("http2-settings"), + ]; + + FORBIDDEN_HEADERS.contains(name) || view.is_forbidden_header(name) +} + +/// Removes forbidden headers from a [`hyper::HeaderMap`]. +pub(crate) fn remove_forbidden_headers( + view: &mut dyn WasiHttpView, + headers: &mut hyper::HeaderMap, +) { + let forbidden_keys = Vec::from_iter(headers.keys().filter_map(|name| { + if is_forbidden_header(view, name) { + Some(name.clone()) + } else { + None + } + })); + + for name in forbidden_keys { + headers.remove(name); + } +} + pub fn default_send_request( view: &mut dyn WasiHttpView, OutgoingRequest { @@ -263,10 +299,21 @@ impl TryInto for types::Method { } pub struct HostIncomingRequest { - pub parts: http::request::Parts, + pub(crate) parts: http::request::Parts, pub body: Option, } +impl HostIncomingRequest { + pub fn new( + view: &mut dyn WasiHttpView, + mut parts: http::request::Parts, + body: Option, + ) -> Self { + remove_forbidden_headers(view, &mut parts.headers); + Self { parts, body } + } +} + pub struct HostResponseOutparam { pub result: tokio::sync::oneshot::Sender, types::ErrorCode>>, diff --git a/crates/wasi-http/src/types_impl.rs b/crates/wasi-http/src/types_impl.rs index 6529381c3fb8..ca11e87fe0ea 100644 --- a/crates/wasi-http/src/types_impl.rs +++ b/crates/wasi-http/src/types_impl.rs @@ -2,13 +2,13 @@ use crate::{ bindings::http::types::{self, Headers, Method, Scheme, StatusCode, Trailers}, body::{HostFutureTrailers, HostIncomingBody, HostOutgoingBody}, types::{ - FieldMap, HostFields, HostFutureIncomingResponse, HostIncomingRequest, - HostIncomingResponse, HostOutgoingRequest, HostOutgoingResponse, HostResponseOutparam, + is_forbidden_header, remove_forbidden_headers, FieldMap, HostFields, + HostFutureIncomingResponse, HostIncomingRequest, HostIncomingResponse, HostOutgoingRequest, + HostOutgoingResponse, HostResponseOutparam, }, WasiHttpView, }; use anyhow::Context; -use hyper::header::HeaderName; use std::any::Any; use std::str::FromStr; use wasmtime::component::Resource; @@ -89,22 +89,6 @@ fn get_fields_mut<'a>( } } -fn is_forbidden_header(view: &mut T, name: &HeaderName) -> bool { - static FORBIDDEN_HEADERS: [HeaderName; 9] = [ - hyper::header::CONNECTION, - HeaderName::from_static("keep-alive"), - hyper::header::PROXY_AUTHENTICATE, - hyper::header::PROXY_AUTHORIZATION, - HeaderName::from_static("proxy-connection"), - hyper::header::TE, - hyper::header::TRANSFER_ENCODING, - hyper::header::UPGRADE, - HeaderName::from_static("http2-settings"), - ]; - - FORBIDDEN_HEADERS.contains(name) || view.is_forbidden_header(name) -} - impl crate::bindings::http::types::HostFields for T { fn new(&mut self) -> wasmtime::Result> { let id = self @@ -834,11 +818,13 @@ impl crate::bindings::http::types::HostFutureIncomingResponse f Ok(Err(e)) => return Ok(Some(Ok(Err(e)))), }; - let (parts, body) = resp.resp.into_parts(); + let (mut parts, body) = resp.resp.into_parts(); + + remove_forbidden_headers(self, &mut parts.headers); let resp = self.table().push(HostIncomingResponse { status: parts.status.as_u16(), - headers: FieldMap::from(parts.headers), + headers: parts.headers, body: Some({ let mut body = HostIncomingBody::new(body, resp.between_bytes_timeout); body.retain_worker(&resp.worker);