diff --git a/tower-http/src/builder.rs b/tower-http/src/builder.rs index 85803855..3bdcf64a 100644 --- a/tower-http/src/builder.rs +++ b/tower-http/src/builder.rs @@ -366,6 +366,16 @@ pub trait ServiceBuilderExt: sealed::Sealed + Sized { fn trim_trailing_slash( self, ) -> ServiceBuilder>; + + /// Append trailing slash to paths. + /// + /// See [`tower_http::normalize_path`] for more details. + /// + /// [`tower_http::normalize_path`]: crate::normalize_path + #[cfg(feature = "normalize-path")] + fn append_trailing_slash( + self, + ) -> ServiceBuilder>; } impl sealed::Sealed for ServiceBuilder {} @@ -596,4 +606,11 @@ impl ServiceBuilderExt for ServiceBuilder { ) -> ServiceBuilder> { self.layer(crate::normalize_path::NormalizePathLayer::trim_trailing_slash()) } + + #[cfg(feature = "normalize-path")] + fn append_trailing_slash( + self, + ) -> ServiceBuilder> { + self.layer(crate::normalize_path::NormalizePathLayer::append_trailing_slash()) + } } diff --git a/tower-http/src/normalize_path.rs b/tower-http/src/normalize_path.rs index efc7be52..f9b9dd2e 100644 --- a/tower-http/src/normalize_path.rs +++ b/tower-http/src/normalize_path.rs @@ -1,8 +1,5 @@ //! Middleware that normalizes paths. //! -//! Any trailing slashes from request paths will be removed. For example, a request with `/foo/` -//! will be changed to `/foo` before reaching the inner service. -//! //! # Example //! //! ``` @@ -45,11 +42,22 @@ use std::{ use tower_layer::Layer; use tower_service::Service; +/// Different modes of normalizing paths +#[derive(Debug, Copy, Clone)] +enum NormalizeMode { + /// Normalizes paths by trimming the trailing slashes, e.g. /foo/ -> /foo + Trim, + /// Normalizes paths by appending trailing slash, e.g. /foo -> /foo/ + Append, +} + /// Layer that applies [`NormalizePath`] which normalizes paths. /// /// See the [module docs](self) for more details. #[derive(Debug, Copy, Clone)] -pub struct NormalizePathLayer {} +pub struct NormalizePathLayer { + mode: NormalizeMode, +} impl NormalizePathLayer { /// Create a new [`NormalizePathLayer`]. @@ -57,7 +65,19 @@ impl NormalizePathLayer { /// Any trailing slashes from request paths will be removed. For example, a request with `/foo/` /// will be changed to `/foo` before reaching the inner service. pub fn trim_trailing_slash() -> Self { - NormalizePathLayer {} + NormalizePathLayer { + mode: NormalizeMode::Trim, + } + } + + /// Create a new [`NormalizePathLayer`]. + /// + /// Request paths without trailing slash will be appended with a trailing slash. For example, a request with `/foo` + /// will be changed to `/foo/` before reaching the inner service. + pub fn append_trailing_slash() -> Self { + NormalizePathLayer { + mode: NormalizeMode::Append, + } } } @@ -65,7 +85,10 @@ impl Layer for NormalizePathLayer { type Service = NormalizePath; fn layer(&self, inner: S) -> Self::Service { - NormalizePath::trim_trailing_slash(inner) + NormalizePath { + mode: self.mode, + inner, + } } } @@ -74,16 +97,25 @@ impl Layer for NormalizePathLayer { /// See the [module docs](self) for more details. #[derive(Debug, Copy, Clone)] pub struct NormalizePath { + mode: NormalizeMode, inner: S, } impl NormalizePath { - /// Create a new [`NormalizePath`]. - /// - /// Any trailing slashes from request paths will be removed. For example, a request with `/foo/` - /// will be changed to `/foo` before reaching the inner service. + /// Construct a new [`NormalizePath`] with trim mode. pub fn trim_trailing_slash(inner: S) -> Self { - Self { inner } + Self { + mode: NormalizeMode::Trim, + inner, + } + } + + /// Construct a new [`NormalizePath`] with append mode. + pub fn append_trailing_slash(inner: S) -> Self { + Self { + mode: NormalizeMode::Append, + inner, + } } define_inner_service_accessors!(); @@ -103,12 +135,15 @@ where } fn call(&mut self, mut req: Request) -> Self::Future { - normalize_trailing_slash(req.uri_mut()); + match self.mode { + NormalizeMode::Trim => trim_trailing_slash(req.uri_mut()), + NormalizeMode::Append => append_trailing_slash(req.uri_mut()), + } self.inner.call(req) } } -fn normalize_trailing_slash(uri: &mut Uri) { +fn trim_trailing_slash(uri: &mut Uri) { if !uri.path().ends_with('/') && !uri.path().starts_with("//") { return; } @@ -137,6 +172,40 @@ fn normalize_trailing_slash(uri: &mut Uri) { } } +fn append_trailing_slash(uri: &mut Uri) { + if uri.path().ends_with("/") && !uri.path().ends_with("//") { + return; + } + + let trimmed = uri.path().trim_matches('/'); + let new_path = if trimmed.is_empty() { + "/".to_string() + } else { + format!("/{trimmed}/") + }; + + let mut parts = uri.clone().into_parts(); + + let new_path_and_query = if let Some(path_and_query) = &parts.path_and_query { + let new_path_and_query = if let Some(query) = path_and_query.query() { + Cow::Owned(format!("{new_path}?{query}")) + } else { + new_path.into() + } + .parse() + .unwrap(); + + Some(new_path_and_query) + } else { + Some(new_path.parse().unwrap()) + }; + + parts.path_and_query = new_path_and_query; + if let Ok(new_uri) = Uri::from_parts(parts) { + *uri = new_uri; + } +} + #[cfg(test)] mod tests { use super::*; @@ -144,7 +213,7 @@ mod tests { use tower::{ServiceBuilder, ServiceExt}; #[tokio::test] - async fn works() { + async fn trim_works() { async fn handle(request: Request<()>) -> Result, Infallible> { Ok(Response::new(request.uri().to_string())) } @@ -168,63 +237,148 @@ mod tests { #[test] fn is_noop_if_no_trailing_slash() { let mut uri = "/foo".parse::().unwrap(); - normalize_trailing_slash(&mut uri); + trim_trailing_slash(&mut uri); assert_eq!(uri, "/foo"); } #[test] fn maintains_query() { let mut uri = "/foo/?a=a".parse::().unwrap(); - normalize_trailing_slash(&mut uri); + trim_trailing_slash(&mut uri); assert_eq!(uri, "/foo?a=a"); } #[test] fn removes_multiple_trailing_slashes() { let mut uri = "/foo////".parse::().unwrap(); - normalize_trailing_slash(&mut uri); + trim_trailing_slash(&mut uri); assert_eq!(uri, "/foo"); } #[test] fn removes_multiple_trailing_slashes_even_with_query() { let mut uri = "/foo////?a=a".parse::().unwrap(); - normalize_trailing_slash(&mut uri); + trim_trailing_slash(&mut uri); assert_eq!(uri, "/foo?a=a"); } #[test] fn is_noop_on_index() { let mut uri = "/".parse::().unwrap(); - normalize_trailing_slash(&mut uri); + trim_trailing_slash(&mut uri); assert_eq!(uri, "/"); } #[test] fn removes_multiple_trailing_slashes_on_index() { let mut uri = "////".parse::().unwrap(); - normalize_trailing_slash(&mut uri); + trim_trailing_slash(&mut uri); assert_eq!(uri, "/"); } #[test] fn removes_multiple_trailing_slashes_on_index_even_with_query() { let mut uri = "////?a=a".parse::().unwrap(); - normalize_trailing_slash(&mut uri); + trim_trailing_slash(&mut uri); assert_eq!(uri, "/?a=a"); } #[test] fn removes_multiple_preceding_slashes_even_with_query() { let mut uri = "///foo//?a=a".parse::().unwrap(); - normalize_trailing_slash(&mut uri); + trim_trailing_slash(&mut uri); assert_eq!(uri, "/foo?a=a"); } #[test] fn removes_multiple_preceding_slashes() { let mut uri = "///foo".parse::().unwrap(); - normalize_trailing_slash(&mut uri); + trim_trailing_slash(&mut uri); assert_eq!(uri, "/foo"); } + + #[tokio::test] + async fn append_works() { + async fn handle(request: Request<()>) -> Result, Infallible> { + Ok(Response::new(request.uri().to_string())) + } + + let mut svc = ServiceBuilder::new() + .layer(NormalizePathLayer::append_trailing_slash()) + .service_fn(handle); + + let body = svc + .ready() + .await + .unwrap() + .call(Request::builder().uri("/foo").body(()).unwrap()) + .await + .unwrap() + .into_body(); + + assert_eq!(body, "/foo/"); + } + + #[test] + fn is_noop_if_trailing_slash() { + let mut uri = "/foo/".parse::().unwrap(); + append_trailing_slash(&mut uri); + assert_eq!(uri, "/foo/"); + } + + #[test] + fn append_maintains_query() { + let mut uri = "/foo?a=a".parse::().unwrap(); + append_trailing_slash(&mut uri); + assert_eq!(uri, "/foo/?a=a"); + } + + #[test] + fn append_only_keeps_one_slash() { + let mut uri = "/foo////".parse::().unwrap(); + append_trailing_slash(&mut uri); + assert_eq!(uri, "/foo/"); + } + + #[test] + fn append_only_keeps_one_slash_even_with_query() { + let mut uri = "/foo////?a=a".parse::().unwrap(); + append_trailing_slash(&mut uri); + assert_eq!(uri, "/foo/?a=a"); + } + + #[test] + fn append_is_noop_on_index() { + let mut uri = "/".parse::().unwrap(); + append_trailing_slash(&mut uri); + assert_eq!(uri, "/"); + } + + #[test] + fn append_removes_multiple_trailing_slashes_on_index() { + let mut uri = "////".parse::().unwrap(); + append_trailing_slash(&mut uri); + assert_eq!(uri, "/"); + } + + #[test] + fn append_removes_multiple_trailing_slashes_on_index_even_with_query() { + let mut uri = "////?a=a".parse::().unwrap(); + append_trailing_slash(&mut uri); + assert_eq!(uri, "/?a=a"); + } + + #[test] + fn append_removes_multiple_preceding_slashes_even_with_query() { + let mut uri = "///foo//?a=a".parse::().unwrap(); + append_trailing_slash(&mut uri); + assert_eq!(uri, "/foo/?a=a"); + } + + #[test] + fn append_removes_multiple_preceding_slashes() { + let mut uri = "///foo".parse::().unwrap(); + append_trailing_slash(&mut uri); + assert_eq!(uri, "/foo/"); + } } diff --git a/tower-http/src/service_ext.rs b/tower-http/src/service_ext.rs index 3221afab..8973d8a4 100644 --- a/tower-http/src/service_ext.rs +++ b/tower-http/src/service_ext.rs @@ -413,6 +413,19 @@ pub trait ServiceExt { { crate::normalize_path::NormalizePath::trim_trailing_slash(self) } + + /// Append trailing slash to paths. + /// + /// See [`tower_http::normalize_path`] for more details. + /// + /// [`tower_http::normalize_path`]: crate::normalize_path + #[cfg(feature = "normalize-path")] + fn append_trailing_slash(self) -> crate::normalize_path::NormalizePath + where + Self: Sized, + { + crate::normalize_path::NormalizePath::append_trailing_slash(self) + } } impl ServiceExt for T {}