Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pointercrate-core-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub mod error;
pub mod etag;
pub mod localization;
pub mod maintenance;
pub mod normalize_uri;
pub mod pagination;
pub mod preferences;
pub mod query;
Expand Down
66 changes: 66 additions & 0 deletions pointercrate-core-api/src/normalize_uri.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
use std::sync::OnceLock;

use rocket::{
fairing::{Fairing, Info, Kind},
Data, Orbit, Request, Rocket, Route,
};

// heavily inspired by rocket's `rocket::fairing::AdHoc::uri_normalizer()` implementation
// only difference is that this applies a trailing slash internally as opposed to omitting it
// https://api.rocket.rs/master/src/rocket/fairing/ad_hoc.rs#315
pub fn uri_normalizer() -> impl Fairing {
#[derive(Default)]
struct Normalizer {
routes: OnceLock<Vec<Route>>,
}

impl Normalizer {
fn routes(&self, rocket: &Rocket<Orbit>) -> &[Route] {
// gather all defined routes which have a trailing slash
self.routes.get_or_init(|| {
rocket
.routes()
.filter(|r| r.uri.has_trailing_slash() || r.uri.path() == "/")
.cloned()
.collect()
})
}
}

#[rocket::async_trait]
impl Fairing for Normalizer {
fn info(&self) -> Info {
Info {
name: "URI Normalizer",
kind: Kind::Liftoff | Kind::Request,
}
}

async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
let _ = self.routes(rocket);
}

async fn on_request(&self, request: &mut Request<'_>, _: &mut Data<'_>) {
if request.uri().has_trailing_slash() {
return;
}

if let Some(normalized) = request.uri().map_path(|p| format!("{}/", p)) {
// check if the normalized uri (the request uri with a trailing slash) matches one of our defined routes
let mut normalized_req = request.clone();
normalized_req.set_uri(normalized.clone());

if self.routes(request.rocket()).iter().any(|r| {
// we need to leverage rocket's route matching otherwise this will suck
r.matches(&normalized_req)
}) {
// the request doesn't have a trailing slash AND it's trying to reach one of our defined routes
// so just point it to our defined route
request.set_uri(normalized);
}
}
}
}

Normalizer::default()
}
2 changes: 2 additions & 0 deletions pointercrate-example/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use maud::html;
use pointercrate_core::localization::LocalesLoader;
use pointercrate_core::pool::PointercratePool;
use pointercrate_core::{error::CoreError, localization::tr};
use pointercrate_core_api::normalize_uri::uri_normalizer;
use pointercrate_core_api::{error::ErrorResponder, maintenance::MaintenanceFairing, preferences::PreferenceManager};
use pointercrate_core_macros::localized_catcher;
use pointercrate_core_pages::{
Expand Down Expand Up @@ -178,6 +179,7 @@ async fn rocket() -> _ {
// static files.

rocket
.attach(uri_normalizer())
.mount("/static/core", FileServer::new("pointercrate-core-pages/static"))
.mount("/static/demonlist", FileServer::new("pointercrate-demonlist-pages/static"))
.mount("/static/user", FileServer::new("pointercrate-user-pages/static"))
Expand Down
Loading