diff --git a/Cargo.toml b/Cargo.toml index e0bdd17..7184002 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ rustdoc-args = ["--cfg", "docsrs"] fred = { version = "10.1", optional = true, default-features = false, features = [ "i-keys", "i-hashes", + "i-sets", ] } rand = "0.8" retainer = "0.3" diff --git a/README.md b/README.md index c62633e..50d605f 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,6 @@ Add to your `Cargo.toml`: ```toml [dependencies] -... rocket = "0.5" rocket-flex-session = { version = "0.1" } ``` diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..3096d76 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,40 @@ +//! Error types + +/// Result type for session operations +pub type SessionResult = Result; + +/// Errors that can happen during session retrieval/handling +#[derive(Debug, thiserror::Error)] +pub enum SessionError { + /// There was no session cookie, or decryption of the cookie failed + #[error("No session cookie")] + NoSessionCookie, + /// Session wasn't found in storage + #[error("Session not found")] + NotFound, + /// Session was found but it was expired + #[error("Session expired")] + Expired, + /// Error serializing or deserializing the session data + #[error("Failed to serialize/deserialize session: {0}")] + Serialization(Box), + /// An indexing operation failed because the storage provider doesn't + /// implement [SessionStorageIndexed](crate::storage::SessionStorageIndexed) + #[error("Storage doesn't support indexing")] + NonIndexedStorage, + /// A generic error from the storage backend. This error type can be + /// used when implementing a custom session storage. + #[error("Storage backend error: {0}")] + Backend(Box), + /// Error occurred while setting up or tearing down the session storage + #[error("Error during storage setup or teardown: {0}")] + SetupTeardown(String), + + #[cfg(feature = "redis_fred")] + #[error("fred.rs client error: {0}")] + RedisFredError(#[from] fred::error::Error), + + #[cfg(feature = "sqlx_postgres")] + #[error("Sqlx error: {0}")] + SqlxError(#[from] sqlx::Error), +} diff --git a/src/fairing.rs b/src/fairing.rs index d7b6f90..7a31978 100644 --- a/src/fairing.rs +++ b/src/fairing.rs @@ -1,11 +1,129 @@ use std::{ marker::{Send, Sync}, - sync::Arc, + sync::{Arc, Mutex}, }; use rocket::{fairing::Fairing, Build, Orbit, Request, Response, Rocket}; -use crate::{guard::LocalCachedSession, RocketFlexSession}; +use crate::{ + guard::LocalCachedSession, + storage::{memory::MemoryStorage, SessionStorage}, + RocketFlexSessionOptions, +}; + +/** +A Rocket fairing that enables sessions. + +# Type Parameters +* `T` - The type of your session data. Must be thread-safe and + implement Clone. The storage provider you use may have additional + trait bounds as well. + +# Example +```rust +use rocket_flex_session::{RocketFlexSession, storage::cookie::CookieStorage}; +use rocket::time::Duration; +use rocket::serde::{Deserialize, Serialize}; + +#[derive(Clone, Serialize, Deserialize)] +struct MySession { + user_id: String, + role: String, +} + +#[rocket::launch] +fn rocket() -> _ { + // Use default settings with in-memory storage + let session_fairing = RocketFlexSession::::default(); + + // Or customize settings with the builder + let custom_session = RocketFlexSession::::builder() + .storage(CookieStorage::default()) // or a custom storage provider + .with_options(|opt| { + opt.cookie_name = "my_cookie".to_string(); + opt.path = "/app".to_string(); + opt.max_age = 7 * 24 * 60 * 60; // 7 days + }) + .build(); + + rocket::build() + .attach(session_fairing) + // ... other configuration ... +} +``` +*/ +#[derive(Clone)] +pub struct RocketFlexSession { + pub(crate) options: RocketFlexSessionOptions, + pub(crate) storage: Arc>, +} +impl RocketFlexSession +where + T: Send + Sync + Clone + 'static, +{ + /// Build a session configuration + pub fn builder() -> RocketFlexSessionBuilder { + RocketFlexSessionBuilder::default() + } +} +impl Default for RocketFlexSession +where + T: Send + Sync + Clone + 'static, +{ + fn default() -> Self { + Self { + options: Default::default(), + storage: Arc::new(MemoryStorage::default()), + } + } +} + +/// Builder to configure the [RocketFlexSession] fairing +pub struct RocketFlexSessionBuilder +where + T: Send + Sync + Clone + 'static, +{ + fairing: RocketFlexSession, +} +impl Default for RocketFlexSessionBuilder +where + T: Send + Sync + Clone + 'static, +{ + fn default() -> Self { + Self { + fairing: Default::default(), + } + } +} +impl RocketFlexSessionBuilder +where + T: Send + Sync + Clone + 'static, +{ + /// Set the session options via a closure. If you're using a cookie-based storage + /// provider, make sure to set the corresponding cookie settings + /// in the storage configuration as well. + pub fn with_options(&mut self, options_fn: OptionsFn) -> &mut Self + where + OptionsFn: FnOnce(&mut RocketFlexSessionOptions), + { + options_fn(&mut self.fairing.options); + self + } + + /// Set the session storage provider + pub fn storage(&mut self, storage: S) -> &mut Self + where + S: SessionStorage + 'static, + { + self.fairing.storage = Arc::new(storage); + self + } + + /// Build the fairing + pub fn build(&self) -> RocketFlexSession { + self.fairing.clone() + } +} #[rocket::async_trait] impl Fairing for RocketFlexSession @@ -34,24 +152,29 @@ where async fn on_response<'r>(&self, req: &'r Request<'_>, _res: &mut Response<'r>) { // Get session data from request local cache, or generate a default empty one - let (session_inner, _): &LocalCachedSession = req.local_cache(|| (Arc::default(), None)); + let (session_inner, _): &LocalCachedSession = + req.local_cache(|| (Mutex::default(), None)); // Take inner session data let (updated, deleted) = session_inner.lock().unwrap().take_for_storage(); // Handle deleted session if let Some(deleted_id) = deleted { - let delete_result = self.storage.delete(&deleted_id).await; - if let Err(e) = delete_result { - rocket::error!("Error while deleting session '{}': {}", deleted_id, e); + rocket::debug!("Found deleted session. Deleting session '{deleted_id}'..."); + if let Err(e) = self.storage.delete(&deleted_id, req.cookies()).await { + rocket::warn!("Error while deleting session '{deleted_id}': {e}"); + } else { + rocket::debug!("Deleted session '{deleted_id}' successfully"); } } // Handle updated session if let Some((id, pending_data, ttl)) = updated { - let save_result = self.storage.save(&id, pending_data, ttl).await; - if let Err(e) = save_result { - rocket::error!("Error while saving session '{}': {}", &id, e); + rocket::debug!("Found updated session. Saving session '{id}'..."); + if let Err(e) = self.storage.save(&id, pending_data, ttl).await { + rocket::error!("Error while saving session '{id}': {e}"); + } else { + rocket::debug!("Saved session '{id}' successfully"); } } } @@ -59,7 +182,7 @@ where async fn on_shutdown(&self, _rocket: &Rocket) { rocket::debug!("Shutting down session resources..."); if let Err(e) = self.storage.shutdown().await { - rocket::warn!("Error during session storage shutdown: {}", e); + rocket::warn!("Error during session storage shutdown: {e}"); } } } diff --git a/src/guard.rs b/src/guard.rs index 7792aba..32cd2b8 100644 --- a/src/guard.rs +++ b/src/guard.rs @@ -1,23 +1,18 @@ -use std::{ - any::type_name, - sync::{Arc, Mutex}, -}; +use std::{any::type_name, sync::Mutex}; use rocket::{ - http::{Cookie, CookieJar}, + http::CookieJar, request::{FromRequest, Outcome}, Request, }; use crate::{ - session::Session, - session_inner::SessionInner, - storage::interface::{SessionError, SessionStorage}, - RocketFlexSession, + error::SessionError, session_inner::SessionInner, storage::SessionStorage, RocketFlexSession, + Session, }; /// Type of the cached inner session data in Rocket's request local cache -pub(crate) type LocalCachedSession = (Arc>>, Option); +pub(crate) type LocalCachedSession = (Mutex>, Option); #[rocket::async_trait] impl<'r, T> FromRequest<'r> for Session<'r, T> @@ -31,24 +26,24 @@ where let fairing = get_fairing::(req.rocket()); let cookie_jar = req.cookies(); + // Use rocket's local cache so that the session data is only fetched once per request let (cached_inner, session_error): &LocalCachedSession = req .local_cache_async(async { - let session_cookie = cookie_jar.get_private(&fairing.options.cookie_name); - get_session_data( - session_cookie, + fetch_session_data( + cookie_jar, + &fairing.options.cookie_name, fairing .options .rolling .then(|| fairing.options.ttl.unwrap_or(fairing.options.max_age)), fairing.storage.as_ref(), - cookie_jar, ) .await }) .await; Outcome::Success(Session::new( - cached_inner.clone(), + cached_inner, session_error.as_ref(), cookie_jar, &fairing.options, @@ -71,33 +66,32 @@ where }) } -/// Get session data from storage +/// Fetch session data from storage #[inline(always)] -async fn get_session_data<'r, T: Send + Sync + Clone>( - session_cookie: Option>, +async fn fetch_session_data<'r, T: Send + Sync + Clone>( + cookie_jar: &'r CookieJar<'_>, + cookie_name: &str, rolling_ttl: Option, storage: &'r dyn SessionStorage, - cookie_jar: &'r CookieJar<'_>, ) -> LocalCachedSession { + let session_cookie = cookie_jar.get_private(cookie_name); if let Some(cookie) = session_cookie { let id = cookie.value(); - rocket::debug!("Got session id '{}' from cookie. Retrieving session...", id); + rocket::debug!("Got session id '{id}' from cookie. Retrieving session..."); match storage.load(id, rolling_ttl, cookie_jar).await { Ok((data, ttl)) => { rocket::debug!("Session found. Creating existing session..."); - ( - Arc::new(Mutex::new(SessionInner::new_existing(id, data, ttl))), - None, - ) + let session_inner = SessionInner::new_existing(id, data, ttl); + (Mutex::new(session_inner), None) } Err(e) => { - rocket::debug!("Error from session storage, creating empty session: {}", e); - (Arc::default(), Some(e)) + rocket::info!("Error from session storage, creating empty session: {e}"); + (Mutex::default(), Some(e)) } } } else { rocket::debug!("No valid session cookie found. Creating empty session..."); - (Arc::default(), Some(SessionError::NoSessionCookie)) + (Mutex::default(), Some(SessionError::NoSessionCookie)) } } diff --git a/src/lib.rs b/src/lib.rs index ad99697..641ff57 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,7 +14,9 @@ Simple, extensible session library for Rocket applications. call will be made to get the session data, and if the session is updated multiple times during the request, only one call will be made at the end of the request to save the session. - Multiple storage providers available, or you can - use your own session storage by implementing the [SessionStorage] trait. + use your own session storage by implementing the [`SessionStorage`](crate::storage::SessionStorage) trait. +- Optional session indexing support for advanced features like multi-device login tracking, + bulk session invalidation, and security auditing. # Usage While technically not needed for development, it is highly recommended to @@ -102,160 +104,211 @@ For more info and examples of this powerful pattern, please see Rocket's documen ## HashMap session data -Instead of a custom struct, you can use a [HashMap](std::collections::HashMap) as your Session data type. This is -particularly useful if you expect your session data structure to be inconsistent and/or change frequently. -When using a HashMap, there are [some additional helper functions](file:///Users/farshad/Projects/pg-user-manager/api/target/doc/rocket_flex_session/struct.Session.html#method.get_key) -to read and set keys. +If your session data has a hashmap data structure, you can implement [`SessionHashMap`] which will +add [additional helper methods](Session::get_key) to Session to read and set keys. This is particularly useful if you expect your +session data structure to be inconsistent and/or change frequently. ``` -use rocket_flex_session::Session; +use rocket_flex_session::{Session, SessionHashMap}; use std::collections::HashMap; -type MySessionData = HashMap; +#[derive(Clone, Default)] +struct MySession(HashMap); + +impl SessionHashMap for MySession { + type Value = String; + + fn get(&self, key: &str) -> Option<&Self::Value> { + self.0.get(key) + } + fn insert(&mut self, key: String, value: Self::Value) { + self.0.insert(key, value); + } + fn remove(&mut self, key: &str) { + self.0.remove(key); + } +} #[rocket::post("/login")] -fn login(mut session: Session) { +fn login(mut session: Session) { let user_id: Option = session.get_key("user_id"); session.set_key("name".to_owned(), "Bob".to_owned()); + session.remove_key("foobar"); } ``` -# Feature flags - -These features can be enabled as shown -[in Cargo's documentation](https://doc.rust-lang.org/cargo/reference/features.html). - -| Name | Description | -|---------|----------------| -| `cookie` | A cookie-based session store. Data is serialized using serde_json and then encrypted into the value of a cookie. | -| `redis_fred` | A session store for Redis (and Redis-compatible databases), using the [fred.rs](https://docs.rs/crate/fred) crate. | -| `sqlx_postgres` | A session store using PostgreSQL via the [sqlx](https://docs.rs/crate/sqlx) crate. | -| `rocket_okapi` | Enables support for the [rocket_okapi](https://docs.rs/crate/rocket_okapi) crate if needed. | -*/ +## Session Indexing -mod fairing; -mod guard; -mod options; -mod session; -mod session_inner; +For use cases like multi-device login tracking or other security features, you can use a storage +provider that supports indexing, and then group sessions by an identifier (such as a user ID) using the [`SessionIdentifier`] trait: -pub mod storage; -pub use options::SessionOptions; -pub use session::Session; +```rust +use rocket::routes; +use rocket_flex_session::{Session, SessionIdentifier, RocketFlexSession}; +use rocket_flex_session::storage::memory::MemoryStorageIndexed; -use crate::storage::{interface::SessionStorage, memory::MemoryStorage}; -use std::sync::Arc; +#[derive(Clone)] +struct UserSession { + user_id: String, + device_name: String, +} -/** -A Rocket fairing that enables sessions. +impl SessionIdentifier for UserSession { + type Id = String; -# Type Parameters -* `T` - The type of your session data. Must be thread-safe and - implement Clone. The storage provider you use may have additional - trait bounds as well. + fn identifier(&self) -> Option<&Self::Id> { + Some(&self.user_id) // Group sessions by user_id + } +} -# Example -```rust -use rocket_flex_session::{RocketFlexSession, SessionOptions, storage::cookie::CookieStorage}; -use rocket::time::Duration; -use rocket::serde::{Deserialize, Serialize}; +#[rocket::get("/user/sessions")] +async fn get_all_user_sessions(session: Session<'_, UserSession>) -> String { + match session.get_all_sessions().await { + Ok(Some(sessions)) => format!("Found {} active sessions", sessions.len()), + Ok(None) => "No active session".to_string(), + Err(e) => format!("Error: {}", e), + } +} -#[derive(Clone, Serialize, Deserialize)] -struct MySession { - user_id: String, - role: String, +#[rocket::get("/user/logout-everywhere")] +async fn logout_everywhere(session: Session<'_, UserSession>) -> String { + match session.invalidate_all_sessions(false).await { + Ok(Some(n)) => format!("Logged out from {n} sessions"), + Ok(None) => "No active session".to_string(), + Err(e) => format!("Error: {}", e), + } } #[rocket::launch] fn rocket() -> _ { - // Use default settings - let session_fairing = RocketFlexSession::::default(); - - // Or customize settings with the builder - let custom_session = RocketFlexSession::::builder() - .storage(CookieStorage::default()) // or a custom storage provider - .with_options(|opt| { - opt.cookie_name = "my_cookie".to_string(); - opt.path = "/app".to_string(); - opt.max_age = 7 * 24 * 60 * 60; // 7 days - }) - .build(); - rocket::build() - .attach(session_fairing) - // ... other configuration ... + .attach( + RocketFlexSession::::builder() + .storage(MemoryStorageIndexed::default()) + .build() + ) + .mount("/", routes![get_all_user_sessions, logout_everywhere]) } ``` -*/ -#[derive(Clone)] -pub struct RocketFlexSession { - pub(crate) options: SessionOptions, - pub(crate) storage: Arc>, -} -impl RocketFlexSession + +# Storage Providers + +This crate supports multiple storage backends with different capabilities: + +## Available Storage Providers + +| Storage | Feature Flag | Indexing support | Use Cases | +|---------|-------------|------------------|------------| +| [`storage::memory::MemoryStorage`] | Built-in | ❌ | Development, testing | +| [`storage::memory::MemoryStorageIndexed`] | Built-in | ✅ | Development with indexing features | +| [`storage::cookie::CookieStorage`] | `cookie` | ❌ | Client-side storage, stateless servers | +| [`storage::redis::RedisFredStorage`] | `redis_fred` | ❌ | Production, distributed systems | +| [`storage::redis::RedisFredStorageIndexed`] | `redis_fred` | ✅ | Production, distributed systems | +| [`storage::sqlx::SqlxPostgresStorage`] | `sqlx_postgres` | ✅ | Production, existing database | + + +## Custom Storage + +To implement a custom storage provider, implement the [`SessionStorage`](crate::storage::SessionStorage) trait: + +```rust +use rocket_flex_session::{error::SessionResult, storage::SessionStorage}; +use rocket::{async_trait, http::CookieJar}; + +pub struct MyCustomStorage {} + +#[async_trait] +impl SessionStorage for MyCustomStorage where T: Send + Sync + Clone + 'static, { - /// Build a session configuration - pub fn builder() -> RocketFlexSessionBuilder { - RocketFlexSessionBuilder::default() + async fn load(&self, id: &str, ttl: Option, cookie_jar: &CookieJar) -> SessionResult<(T, u32)> { + // Load session from your storage + todo!() } -} -impl Default for RocketFlexSession -where - T: Send + Sync + Clone + 'static, -{ - fn default() -> Self { - Self { - options: Default::default(), - storage: Arc::new(MemoryStorage::default()), - } + + async fn save(&self, id: &str, data: T, ttl: u32) -> SessionResult<()> { + // Save session to your storage + todo!() } -} -/// Builder to configure the [RocketFlexSession] fairing -pub struct RocketFlexSessionBuilder -where - T: Send + Sync + Clone + 'static, -{ - fairing: RocketFlexSession, + async fn delete(&self, id: &str, cookie_jar: &CookieJar) -> SessionResult<()> { + // Delete session from your storage + todo!() + } } -impl Default for RocketFlexSessionBuilder +``` + +### Adding Indexing Support + +To support session indexing, also implement [`SessionStorageIndexed`](crate::storage::SessionStorageIndexed) and add the `as_indexed_storage` method +to the [`SessionStorage`](crate::storage::SessionStorage) trait: + + +```rust,ignore +use rocket_flex_session::{error::SessionResult, storage::{SessionStorage, SessionStorageIndexed, SessionIdentifier}}; + +struct MyCustomStorage; + +#[async_trait] +impl SessionStorageIndexed for MyCustomStorage where - T: Send + Sync + Clone + 'static, + T: SessionIdentifier + Send + Sync + Clone + 'static, { - fn default() -> Self { - Self { - fairing: Default::default(), - } + async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult> { + // Return all sessions (session_id, session_data, session_ttl) for the given identifier + todo!() } + // etc... } -impl RocketFlexSessionBuilder + +// Make sure to also add this to the `SessionStorage` trait to enable indexing support +#[async_trait] +impl SessionStorage for MyCustomStorage where T: Send + Sync + Clone + 'static, { - /// Set the session options via a closure. If you're using a cookie-based storage - /// provider, make sure to set the corresponding cookie settings - /// in the storage configuration as well. - pub fn with_options(&mut self, options_fn: OptionsFn) -> &mut Self - where - OptionsFn: FnOnce(&mut SessionOptions), - { - options_fn(&mut self.fairing.options); - self - } - - /// Set the session storage provider - pub fn storage(&mut self, storage: S) -> &mut Self - where - S: SessionStorage + 'static, - { - self.fairing.storage = Arc::new(storage); - self - } + // ... other methods ... - /// Build the fairing - pub fn build(&self) -> RocketFlexSession { - self.fairing.clone() + fn as_indexed_storage(&self) -> Option<&dyn SessionStorageIndexed> { + Some(self) // Enable indexing support } } +``` + +### Implementation Tips + +1. **Trait bounds**: Add additional trait bounds to the session data type `` as needed +2. **Error Handling**: Use [`error::SessionError::Backend`] for custom errors +3. **TTL Handling**: Respect the TTL parameters in `load` and `save` for session expiration +4. **Indexing Consistency**: Keep identifier indexes in sync with session data +5. **Cleanup**: Implement proper cleanup in `shutdown()` if needed + +# Feature flags + +These features can be enabled as shown +[in Cargo's documentation](https://doc.rust-lang.org/cargo/reference/features.html). + +| Name | Description | +|---------|----------------| +| `cookie` | A cookie-based session store. Data is serialized using serde_json and then encrypted into the value of a cookie. | +| `redis_fred` | A session store for Redis (and Redis-compatible databases), using the [fred.rs](https://docs.rs/crate/fred) crate. | +| `sqlx_postgres` | A session store using PostgreSQL via the [sqlx](https://docs.rs/crate/sqlx) crate. | +| `rocket_okapi` | Enables support for the [rocket_okapi](https://docs.rs/crate/rocket_okapi) crate if needed. | +*/ + +mod fairing; +mod guard; +mod options; +mod session; +mod session_hash; +mod session_index; +mod session_inner; + +pub mod error; +pub mod storage; +pub use fairing::{RocketFlexSession, RocketFlexSessionBuilder}; +pub use options::RocketFlexSessionOptions; +pub use session::Session; +pub use session_hash::SessionHashMap; +pub use session_index::SessionIdentifier; diff --git a/src/options.rs b/src/options.rs index 5bb7ab0..53204aa 100644 --- a/src/options.rs +++ b/src/options.rs @@ -1,6 +1,6 @@ /// Options for configuring the session. #[derive(Clone, Debug)] -pub struct SessionOptions { +pub struct RocketFlexSessionOptions { /// The name of the cookie used to store the session ID (default: `"rocket"`) pub cookie_name: String, /// The session cookie's `Domain` attribute (default: `None`) @@ -26,7 +26,7 @@ pub struct SessionOptions { pub ttl: Option, } -impl Default for SessionOptions { +impl Default for RocketFlexSessionOptions { fn default() -> Self { Self { cookie_name: "rocket".to_owned(), diff --git a/src/session.rs b/src/session.rs index 8b0883d..ce09bd2 100644 --- a/src/session.rs +++ b/src/session.rs @@ -3,23 +3,20 @@ use rocket::{ time::{Duration, OffsetDateTime}, }; use std::{ - collections::HashMap, fmt::Display, - hash::Hash, marker::{Send, Sync}, - sync::{Arc, Mutex, MutexGuard}, + sync::{Mutex, MutexGuard}, }; use crate::{ - options::SessionOptions, - session_inner::SessionInner, - storage::interface::{SessionError, SessionStorage}, + error::SessionError, options::RocketFlexSessionOptions, session_inner::SessionInner, + storage::SessionStorage, }; /** Represents the current session state. When used as a request guard, it will attempt to retrieve the session. The request guard will always succeed - if a -valid session wasn't found, `session.get()` will return `None` indicating an +valid session wasn't found, the data functions will return `None` indicating an inactive session. # Type Parameters @@ -47,18 +44,18 @@ fn profile(session: Session) -> String { */ pub struct Session<'a, T> where - T: Send + Sync, + T: Send + Sync + Clone, { /// Internal mutable state of the session - inner: Arc>>, + inner: &'a Mutex>, /// Error (if any) when retrieving from storage error: Option<&'a SessionError>, /// Rocket's cookie jar for managing cookies cookie_jar: &'a CookieJar<'a>, /// User's session options - options: &'a SessionOptions, + options: &'a RocketFlexSessionOptions, /// Configured storage provider for sessions - storage: &'a dyn SessionStorage, + pub(crate) storage: &'a dyn SessionStorage, } impl Display for Session<'_, T> @@ -66,7 +63,7 @@ where T: Send + Sync + Clone, { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "Session(id: {:?})", self.get_inner().get_id()) + write!(f, "Session(id: {:?})", self.get_inner_lock().get_id()) } } @@ -76,10 +73,10 @@ where { /// Create a new session instance to keep track of the session state in a request pub(crate) fn new( - inner: Arc>>, + inner: &'a Mutex>, error: Option<&'a SessionError>, cookie_jar: &'a CookieJar<'a>, - options: &'a SessionOptions, + options: &'a RocketFlexSessionOptions, storage: &'a dyn SessionStorage, ) -> Self { Self { @@ -93,30 +90,56 @@ where /// Get the session ID (alphanumeric string). Will be `None` if there's no active session. pub fn id(&self) -> Option { - self.get_inner().get_id().map(|s| s.to_owned()) + self.get_inner_lock().get_id().map(|s| s.to_owned()) } /// Get the current session data via cloning. Will be `None` if there's no active session. pub fn get(&self) -> Option { - self.get_inner().get_current_data().map(|d| d.to_owned()) + self.get_inner_lock() + .get_current_data() + .map(|d| d.to_owned()) } /// Get a reference to the current session data via a closure. - /// Data will be `None` if there's no active session. + /// The closure's argument will be `None` if there's no active session. + /// + /// # Example + /// ```rust,ignore + /// session.tap(|data| { + /// if let Some(data) = data { + /// println!("Session data: {:?}", data); + /// } else { + /// println!("No active session"); + /// } + /// }); + /// ``` pub fn tap(&self, f: F) -> R where F: FnOnce(Option<&T>) -> R, { - f(self.get_inner().get_current_data()) + f(self.get_inner_lock().get_current_data()) } /// Get a mutable reference to the current session data via a closure. - /// The function's argument will be `None` if there's no active session. + /// The closure's argument will be `None` if there's no active session. + /// + /// # Example + /// ```rust,ignore + /// session.tap_mut(|data| { + /// if let Some(data) = data { + /// data.foo = new_value; + /// } else { + /// println!("No active session"); + /// } + /// }); + /// ``` pub fn tap_mut(&mut self, f: UpdateFn) -> R where UpdateFn: FnOnce(&mut Option) -> R, { - let (response, is_deleted) = self.get_inner().tap_data_mut(f, self.get_default_ttl()); + let (response, is_deleted) = self + .get_inner_lock() + .tap_data_mut(f, self.get_default_ttl()); if is_deleted { self.delete(); } else { @@ -126,22 +149,24 @@ where response } - /// Set/update the session data. Will create a new active session if needed. + /// Set/replace the session data. Will create a new active session if there isn't one. pub fn set(&mut self, new_data: T) { - self.get_inner().set_data(new_data, self.get_default_ttl()); + self.get_inner_lock() + .set_data(new_data, self.get_default_ttl()); self.update_cookies(); } /// Set the TTL of the session in seconds. This can be used to extend the length - /// of the session if needed. This has no effect if there is no active session. + /// of the session if needed. This has no effect if there is no active session, or + /// if you have enabled "rolling" sessions in the [`options`](RocketFlexSessionOptions::rolling). pub fn set_ttl(&mut self, new_ttl: u32) { - self.get_inner().set_ttl(new_ttl); + self.get_inner_lock().set_ttl(new_ttl); self.update_cookies(); } /// Get the session TTL in seconds. pub fn ttl(&self) -> u32 { - self.get_inner() + self.get_inner_lock() .get_current_ttl() .unwrap_or(self.get_default_ttl()) } @@ -151,10 +176,10 @@ where OffsetDateTime::now_utc().saturating_add(Duration::seconds(self.ttl().into())) } - /// Delete the session. + /// Delete the current session. pub fn delete(&mut self) { // Delete inner session data - let mut inner = self.get_inner(); + let mut inner = self.get_inner_lock(); inner.delete(); // Remove the session cookie @@ -183,18 +208,18 @@ where self.error } - fn get_inner(&self) -> MutexGuard<'_, SessionInner> { + pub(crate) fn get_inner_lock(&self) -> MutexGuard<'_, SessionInner> { self.inner.lock().expect("Failed to get session data lock") } - fn get_default_ttl(&self) -> u32 { + pub(super) fn get_default_ttl(&self) -> u32 { self.options.ttl.unwrap_or(self.options.max_age) } - fn update_cookies(&self) { - let inner = self.get_inner(); + pub(super) fn update_cookies(&self) { + let inner = self.get_inner_lock(); let Some(id) = inner.get_id() else { - rocket::info!("Cookies not updated: no active session"); + rocket::warn!("Cookies not updated: no active session"); return; }; @@ -215,52 +240,8 @@ where } } -impl Session<'_, HashMap> -where - K: Eq + Hash + Send + Sync + Clone, - V: Send + Sync + Clone, -{ - /// Get the value of a key in the session data via cloning - pub fn get_key(&self, key: &Q) -> Option - where - Q: ?Sized + Eq + Hash, - K: std::borrow::Borrow, - { - self.get_inner() - .get_current_data() - .and_then(|h| h.get(key).cloned()) - } - - /// Set the value of a key in the session data. Will create - /// a new session if needed. - pub fn set_key(&mut self, key: K, value: V) { - let mut new_data = self - .get_inner() - .get_current_data() - .cloned() - .unwrap_or_default(); - new_data.insert(key, value); - self.set(new_data); - } - - /// Set multiple keys and values in the session data. Will create - /// a new session if needed. - pub fn set_keys(&mut self, kv_iter: I) - where - I: IntoIterator, - { - let mut new_data = self - .get_inner() - .get_current_data() - .cloned() - .unwrap_or_default(); - new_data.extend(kv_iter); - self.set(new_data); - } -} - /// Create the session cookie -fn create_session_cookie(id: &str, options: &SessionOptions) -> Cookie<'static> { +fn create_session_cookie(id: &str, options: &RocketFlexSessionOptions) -> Cookie<'static> { let mut cookie = Cookie::build((options.cookie_name.to_owned(), id.to_owned())) .http_only(options.http_only) .max_age(Duration::seconds(options.max_age.into())) diff --git a/src/session_hash.rs b/src/session_hash.rs new file mode 100644 index 0000000..55b7f8f --- /dev/null +++ b/src/session_hash.rs @@ -0,0 +1,71 @@ +use crate::Session; + +/// Optional trait for sessions with a hashmap-like data structure. +pub trait SessionHashMap: Send + Sync + Clone + Default { + /// The type of values stored in the session hashmap. + type Value: Send + Sync + Clone; + + /// Get a reference to the value associated with the given key. + fn get(&self, key: &str) -> Option<&Self::Value>; + + /// Inserts or updates a key-value pair into the map. + fn insert(&mut self, key: String, value: Self::Value); + + /// Removes a key from the map. + fn remove(&mut self, key: &str); + + // /// Returns the number of keys in the map. + // fn len(&self) -> usize; + + // /// Returns an iterator over the key-value pairs in the map. + // fn iter(&self) -> std::slice::Iter<'_, (&str, &Self::Value)>; + + // /// Returns an iterator over the key-value pairs in the map, with mutable references. + // fn iter_mut(&mut self) -> std::slice::IterMut<'_, (&str, &mut Self::Value)>; +} + +/// Implementation block for sessions with hashmap-like data structures +impl Session<'_, T> +where + T: SessionHashMap, +{ + /// Get the value of a key in the session data via cloning + pub fn get_key(&self, key: &str) -> Option { + self.get_inner_lock() + .get_current_data() + .and_then(|h| h.get(key).cloned()) + } + + /// Get the value of a key in the session data via a closure + pub fn tap_key(&self, key: &str, f: F) -> R + where + F: FnOnce(Option<&T::Value>) -> R, + { + f(self + .get_inner_lock() + .get_current_data() + .and_then(|d| d.get(key))) + } + + /// Set the value of a key in the session data. Will create a new session if there isn't one. + pub fn set_key(&mut self, key: String, value: T::Value) { + self.get_inner_lock().tap_data_mut( + |data| data.get_or_insert_default().insert(key, value), + self.get_default_ttl(), + ); + self.update_cookies(); + } + + /// Remove a key from the session data. + pub fn remove_key(&mut self, key: &str) { + self.get_inner_lock().tap_data_mut( + |data| { + if let Some(data) = data { + data.remove(key); + } + }, + self.get_default_ttl(), + ); + self.update_cookies(); + } +} diff --git a/src/session_index.rs b/src/session_index.rs new file mode 100644 index 0000000..6077045 --- /dev/null +++ b/src/session_index.rs @@ -0,0 +1,135 @@ +use crate::{error::SessionError, storage::SessionStorageIndexed, Session}; + +/// Optional trait for session data types that can be grouped by an identifier. +/// This enables features like retrieving all sessions for a user or invalidating +/// all sessions when a user's password changes. +/// +/// The storage provider must support indexing sessions (check the docs for the +/// provider you're using). +/// +/// # Example +/// ```rust +/// use rocket_flex_session::SessionIdentifier; +/// +/// #[derive(Clone)] +/// struct MySession { +/// user_id: String, +/// role: String, +/// } +/// +/// impl SessionIdentifier for MySession { +/// const IDENTIFIER: &str = "user_id"; +/// type Id = String; +/// +/// fn identifier(&self) -> Option<&Self::Id> { +/// Some(&self.user_id) +/// } +/// } +/// ``` +pub trait SessionIdentifier: Send + Sync + Clone { + /// The name of the identifier (default: `"user_id"`), that may be used as a field/key name by the storage backend. + const IDENTIFIER: &str = "user_id"; + + /// The type of the identifier + type Id: Send + Sync + Clone; + + /// Extract the identifier from the session data. This identifier + /// should be immutable for the lifetime of the session. + /// Can return `None` if a session doesn't have an identifier and/or + /// shouldn't be indexed. + fn identifier(&self) -> Option<&Self::Id>; +} + +/// Session implementation block for indexing operations +impl<'a, T> Session<'a, T> +where + T: SessionIdentifier, +{ + /// Get all active sessions for the same user/identifier as the current session. + /// Returns the session ID, data, and TTL (in seconds) for each session. + /// Returns `None` if there's no current session or the session isn't indexed. + pub async fn get_all_sessions(&self) -> Result>, SessionError> { + let Some(identifier) = self.get_identifier() else { + return Ok(None); + }; + let storage = self.get_indexed_storage()?; + let sessions = storage.get_sessions_by_identifier(&identifier).await?; + + Ok(Some(sessions)) + } + + /// Get all active session IDs for the same user/identifier as the current session. + /// Returns `None` if there's no current session or the session isn't indexed. + pub async fn get_all_session_ids(&self) -> Result>, SessionError> { + let Some(identifier) = self.get_identifier() else { + return Ok(None); + }; + let storage = self.get_indexed_storage()?; + let session_ids = storage.get_session_ids_by_identifier(&identifier).await?; + + Ok(Some(session_ids)) + } + + /// Invalidate all sessions with the same user/identifier as the current session, optionally keeping the current session active. + /// Returns the number of sessions invalidated, or `None` if there's no current session or the session isn't indexed. + pub async fn invalidate_all_sessions( + &self, + keep_current: bool, + ) -> Result, SessionError> { + let Some((session_id, identifier)) = self.id().zip(self.get_identifier()) else { + return Ok(None); + }; + let storage = self.get_indexed_storage()?; + let num_sessions = storage + .invalidate_sessions_by_identifier( + &identifier, + keep_current.then_some(session_id.as_str()), + ) + .await?; + + Ok(Some(num_sessions)) + } + + /// Get all session IDs, data, and TTL (in seconds) for a specific user/identifier. + pub async fn get_sessions_by_identifier( + &self, + identifier: &T::Id, + ) -> Result, SessionError> { + let storage = self.get_indexed_storage()?; + storage.get_sessions_by_identifier(identifier).await + } + + /// Get all session IDs for a specific user/identifier. + pub async fn get_session_ids_by_identifier( + &self, + identifier: &T::Id, + ) -> Result, SessionError> { + let storage = self.get_indexed_storage()?; + storage.get_session_ids_by_identifier(identifier).await + } + + /// Invalidate all sessions for a specific user/identifier, returning the number of sessions invalidated. + pub async fn invalidate_sessions_by_identifier( + &self, + identifier: &T::Id, + ) -> Result { + let storage = self.get_indexed_storage()?; + storage + .invalidate_sessions_by_identifier(identifier, None) + .await + } + + /// Get the current session's identifier, if there is one. + fn get_identifier(&self) -> Option { + self.get_inner_lock().get_current_identifier().cloned() + } + + /// Try to cast the storage as an indexed storage + fn get_indexed_storage(&self) -> Result<&dyn SessionStorageIndexed, SessionError> { + let indexed_storage = self + .storage + .as_indexed_storage() + .ok_or(SessionError::NonIndexedStorage)?; + Ok(indexed_storage) + } +} diff --git a/src/session_inner.rs b/src/session_inner.rs index f148191..a6cf405 100644 --- a/src/session_inner.rs +++ b/src/session_inner.rs @@ -1,5 +1,7 @@ use rand::distributions::{Alphanumeric, DistString}; +use crate::SessionIdentifier; + /// Represents a current, active session struct ActiveSession { /// Session ID (20-character alphanumeric string) @@ -151,3 +153,12 @@ where ) } } + +impl SessionInner +where + T: SessionIdentifier + Clone, +{ + pub(crate) fn get_current_identifier(&self) -> Option<&T::Id> { + self.get_current_data().and_then(|data| data.identifier()) + } +} diff --git a/src/storage.rs b/src/storage.rs index c0bc63e..61b19e3 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -1,6 +1,27 @@ //! Storage implementations for sessions +//! +//! This module provides various storage backends for session data, with optional +//! support for session indexing by identifier. +//! +//! ## Session Indexing +//! +//! Some storage backends support indexing sessions by an identifier (like user ID). +//! This enables advanced features such as: +//! +//! - Finding all active sessions for a user +//! - Bulk invalidation of sessions (e.g., "log out everywhere") +//! - Security auditing and monitoring +//! +//! To use indexing, your session type must implement [`crate::SessionIdentifier`]. +//! +//! ## Custom Storage +//! +//! Implement [`SessionStorage`] to create custom storage backends. For indexing +//! support, also implement [`SessionStorageIndexed`]. + +mod interface; +pub use interface::*; -pub mod interface; pub mod memory; #[cfg(feature = "cookie")] diff --git a/src/storage/cookie.rs b/src/storage/cookie.rs index 759537d..90fed2f 100644 --- a/src/storage/cookie.rs +++ b/src/storage/cookie.rs @@ -7,7 +7,9 @@ use rocket::{ time::{Duration, OffsetDateTime}, }; -use super::interface::{SessionError, SessionResult, SessionStorage}; +use crate::error::{SessionError, SessionResult}; + +use super::interface::SessionStorage; /** Storage provider for sessions backed by cookies. All session data is serialized to JSON @@ -170,7 +172,7 @@ where Ok(()) // no-op (cookie session should already be saved by `save_cookie`) } - async fn delete(&self, _id: &str) -> SessionResult<()> { + async fn delete(&self, _id: &str, _cookie_jar: &CookieJar) -> SessionResult<()> { Ok(()) // no-op (cookie session should already be deleted by `save_cookie`) } } diff --git a/src/storage/interface.rs b/src/storage/interface.rs index 9af0e23..51788e5 100644 --- a/src/storage/interface.rs +++ b/src/storage/interface.rs @@ -1,52 +1,8 @@ //! Shared interface for session storage -use std::fmt::Debug; - use rocket::{async_trait, http::CookieJar}; -/// Errors that can happen during session retrieval/handling -#[derive(Debug, thiserror::Error)] -pub enum SessionError { - /// There was no session cookie, or decryption of the cookie failed - #[error("No session cookie")] - NoSessionCookie, - /// Session wasn't found in storage - #[error("Session not found")] - NotFound, - /// Session was found but it was expired - #[error("Session expired")] - Expired, - /// Error serializing or deserializing the session data - #[error("Failed to serialize/deserialize session")] - Serialization(Box), - /// An unexpected error from the storage backend - #[error("Storage backend error: {0}")] - Backend(Box), - - #[cfg(feature = "redis_fred")] - #[error("fred.rs client error: {0}")] - RedisFredError(fred::error::Error), - - #[cfg(feature = "sqlx_postgres")] - #[error("Sqlx error: {0}")] - SqlxError(sqlx::Error), -} - -#[cfg(feature = "redis_fred")] -impl From for SessionError { - fn from(value: fred::error::Error) -> Self { - SessionError::RedisFredError(value) - } -} - -#[cfg(feature = "sqlx_postgres")] -impl From for SessionError { - fn from(value: sqlx::Error) -> Self { - SessionError::SqlxError(value) - } -} - -pub type SessionResult = Result; +use crate::{error::SessionResult, SessionIdentifier}; /// Trait representing a session backend storage. You can use your own session storage /// by implementing this trait. @@ -57,7 +13,7 @@ where { /// Load session data and TTL (time-to-live in seconds) from storage. If a TTL value is provided, /// it should be set upon retreiving the session. If session is already expired - /// or otherwise invalid, a [SessionError] should be returned instead. + /// or otherwise invalid, a [`SessionError`](crate::error::SessionError) should be returned instead. async fn load( &self, id: &str, @@ -69,7 +25,7 @@ where async fn save(&self, id: &str, data: T, ttl: u32) -> SessionResult<()>; /// Delete a session in storage. This will be performed at the end of the request lifecycle. - async fn delete(&self, id: &str) -> SessionResult<()>; + async fn delete(&self, id: &str, cookie_jar: &CookieJar) -> SessionResult<()>; /// Optional callback when there's a pending change to the session data. A `data` value /// of `None` indicates a deleted session. This callback can be used by cookie-based @@ -85,6 +41,12 @@ where Ok(()) // Default no-op } + /// Storages that support indexing (by implementing [`SessionStorageIndexed`]) must + /// also implement this. Implementation should be trivial: `Some(self)` + fn as_indexed_storage(&self) -> Option<&dyn SessionStorageIndexed> { + None // Default not supported + } + /// Optional setup of resources that will be called on server startup async fn setup(&self) -> SessionResult<()> { Ok(()) // Default no-op @@ -95,3 +57,28 @@ where Ok(()) // Default no-op } } + +/// Extended trait for storage backends that support session indexing by identifier. +/// This allows operations like finding all sessions for a user or bulk invalidation. +/// +/// Not all storage backends can support this - for example, cookie-based storage +/// cannot implement this trait since cookies are only persisted on the client-side. +#[async_trait] +pub trait SessionStorageIndexed: SessionStorage +where + T: SessionIdentifier + Send + Sync, +{ + /// Retrieve all tracked session IDs associated with the given identifier. + async fn get_session_ids_by_identifier(&self, id: &T::Id) -> SessionResult>; + + /// Retrieve all tracked session IDs, data, and TTL for the given identifier. + async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult>; + + /// Invalidate all tracked sessions associated with the given identifier, optionally excluding one session ID. + /// Returns the number of sessions invalidated. + async fn invalidate_sessions_by_identifier( + &self, + id: &T::Id, + excluded_session_id: Option<&str>, + ) -> SessionResult; +} diff --git a/src/storage/memory.rs b/src/storage/memory.rs index 73bc03a..7d4be50 100644 --- a/src/storage/memory.rs +++ b/src/storage/memory.rs @@ -1,6 +1,7 @@ //! In-memory session storage implementation use std::{ + collections::{HashMap, HashSet}, sync::{Arc, Mutex}, time::Duration, }; @@ -12,11 +13,18 @@ use rocket::{ tokio::{select, spawn, sync::oneshot}, }; -use super::interface::{SessionError, SessionResult, SessionStorage}; +use crate::{ + error::{SessionError, SessionResult}, + SessionIdentifier, +}; + +use super::interface::{SessionStorage, SessionStorageIndexed}; /// In-memory storage provider for sessions. This is designed mostly for local /// development, and not for production use. It uses the [retainer] crate to /// create an async cache. +/// +/// For session indexing support, see [`MemoryStorageIndexed`]. pub struct MemoryStorage { shutdown_tx: Mutex>>, cache: Arc>, @@ -54,10 +62,8 @@ where ) .await; } - Ok(( - data.to_owned(), - ttl.unwrap_or(data.expiration().remaining().unwrap().as_secs() as u32), - )) + let ttl = ttl.unwrap_or(data.expiration().remaining().unwrap().as_secs() as u32); + Ok((data.to_owned(), ttl)) } async fn save(&self, id: &str, data: T, ttl: u32) -> SessionResult<()> { @@ -67,7 +73,7 @@ where Ok(()) } - async fn delete(&self, id: &str) -> SessionResult<()> { + async fn delete(&self, id: &str, _cookie_jar: &CookieJar) -> SessionResult<()> { self.cache.remove(&id.to_owned()).await; Ok(()) } @@ -94,3 +100,202 @@ where Ok(()) } } + +/// Extended in-memory storage that supports session indexing by identifier. +/// This allows for operations like retrieving all sessions for a user or +/// bulk invalidation of sessions. +/// +/// You must implement the [`SessionIdentifier`] trait for your session type, +/// and the [`SessionIdentifier::Id`] type must implement [`ToString`]. +/// +/// # Example +/// ```rust +/// use rocket_flex_session::storage::memory::MemoryStorageIndexed; +/// use rocket_flex_session::{SessionIdentifier, RocketFlexSession}; +/// +/// #[derive(Clone)] +/// struct UserSession { +/// user_id: String, +/// data: String, +/// } +/// +/// impl SessionIdentifier for UserSession { +/// type Id = String; +/// fn identifier(&self) -> Option<&Self::Id> { +/// Some(&self.user_id) +/// } +/// } +/// +/// let storage = MemoryStorageIndexed::::default(); +/// let fairing = RocketFlexSession::builder() +/// .storage(storage) +/// .build(); +/// ``` +pub struct MemoryStorageIndexed +where + T: SessionIdentifier, +{ + base_storage: MemoryStorage, + // Index from identifier to set of session IDs + identifier_index: Arc>>>, +} + +impl Default for MemoryStorageIndexed +where + T: SessionIdentifier, + ::Id: ToString, +{ + fn default() -> Self { + Self { + base_storage: MemoryStorage::default(), + identifier_index: Arc::default(), + } + } +} + +impl MemoryStorageIndexed +where + T: SessionIdentifier, + T::Id: ToString, +{ + /// Update the identifier index when session data is saved + fn update_identifier_index(&self, session_id: &str, data: &T) { + if let Some(id) = data.identifier() { + let mut index = self.identifier_index.lock().unwrap(); + index + .entry(id.to_string()) + .or_insert_with(HashSet::new) + .insert(session_id.to_owned()); + } + } + + /// Remove from identifier index when session is deleted + fn remove_from_identifier_index(&self, session_id: &str, data: &T) { + if let Some(id) = data.identifier() { + let mut index = self.identifier_index.lock().unwrap(); + let key = id.to_string(); + if let Some(session_ids) = index.get_mut(&key) { + session_ids.remove(session_id); + if session_ids.is_empty() { + index.remove(&key); + } + } + } + } +} + +#[async_trait] +impl SessionStorage for MemoryStorageIndexed +where + T: SessionIdentifier + Clone + Send + Sync + 'static, + T::Id: ToString, +{ + async fn load( + &self, + id: &str, + ttl: Option, + cookie_jar: &CookieJar, + ) -> SessionResult<(T, u32)> { + self.base_storage.load(id, ttl, cookie_jar).await + } + + async fn save(&self, id: &str, data: T, ttl: u32) -> SessionResult<()> { + // Update identifier index before saving + self.update_identifier_index(id, &data); + + // Save using base storage + self.base_storage.save(id, data, ttl).await + } + + async fn delete(&self, id: &str, cookie_jar: &CookieJar) -> SessionResult<()> { + // Get the data first so we can update the index + if let Ok((data, _)) = self.base_storage.load(id, None, cookie_jar).await { + self.remove_from_identifier_index(id, &data); + } + + // Delete using base storage + self.base_storage.delete(id, cookie_jar).await + } + + fn as_indexed_storage(&self) -> Option<&dyn SessionStorageIndexed> { + Some(self) + } + + async fn setup(&self) -> SessionResult<()> { + self.base_storage.setup().await + } + + async fn shutdown(&self) -> SessionResult<()> { + self.base_storage.shutdown().await + } +} + +#[async_trait] +impl SessionStorageIndexed for MemoryStorageIndexed +where + Self: SessionStorage, + T: SessionIdentifier + Clone + Send + Sync, + T::Id: ToString, +{ + async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult> { + let session_ids = { + let index = self.identifier_index.lock().unwrap(); + index.get(&id.to_string()).cloned().unwrap_or_default() + }; + + let mut sessions: Vec<(String, T, u32)> = Vec::new(); + for session_id in session_ids { + if let Some(data) = self.base_storage.cache.get(&session_id).await { + let secs = data.expiration().remaining().unwrap().as_secs(); + sessions.push((session_id, data.value().to_owned(), secs as u32)); + } + } + + Ok(sessions) + } + + async fn get_session_ids_by_identifier(&self, id: &T::Id) -> SessionResult> { + let id_str = id.to_string(); + let session_ids = { + let index = self.identifier_index.lock().unwrap(); + index.get(&id_str).cloned().unwrap_or_default() + }; + + Ok(session_ids.into_iter().collect()) + } + + async fn invalidate_sessions_by_identifier( + &self, + id: &T::Id, + excluded_session_id: Option<&str>, + ) -> SessionResult { + let id_str = id.to_string(); + let mut session_ids_to_remove = { + let index = self.identifier_index.lock().unwrap(); + index.get(&id_str).cloned().unwrap_or_default() + }; + if let Some(session_id) = excluded_session_id { + session_ids_to_remove.retain(|id| id != session_id); + } + + // Remove all sessions from cache + for session_id in &session_ids_to_remove { + self.base_storage.cache.remove(session_id).await; + } + + // Remove all sessions from index + { + let mut index = self.identifier_index.lock().unwrap(); + if let Some(session_set) = index.get_mut(&id_str) { + for session_id in &session_ids_to_remove { + session_set.remove(session_id); + } + if session_set.is_empty() { + index.remove(&id_str); + } + } + } + + Ok(session_ids_to_remove.len() as u64) + } +} diff --git a/src/storage/redis.rs b/src/storage/redis.rs index 06a096c..c4d2048 100644 --- a/src/storage/redis.rs +++ b/src/storage/redis.rs @@ -1,157 +1,8 @@ //! Session storage with Redis (and Redis-compatible databases) -use fred::{ - prelude::{HashesInterface, KeysInterface, Pool, Value}, - types::Expiration, -}; -use rocket::{async_trait, http::CookieJar}; +mod base; +mod storage; +mod storage_indexed; -use super::interface::{SessionError, SessionResult, SessionStorage}; - -#[derive(Debug)] -pub enum RedisType { - String, - Hash, -} - -/** -Session storage with Redis (and Redis-compatible databases) using the [fred.rs](https://docs.rs/fred) crate. -You can store the data as a Redis string or hash. Your session data type must implement `TryFrom` -using the fred.rs [Value](https://docs.rs/fred/latest/fred/types/enum.Value.html) type, as well as the -inverse `TryFrom for Value`, in order to dictate how the data will be converted to/from the Redis data type. -- For `RedisType::String`, convert to/from `Value::String` -- For `RedisType::Hash`, convert to/from `Value::Map` - -```rust -use fred::prelude::{Builder, ClientLike, Config, Value}; -use rocket_flex_session::storage::{interface::SessionError, redis::{RedisFredStorage, RedisType}}; - -async fn setup_storage() -> RedisFredStorage { - // Setup and initialize a fred.rs Redis pool. - let redis_pool = Builder::default_centralized() - .set_config(Config::from_url("redis://localhost").expect("Valid Redis URL")) - .build_pool(4) - .expect("Should build Redis pool"); - redis_pool.init().await.expect("Should initialize Redis pool"); - let storage = RedisFredStorage::new( - redis_pool, - RedisType::String, // or RedisType::Hash - "sess:" // Prefix for Redis keys - ); - - storage -} - -struct MySessionData { - user_id: String, -} - -// Implement `TryFrom` to convert from the Redis value to your session data type -impl TryFrom for MySessionData { - type Error = SessionError; // or use your own error type - fn try_from(value: Value) -> Result { - match value { - Value::String(id) => Ok(MySessionData { - user_id: id.to_string(), - }), - _ => Err(SessionError::NotFound), - } - } -} -// You can use From or TryFrom for the inverse conversion -impl From for Value { - fn from(data: MySessionData) -> Self { - Value::String(data.user_id.into()) - } -} -``` -*/ -pub struct RedisFredStorage { - pool: Pool, - prefix: String, - redis_type: RedisType, -} -impl RedisFredStorage { - pub fn new(pool: Pool, redis_type: RedisType, key_prefix: &str) -> Self { - Self { - pool, - prefix: key_prefix.to_owned(), - redis_type, - } - } - - fn key(&self, id: &str) -> String { - format!("{}{id}", self.prefix) - } -} - -#[async_trait] -impl SessionStorage for RedisFredStorage -where - T: TryFrom + TryInto + Clone + Send + Sync + 'static, - >::Error: std::error::Error + Send + Sync + 'static, - >::Error: std::error::Error + Send + Sync + 'static, -{ - async fn load( - &self, - id: &str, - ttl: Option, - _cookie_jar: &CookieJar, - ) -> SessionResult<(T, u32)> { - let key = self.key(id); - let pipeline = self.pool.next().pipeline(); - let _: () = match self.redis_type { - RedisType::String => pipeline.get(&key).await?, - RedisType::Hash => pipeline.hgetall(&key).await?, - }; - let _: () = pipeline.ttl(&key).await?; - - let (value, orig_ttl): (Option, i64) = match ttl { - None => pipeline.all().await?, - Some(new_ttl) => { - let _: () = pipeline.expire(&key, new_ttl.into(), None).await?; - let (value, orig_ttl, _expire_result): (Option, i64, Option) = - pipeline.all().await?; - (value, orig_ttl) - } - }; - - let found_value = value.ok_or(SessionError::NotFound)?; - let data = - T::try_from(found_value).map_err(|e| SessionError::Serialization(Box::new(e)))?; - - Ok((data, ttl.unwrap_or(orig_ttl.try_into().unwrap_or(0)))) - } - - async fn save(&self, id: &str, data: T, ttl: u32) -> SessionResult<()> { - let key = self.key(id); - let value: Value = data - .try_into() - .map_err(|e| SessionError::Serialization(Box::new(e)))?; - let _: () = match self.redis_type { - RedisType::String => { - self.pool - .set(&key, value, Some(Expiration::EX(ttl.into())), None, false) - .await? - } - RedisType::Hash => { - let Value::Map(map) = value else { - return Err(SessionError::Serialization(Box::new(std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!("Converted Redis value wasn't a Map: {:?}", value), - )))); - }; - let pipeline = self.pool.next().pipeline(); - let _: () = pipeline.hset(&key, map).await?; - let _: () = pipeline.expire(&key, ttl.into(), None).await?; - pipeline.all().await? - } - }; - Ok(()) - } - - async fn delete(&self, id: &str) -> SessionResult<()> { - let _: u8 = self.pool.del(self.key(id)).await?; - Ok(()) - } -} +pub use base::{RedisFredStorage, RedisType}; +pub use storage_indexed::RedisFredStorageIndexed; diff --git a/src/storage/redis/base.rs b/src/storage/redis/base.rs new file mode 100644 index 0000000..d5778f5 --- /dev/null +++ b/src/storage/redis/base.rs @@ -0,0 +1,138 @@ +use fred::{ + prelude::{HashesInterface, KeysInterface, Pool, Value}, + types::Expiration, +}; + +use crate::error::{SessionError, SessionResult}; + +/// The Redis type to use for the session data +#[derive(Debug)] +pub enum RedisType { + String, + Hash, +} + +/** +Redis session storage using the [fred.rs](https://docs.rs/fred) crate. + +You can store the data as a Redis string or hash. Your session data type must implement [`FromValue`](https://docs.rs/fred/latest/fred/types/trait.FromValue.html) +from the fred.rs crate, as well as the inverse `From` or `TryFrom` for [`Value`](https://docs.rs/fred/latest/fred/types/enum.Value.html) in order +to dictate how the data will be converted to/from the Redis data type. +- For Redis string types, convert to/from `Value::String` +- For Redis hash types, convert to/from `Value::Map` + +💡 Common hashmap types like `HashMap` are automatically supported - make sure to use `RedisType::Hash` +when constructing the storage to ensure they are properly converted and stored as Redis hashes. + +```rust +use fred::prelude::{Builder, ClientLike, Config, FromValue, Value}; +use rocket_flex_session::{error::SessionError, storage::{redis::{RedisFredStorage, RedisType}}}; + +async fn setup_storage() -> RedisFredStorage { + // Setup and initialize a fred.rs Redis pool. + let redis_pool = Builder::default_centralized() + .set_config(Config::from_url("redis://localhost").expect("Valid Redis URL")) + .build_pool(4) + .expect("Should build Redis pool"); + redis_pool.init().await.expect("Should initialize Redis pool"); + + // Construct the storage + let storage = RedisFredStorage::new( + redis_pool, + RedisType::String, // or RedisType::Hash + "sess:" // Prefix for Redis keys + ); + + storage +} + +// If using a custom struct for your session data, implement the following... +struct MySessionData { + user_id: String, +} +// Implement `FromValue` to convert from the Redis value to your session data type +impl FromValue for MySessionData { + fn from_value(value: Value) -> Result { + let data: String = value.convert()?; // fred.rs provides several conversion methods on the Value type + Ok(MySessionData { + user_id: data, + }) + } +} +// Implement the inverse conversion +impl From for Value { + fn from(data: MySessionData) -> Self { + Value::String(data.user_id.into()) + } +} +``` +*/ +pub struct RedisFredStorage { + pub(super) pool: fred::prelude::Pool, + pub(super) prefix: String, + pub(super) redis_type: RedisType, +} + +impl RedisFredStorage { + /// Create the storage instance. + /// # Parameters + /// * `pool` - The initialized fred.rs connection pool. + /// * `redis_type` - The Redis data type to use for storing sessions. + /// * `key_prefix` - The prefix to use for session keys. (e.g. "sess:") + pub fn new(pool: Pool, redis_type: RedisType, key_prefix: &str) -> Self { + Self { + pool, + prefix: key_prefix.to_owned(), + redis_type, + } + } + + pub(super) fn session_key(&self, id: &str) -> String { + format!("{}{id}", self.prefix) + } + + pub(super) async fn fetch_session( + &self, + id: &str, + ttl: Option, + ) -> SessionResult<(Value, u32)> { + let key = self.session_key(id); + let pipeline = self.pool.next().pipeline(); + let _: () = match self.redis_type { + RedisType::String => pipeline.get(&key).await?, + RedisType::Hash => pipeline.hgetall(&key).await?, + }; + let _: () = pipeline.ttl(&key).await?; + + let (value, orig_ttl): (Option, i64) = match ttl { + None => pipeline.all().await?, + Some(new_ttl) => { + let _: () = pipeline.expire(&key, new_ttl.into(), None).await?; + let (value, orig_ttl, _expire_result): (Option, i64, Option) = + pipeline.all().await?; + (value, orig_ttl) + } + }; + + let found_value = value.ok_or(SessionError::NotFound)?; + Ok((found_value, ttl.unwrap_or(orig_ttl.try_into().unwrap_or(0)))) + } + + pub(super) async fn save_session(&self, id: &str, value: Value, ttl: u32) -> SessionResult<()> { + let key = self.session_key(id); + let _: () = match self.redis_type { + RedisType::String => { + self.pool + .set(&key, value, Some(Expiration::EX(ttl.into())), None, false) + .await? + } + RedisType::Hash => { + let pipeline = self.pool.next().pipeline(); + let _: () = pipeline.hset(&key, value.into_map()?).await?; + let _: () = pipeline.expire(&key, ttl.into(), None).await?; + pipeline.all().await? + } + }; + Ok(()) + } +} diff --git a/src/storage/redis/storage.rs b/src/storage/redis/storage.rs new file mode 100644 index 0000000..a07ffee --- /dev/null +++ b/src/storage/redis/storage.rs @@ -0,0 +1,40 @@ +use fred::prelude::{FromValue, KeysInterface, Value}; +use rocket::http::CookieJar; + +use crate::{ + error::{SessionError, SessionResult}, + storage::SessionStorage, +}; + +use super::RedisFredStorage; + +#[rocket::async_trait] +impl SessionStorage for RedisFredStorage +where + T: FromValue + TryInto + Clone + Send + Sync + 'static, + >::Error: std::error::Error + Send + Sync + 'static, +{ + async fn load( + &self, + id: &str, + ttl: Option, + _cookie_jar: &CookieJar, + ) -> SessionResult<(T, u32)> { + let (value, ttl) = self.fetch_session(id, ttl).await?; + let data = T::from_value(value)?; + Ok((data, ttl)) + } + + async fn save(&self, id: &str, data: T, ttl: u32) -> SessionResult<()> { + let value: Value = data + .try_into() + .map_err(|e| SessionError::Serialization(Box::new(e)))?; + self.save_session(id, value, ttl).await?; + Ok(()) + } + + async fn delete(&self, id: &str, _cookie_jar: &CookieJar) -> SessionResult<()> { + let _: () = self.pool.del(self.session_key(id)).await?; + Ok(()) + } +} diff --git a/src/storage/redis/storage_indexed.rs b/src/storage/redis/storage_indexed.rs new file mode 100644 index 0000000..363799a --- /dev/null +++ b/src/storage/redis/storage_indexed.rs @@ -0,0 +1,211 @@ +use fred::prelude::{FromValue, HashesInterface, KeysInterface, SetsInterface, Value}; +use rocket::http::CookieJar; + +use crate::{ + error::{SessionError, SessionResult}, + storage::{SessionStorage, SessionStorageIndexed}, + SessionIdentifier, +}; + +use super::RedisFredStorage; + +const DEFAULT_INDEX_TTL: u32 = 60 * 60 * 24 * 7 * 2; // 2 weeks + +/// Redis session storage using the [fred.rs](https://docs.rs/fred) crate. This is a wrapper around +/// [`RedisFredStorage`] that adds support for indexing sessions by an identifier (e.g. `user_id`). +/// +/// In addition to the requirements for `RedisFredStorage`, your session data type must +/// implement [`SessionIdentifier`], and its [Id](`SessionIdentifier::Id`) type +/// must implement `ToString`. Sessions are tracked in Redis sets, with a key format of +/// `:`. e.g.: `sess:user_id:1` +pub struct RedisFredStorageIndexed { + base_storage: RedisFredStorage, + index_ttl: u32, +} + +impl RedisFredStorageIndexed { + /// Create the indexed storage. + /// + /// # Parameters: + /// - `base_storage`: The [`RedisFredStorage`] instance to use. + /// - `index_ttl`: The TTL for the session index - should match + /// your longest expected session duration (default: 2 weeks). + pub fn new(base_storage: RedisFredStorage, index_ttl: Option) -> Self { + Self { + base_storage, + index_ttl: index_ttl.unwrap_or(DEFAULT_INDEX_TTL), + } + } + + fn session_index_key(&self, identifier_name: &str, identifier: &impl ToString) -> String { + format!( + "{}{identifier_name}:{}", + self.base_storage.prefix, + identifier.to_string() + ) + } + + async fn fetch_session_index( + &self, + identifier_name: &str, + identifier: &impl ToString, + ) -> SessionResult<(Vec, String)> { + let index_key = self.session_index_key(identifier_name, identifier); + let session_ids = self.base_storage.pool.smembers(&index_key).await?; + Ok((session_ids, index_key)) + } + + async fn cleanup_session_index( + &self, + index_key: &str, + stale_ids: Vec, + ) -> SessionResult<()> { + Ok(self.base_storage.pool.srem(index_key, stale_ids).await?) + } +} + +#[rocket::async_trait] +impl SessionStorage for RedisFredStorageIndexed +where + T: SessionIdentifier + FromValue + TryInto + Clone + Send + Sync + 'static, + >::Error: std::error::Error + Send + Sync + 'static, + ::Id: ToString, +{ + async fn load( + &self, + id: &str, + ttl: Option, + _cookie_jar: &CookieJar, + ) -> SessionResult<(T, u32)> { + let (value, ttl) = self.base_storage.fetch_session(id, ttl).await?; + let data = T::from_value(value)?; + Ok((data, ttl)) + } + + async fn save(&self, id: &str, data: T, ttl: u32) -> SessionResult<()> { + if let Some(identifier) = data.identifier() { + let index_key = self.session_index_key(T::IDENTIFIER, identifier); + let pipeline = self.base_storage.pool.next().pipeline(); + let _: () = pipeline.sadd(&index_key, id).await?; + let _: () = pipeline + .expire(&index_key, self.index_ttl.into(), None) + .await?; + let _: () = pipeline.all().await?; + } + + let value: Value = data + .try_into() + .map_err(|e| SessionError::Serialization(Box::new(e)))?; + self.base_storage.save_session(id, value, ttl).await + } + + async fn delete(&self, id: &str, _cookie_jar: &CookieJar) -> SessionResult<()> { + let (value, _) = self.base_storage.fetch_session(id, None).await?; + let data = T::from_value(value)?; + + let pipeline = self.base_storage.pool.next().pipeline(); + let _: () = pipeline.del(self.base_storage.session_key(id)).await?; + if let Some(identifier) = data.identifier() { + let session_idx_key = self.session_index_key(T::IDENTIFIER, identifier); + let _: () = pipeline.srem(&session_idx_key, id).await?; + } + Ok(pipeline.all().await?) + } +} + +#[rocket::async_trait] +impl SessionStorageIndexed for RedisFredStorageIndexed +where + T: SessionIdentifier + FromValue + TryInto + Clone + Send + Sync + 'static, + >::Error: std::error::Error + Send + Sync + 'static, + ::Id: ToString, +{ + async fn get_session_ids_by_identifier(&self, id: &T::Id) -> SessionResult> { + let (session_ids, index_key) = self.fetch_session_index(T::IDENTIFIER, id).await?; + + let session_exist_pipeline = self.base_storage.pool.next().pipeline(); + for session_id in &session_ids { + let session_key = self.base_storage.session_key(&session_id); + let _: () = session_exist_pipeline.exists(&session_key).await?; + } + let session_exist_results: Vec = session_exist_pipeline.all().await?; + + let (existing_sessions, stale_sessions): (Vec<_>, Vec<_>) = session_ids + .into_iter() + .zip(session_exist_results.into_iter()) + .partition(|(_, exists)| *exists); + if !stale_sessions.is_empty() { + let stale_ids: Vec<_> = stale_sessions.into_iter().map(|(id, _)| id).collect(); + self.cleanup_session_index(&index_key, stale_ids).await?; + } + + let sessions = existing_sessions.into_iter().map(|(id, _)| id).collect(); + Ok(sessions) + } + + async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult> { + let (session_ids, index_key) = self.fetch_session_index(T::IDENTIFIER, id).await?; + + let session_value_pipeline = self.base_storage.pool.next().pipeline(); + for session_id in &session_ids { + let session_key = self.base_storage.session_key(&session_id); + let _: () = match self.base_storage.redis_type { + super::RedisType::String => session_value_pipeline.get(&session_key).await?, + super::RedisType::Hash => session_value_pipeline.hgetall(&session_key).await?, + }; + let _: () = session_value_pipeline.ttl(&session_key).await?; + } + let mut raw_values_and_ttls: Vec> = session_value_pipeline.all().await?; + + let (existing_sessions, stale_sessions): (Vec<_>, Vec<_>) = session_ids + .into_iter() + .zip(raw_values_and_ttls.chunks_exact_mut(2)) + .map(|(id, raw)| { + let data_and_ttl = raw[0].take().and_then(|val| { + let data = T::from_value(val).ok()?; + let ttl = raw[1].as_ref().and_then(Value::as_i64)?; + Some((data, ttl)) + }); + (id, data_and_ttl) + }) + .partition(|(_, data_and_ttl)| data_and_ttl.is_some()); + if !stale_sessions.is_empty() { + let stale_ids: Vec<_> = stale_sessions.into_iter().map(|(id, _)| id).collect(); + self.cleanup_session_index(&index_key, stale_ids).await?; + } + + let sessions = existing_sessions + .into_iter() + .map(|(id, data_and_ttl)| { + let (data, ttl) = data_and_ttl.expect("already checked by partition"); + (id, data, ttl.try_into().unwrap_or(0)) + }) + .collect(); + Ok(sessions) + } + + async fn invalidate_sessions_by_identifier( + &self, + id: &T::Id, + excluded_session_id: Option<&str>, + ) -> SessionResult { + let (mut session_ids, index_key) = self.fetch_session_index(T::IDENTIFIER, id).await?; + if let Some(excluded_id) = excluded_session_id { + session_ids.retain(|id| id != excluded_id); + } + if session_ids.is_empty() { + return Ok(0); + } + + let session_keys: Vec<_> = session_ids + .iter() + .map(|id| self.base_storage.session_key(id)) + .collect(); + let delete_pipeline = self.base_storage.pool.next().pipeline(); + let _: () = delete_pipeline.del(session_keys).await?; + let _: () = delete_pipeline.srem(index_key, session_ids).await?; + let (del_num, _srem_num): (u64, u64) = delete_pipeline.all().await?; + + Ok(del_num) + } +} diff --git a/src/storage/sqlx.rs b/src/storage/sqlx.rs index 1f68686..68ac06b 100644 --- a/src/storage/sqlx.rs +++ b/src/storage/sqlx.rs @@ -1,120 +1,5 @@ -//! Session storage in PostgreSQL via sqlx +//! Session storage via sqlx -use rocket::{async_trait, http::CookieJar}; -use sqlx::{PgPool, Row}; -use time::{Duration, OffsetDateTime}; +mod postgres; -use super::interface::{SessionError, SessionResult, SessionStorage}; - -/** -Session store using PostgreSQL via [sqlx](https://docs.rs/crate/sqlx). Stores the session data as a string, so you'll need -to implement `TryFrom for String` and `TryFrom for YourSession` -for your session data type. Expects a table to already exist with the following columns: -| Name | Type | -|------|---------| -| id | text PRIMARY KEY | -| data | text NOT NULL (or jsonb if using JSON) | -| expires | timestamptz NOT NULL | -*/ -pub struct SqlxPostgresStorage { - pool: PgPool, - table_name: String, -} - -impl SqlxPostgresStorage { - pub fn new(pool: PgPool, table_name: &str) -> SqlxPostgresStorage { - Self { - pool, - table_name: table_name.to_owned(), - } - } -} - -#[async_trait] -impl SessionStorage for SqlxPostgresStorage -where - T: TryFrom + TryInto + Clone + Send + Sync + 'static, - >::Error: std::error::Error + Send + Sync + 'static, - >::Error: std::error::Error + Send + Sync + 'static, -{ - async fn load( - &self, - id: &str, - ttl: Option, - _cookie_jar: &CookieJar, - ) -> SessionResult<(T, u32)> { - let row = match ttl { - Some(new_ttl) => { - sqlx::query(&format!( - r#" - UPDATE "{}" SET expires = $1 - WHERE id = $2 AND expires > CURRENT_TIMESTAMP - RETURNING data, expires"#, - &self.table_name - )) - .bind(OffsetDateTime::now_utc() + Duration::seconds(new_ttl.into())) - .bind(id) - .fetch_optional(&self.pool) - .await? - } - None => { - sqlx::query(&format!( - r#" - SELECT data, expires FROM "{}" - WHERE id = $1 AND expires > CURRENT_TIMESTAMP"#, - &self.table_name - )) - .bind(id) - .fetch_optional(&self.pool) - .await? - } - }; - - let (raw_str, expires) = match row { - Some(row) => { - let data: String = row.try_get("data")?; - let expires: OffsetDateTime = row.try_get("expires")?; - (data, expires) - } - None => return Err(SessionError::NotFound), - }; - let data = T::try_from(raw_str).map_err(|e| SessionError::Serialization(Box::new(e)))?; - let ttl = (expires - OffsetDateTime::now_utc()).whole_seconds(); - - Ok((data, ttl.try_into().unwrap_or(0))) - } - - async fn save(&self, id: &str, data: T, ttl: u32) -> SessionResult<()> { - let raw_str: String = data - .try_into() - .map_err(|e| SessionError::Serialization(Box::new(e)))?; - let expires = OffsetDateTime::now_utc() + Duration::seconds(ttl.into()); - - sqlx::query(&format!( - r#" - INSERT INTO "{}" (id, data, expires) - VALUES ($1, $2, $3) - ON CONFLICT (id) DO UPDATE SET - data = EXCLUDED.data, - expires = EXCLUDED.expires - "#, - self.table_name - )) - .bind(id) - .bind(raw_str) - .bind(expires) - .execute(&self.pool) - .await?; - - Ok(()) - } - - async fn delete(&self, id: &str) -> SessionResult<()> { - sqlx::query(&format!("DELETE FROM {} WHERE id = $1", &self.table_name)) - .bind(id) - .execute(&self.pool) - .await?; - - Ok(()) - } -} +pub use postgres::SqlxPostgresStorage; diff --git a/src/storage/sqlx/postgres.rs b/src/storage/sqlx/postgres.rs new file mode 100644 index 0000000..702ec1e --- /dev/null +++ b/src/storage/sqlx/postgres.rs @@ -0,0 +1,287 @@ +use rocket::{ + async_trait, + http::CookieJar, + tokio::{ + self, + sync::{oneshot, Mutex}, + time::interval, + }, +}; +use sqlx::{postgres::PgRow, PgPool, Row}; +use time::{Duration, OffsetDateTime}; + +use crate::{ + error::{SessionError, SessionResult}, + storage::{SessionStorage, SessionStorageIndexed}, + SessionIdentifier, +}; + +const ID_COLUMN: &str = "id"; +const DATA_COLUMN: &str = "data"; +const EXPIRES_COLUMN: &str = "expires"; + +/** +Session store using PostgreSQL via [sqlx](https://docs.rs/crate/sqlx) that stores session data as a string, and supports session indexing. + +You'll need to implement `TryInto` and `TryFrom` for your session data type. You'll also need to implement [`SessionIdentifier`], +and its [`Id`](crate::SessionIdentifier::Id) must be a [type supported by sqlx](https://docs.rs/sqlx/latest/sqlx/postgres/types/index.html). +Expects a table to already exist with the following columns: + +| Name | Type | +|------|---------| +| id | `text` PRIMARY KEY | +| data | `text` NOT NULL (or `jsonb` if using JSON) | +| `` | `` (the name and type should match your [`SessionIdentifier`] impl) | +| expires | `timestamptz` NOT NULL | +*/ +pub struct SqlxPostgresStorage { + pool: PgPool, + table_name: String, + cleanup_interval: Option, + shutdown_tx: Mutex>>, +} + +impl SqlxPostgresStorage { + /// Creates a new [`SqlxPostgresStorage`]. + /// + /// Parameters: + /// - `pool`: An initialized Postgres connection pool. + /// - `table_name`: The name of the table to use for storing sessions. + /// - `cleanup_interval`: Interval to check for and clean up expired sessions. If `None`, + /// expired sessions won't be cleaned up automatically. + pub fn new( + pool: PgPool, + table_name: &str, + cleanup_interval: Option, + ) -> SqlxPostgresStorage { + Self { + pool, + table_name: table_name.to_owned(), + cleanup_interval, + shutdown_tx: Mutex::default(), + } + } + + fn id_from_row(&self, row: &PgRow) -> sqlx::Result { + row.try_get(ID_COLUMN) + } + + fn raw_data_from_row(&self, row: &PgRow) -> sqlx::Result { + row.try_get(DATA_COLUMN) + } + + fn ttl_from_row(&self, row: &PgRow) -> sqlx::Result { + let expires: OffsetDateTime = row.try_get(EXPIRES_COLUMN)?; + let ttl = (expires - OffsetDateTime::now_utc()) + .whole_seconds() + .try_into() + .unwrap_or(0); + Ok(ttl) + } +} + +#[async_trait] +impl SessionStorage for SqlxPostgresStorage +where + T: SessionIdentifier + TryFrom + TryInto + Clone + Send + Sync + 'static, + ::Id: + for<'q> sqlx::Encode<'q, sqlx::Postgres> + sqlx::Type, + >::Error: std::error::Error + Send + Sync + 'static, + >::Error: std::error::Error + Send + Sync + 'static, +{ + async fn load( + &self, + id: &str, + ttl: Option, + _cookie_jar: &CookieJar, + ) -> SessionResult<(T, u32)> { + let row = match ttl { + Some(new_ttl) => { + let sql = format!( + "UPDATE \"{}\" SET {EXPIRES_COLUMN} = $1 \ + WHERE {ID_COLUMN} = $2 AND {EXPIRES_COLUMN} > CURRENT_TIMESTAMP \ + RETURNING {DATA_COLUMN}, {EXPIRES_COLUMN}", + &self.table_name, + ); + sqlx::query(&sql) + .bind(OffsetDateTime::now_utc() + Duration::seconds(new_ttl.into())) + .bind(id) + .fetch_optional(&self.pool) + .await? + } + None => { + let sql = format!( + "SELECT {DATA_COLUMN}, {EXPIRES_COLUMN} FROM \"{}\" \ + WHERE {ID_COLUMN} = $1 AND {EXPIRES_COLUMN} > CURRENT_TIMESTAMP", + &self.table_name, + ); + sqlx::query(&sql) + .bind(id) + .fetch_optional(&self.pool) + .await? + } + }; + + let (raw_str, ttl) = match row { + Some(row) => { + let data = self.raw_data_from_row(&row)?; + let ttl = self.ttl_from_row(&row)?; + (data, ttl) + } + None => return Err(SessionError::NotFound), + }; + let data = T::try_from(raw_str).map_err(|e| SessionError::Serialization(Box::new(e)))?; + Ok((data, ttl)) + } + + async fn save(&self, id: &str, data: T, ttl: u32) -> SessionResult<()> { + let sql = format!( + "INSERT INTO \"{}\" ({ID_COLUMN}, {}, {DATA_COLUMN}, {EXPIRES_COLUMN}) \ + VALUES ($1, $2, $3, $4) \ + ON CONFLICT ({ID_COLUMN}) DO UPDATE SET \ + {DATA_COLUMN} = EXCLUDED.{DATA_COLUMN}, \ + {EXPIRES_COLUMN} = EXCLUDED.{EXPIRES_COLUMN}", + self.table_name, + T::IDENTIFIER + ); + let identifier = data.identifier().cloned(); + let data_str: String = data + .try_into() + .map_err(|e| SessionError::Serialization(Box::new(e)))?; + sqlx::query(&sql) + .bind(id) + .bind(identifier) + .bind(data_str) + .bind(OffsetDateTime::now_utc() + Duration::seconds(ttl.into())) + .execute(&self.pool) + .await?; + + Ok(()) + } + + async fn delete(&self, id: &str, _cookie_jar: &CookieJar) -> SessionResult<()> { + let sql = format!("DELETE FROM {} WHERE {ID_COLUMN} = $1", &self.table_name); + sqlx::query(&sql).bind(id).execute(&self.pool).await?; + + Ok(()) + } + + async fn setup(&self) -> SessionResult<()> { + let Some(cleanup_interval) = self.cleanup_interval else { + return Ok(()); + }; + let (tx, mut rx) = oneshot::channel(); + let pool = self.pool.clone(); + let table_name = self.table_name.clone(); + tokio::spawn(async move { + rocket::info!("Starting session cleanup monitor"); + let mut interval = interval(cleanup_interval); + loop { + tokio::select! { + _ = interval.tick() => { + rocket::debug!("Cleaning up expired sessions"); + if let Err(e) = cleanup_expired_sessions(&table_name, &pool).await { + rocket::error!("Error deleting expired sessions: {e}"); + } + } + _ = &mut rx => { + rocket::info!("Session cleanup monitor shutdown"); + } + } + } + }); + self.shutdown_tx.lock().await.replace(tx); + + Ok(()) + } + + async fn shutdown(&self) -> SessionResult<()> { + if let Some(tx) = self.shutdown_tx.lock().await.take() { + tx.send(()).map_err(|_| { + SessionError::SetupTeardown("Failed to send shutdown signal".to_string()) + })?; + } + Ok(()) + } +} + +async fn cleanup_expired_sessions(table_name: &str, pool: &PgPool) -> Result { + rocket::debug!("Cleaning up expired sessions"); + let sql = format!("DELETE FROM \"{table_name}\" WHERE {EXPIRES_COLUMN} < $1"); + let rows = sqlx::query(&sql) + .bind(OffsetDateTime::now_utc()) + .execute(pool) + .await?; + Ok(rows.rows_affected()) +} + +#[async_trait] +impl SessionStorageIndexed for SqlxPostgresStorage +where + T: SessionIdentifier + TryFrom + TryInto + Clone + Send + Sync + 'static, + ::Id: + for<'q> sqlx::Encode<'q, sqlx::Postgres> + sqlx::Type, + >::Error: std::error::Error + Send + Sync + 'static, + >::Error: std::error::Error + Send + Sync + 'static, +{ + async fn get_session_ids_by_identifier(&self, id: &T::Id) -> SessionResult> { + let sql = format!( + "SELECT {ID_COLUMN} FROM \"{}\" \ + WHERE {} = $1 AND {EXPIRES_COLUMN} > CURRENT_TIMESTAMP", + &self.table_name, + T::IDENTIFIER + ); + let rows = sqlx::query(&sql).bind(id).fetch_all(&self.pool).await?; + let parsed_rows = rows + .into_iter() + .filter_map(|row| self.id_from_row(&row).ok()) + .collect(); + + Ok(parsed_rows) + } + + async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult> { + let sql = format!( + "SELECT {ID_COLUMN}, {DATA_COLUMN}, {EXPIRES_COLUMN} FROM \"{}\" \ + WHERE {} = $1 AND {EXPIRES_COLUMN} > CURRENT_TIMESTAMP", + &self.table_name, + T::IDENTIFIER + ); + let rows = sqlx::query(&sql).bind(id).fetch_all(&self.pool).await?; + let parsed_rows = rows + .into_iter() + .filter_map(|row| { + let id = self.id_from_row(&row).ok()?; + let raw_data = self.raw_data_from_row(&row).ok()?; + let data = T::try_from(raw_data).ok()?; + let ttl = self.ttl_from_row(&row).ok()?; + Some((id, data, ttl)) + }) + .collect(); + + Ok(parsed_rows) + } + + async fn invalidate_sessions_by_identifier( + &self, + id: &T::Id, + excluded_session_id: Option<&str>, + ) -> SessionResult { + let mut sql = format!( + "DELETE FROM \"{}\" WHERE {} = $1", + &self.table_name, + T::IDENTIFIER + ); + if excluded_session_id.is_some() { + sql.push_str(&format!(" AND {ID_COLUMN} != $2")); + } + + let mut query = sqlx::query(&sql).bind(id); + if let Some(excluded_id) = excluded_session_id { + query = query.bind(excluded_id); + } + let rows = query.execute(&self.pool).await?; + + Ok(rows.rows_affected()) + } +} diff --git a/tests/basic.rs b/tests/basic.rs index 38169f8..e7494a9 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -6,7 +6,10 @@ use rocket::{ local::blocking::Client, {routes, Build, Rocket}, }; -use rocket_flex_session::{storage::cookie::CookieStorage, RocketFlexSession, Session}; +use rocket_flex_session::{ + storage::cookie::CookieStorage, RocketFlexSession, Session, SessionHashMap, +}; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; #[derive(Clone, Debug, PartialEq)] @@ -15,6 +18,23 @@ struct User { name: String, } +#[derive(Clone, Default, Serialize, Deserialize)] +struct SessionHash(HashMap); + +impl SessionHashMap for SessionHash { + type Value = String; + + fn get(&self, key: &str) -> Option<&Self::Value> { + self.0.get(key) + } + fn insert(&mut self, key: String, value: Self::Value) { + self.0.insert(key, value); + } + fn remove(&mut self, key: &str) { + self.0.remove(key); + } +} + #[get("/get_session")] fn get_session(session: Session) -> String { match session.get() { @@ -39,19 +59,15 @@ fn delete_session(mut session: Session) -> &'static str { } #[get("/get_hash_session/")] -fn get_hash_session(session: Session>, key: &str) -> String { +fn get_hash_session(session: Session, key: &str) -> String { match session.get_key(key) { - Some(value) => value.clone(), + Some(value) => value, None => "No value".to_string(), } } #[post("/set_hash_session//")] -fn set_hash_session( - mut session: Session>, - key: &str, - value: &str, -) -> &'static str { +fn set_hash_session(mut session: Session, key: &str, value: &str) -> &'static str { session.set_key(key.to_owned(), value.to_owned()); "Hash session value set" } @@ -60,7 +76,7 @@ fn create_rocket() -> Rocket { rocket::build() .attach(RocketFlexSession::::default()) .attach( - RocketFlexSession::>::builder() + RocketFlexSession::::builder() .with_options(|opt| opt.cookie_name = "hash_session".to_owned()) .storage( CookieStorage::builder() diff --git a/tests/common/mod.rs b/tests/common/mod.rs new file mode 100644 index 0000000..d2413f3 --- /dev/null +++ b/tests/common/mod.rs @@ -0,0 +1,70 @@ +use fred::prelude::{ClientLike, KeysInterface, ReconnectPolicy}; +use sqlx::{Connection, PgPool}; + +pub const POSTGRES_URL: &str = "postgres://postgres:postgres@localhost"; + +fn random_string(n: usize) -> String { + (0..n) + .map(|_| (b'a' + (rand::random::() % 26)) as char) + .collect() +} + +/// Setup a test Postgres database +pub async fn setup_postgres(base_url: &str) -> (PgPool, String) { + let db_name = format!("test_{}", random_string(6)); + let mut cxn = sqlx::PgConnection::connect(base_url).await.unwrap(); + sqlx::query(&format!("CREATE DATABASE {}", db_name)) + .execute(&mut cxn) + .await + .expect("Should create test database"); + let _ = cxn.close().await; + + let db_url = format!("{}/{}", base_url, db_name); + let pool = sqlx::PgPool::connect(&db_url).await.unwrap(); + sqlx::query( + r#"CREATE TABLE IF NOT EXISTS sessions ( + id TEXT PRIMARY KEY, + data TEXT NOT NULL, + user_id TEXT, + expires TIMESTAMPTZ NOT NULL + )"#, + ) + .execute(&pool) + .await + .expect("Should create sessions table"); + + (pool, db_name) +} + +pub async fn teardown_postgres(pool: sqlx::Pool, db_name: String) { + pool.close().await; + drop(pool); + let mut cxn = sqlx::PgConnection::connect(POSTGRES_URL).await.unwrap(); + sqlx::query(&format!("DROP DATABASE {} WITH (FORCE)", db_name)) + .execute(&mut cxn) + .await + .expect("Should drop test database"); +} + +pub async fn setup_redis_fred() -> (fred::prelude::Pool, String) { + let pool = fred::prelude::Builder::default_centralized() + .set_policy(ReconnectPolicy::new_linear(3, 5, 1)) + .with_performance_config(|c| c.default_command_timeout = std::time::Duration::from_secs(5)) + .build_pool(3) + .expect("Should build Redis pool"); + pool.init().await.expect("Should initialize Redis pool"); + let prefix = format!("test_{}:sess:", random_string(6)); + + (pool, prefix) +} + +pub async fn teardown_redis_fred(pool: fred::prelude::Pool, prefix: String) { + let (_cursor, keys): (String, Vec) = pool + .scan_page("0", format!("{prefix}*"), Some(50), None) + .await + .expect("Should scan keys"); + if !keys.is_empty() { + let _: () = pool.del(keys).await.expect("Should delete keys"); + } + pool.quit().await.expect("Should quit Redis pool"); +} diff --git a/tests/session_indexed.rs b/tests/session_indexed.rs new file mode 100644 index 0000000..ef1b131 --- /dev/null +++ b/tests/session_indexed.rs @@ -0,0 +1,333 @@ +use rocket::{ + get, routes, + serde::{Deserialize, Serialize}, + Build, Rocket, +}; +use rocket_flex_session::{ + storage::memory::MemoryStorageIndexed, RocketFlexSession, Session, SessionIdentifier, +}; + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +struct UserSession { + user_id: String, + username: String, + login_time: u64, +} + +impl SessionIdentifier for UserSession { + const IDENTIFIER: &str = "user_id"; + type Id = String; + + fn identifier(&self) -> Option<&Self::Id> { + Some(&self.user_id) + } +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +struct AdminSession { + admin_id: String, + role: String, + permissions: Vec, +} + +impl SessionIdentifier for AdminSession { + const IDENTIFIER: &str = "admin_id"; + type Id = String; + + fn identifier(&self) -> Option<&Self::Id> { + Some(&self.admin_id) + } +} + +// Routes for testing user sessions +#[get("/user/login//")] +async fn user_login( + mut session: Session<'_, UserSession>, + user_id: String, + username: String, +) -> String { + let user_session = UserSession { + user_id: user_id.clone(), + username: username.clone(), + login_time: 1234567890, + }; + + session.set(user_session); + format!("User {} logged in", username) +} + +#[get("/user/sessions")] +async fn get_user_sessions(session: Session<'_, UserSession>) -> String { + match session.get_all_sessions().await { + Ok(Some(sessions)) => { + format!("Found {} sessions for current user", sessions.len()) + } + Ok(None) => "No current session".to_string(), + Err(e) => format!("Error getting sessions: {e}"), + } +} + +#[get("/user/sessions/")] +async fn get_sessions_for_user(session: Session<'_, UserSession>, user_id: String) -> String { + match session.get_sessions_by_identifier(&user_id).await { + Ok(sessions) => { + format!("Sessions for user {user_id}: {:?}", sessions) + } + Err(e) => format!("Error getting sessions: {e}"), + } +} + +#[get("/user/invalidate-all")] +async fn invalidate_all_user_sessions(session: Session<'_, UserSession>) -> String { + match session.invalidate_all_sessions(false).await { + Ok(Some(n)) => format!("{n} session(s) for current user invalidated."), + Ok(None) => "No current session".to_string(), + Err(e) => format!("Error invalidating sessions: {e}"), + } +} + +#[get("/user/invalidate-other")] +async fn invalidate_other_user_sessions(session: Session<'_, UserSession>) -> String { + match session.invalidate_all_sessions(true).await { + Ok(Some(n)) => format!("{n} session(s) for current user invalidated."), + Ok(None) => "No current session".to_string(), + Err(e) => format!("Error invalidating sessions: {e}"), + } +} + +#[get("/user/invalidate-all/")] +async fn invalidate_sessions_for_user( + session: Session<'_, UserSession>, + user_id: String, +) -> String { + match session.invalidate_sessions_by_identifier(&user_id).await { + Ok(n) => format!("{n} session(s) for user {user_id} invalidated"), + Err(e) => format!("Error invalidating sessions: {e}"), + } +} + +#[get("/user/session-ids")] +async fn get_user_session_ids(session: Session<'_, UserSession>) -> String { + match session.get_all_session_ids().await { + Ok(Some(session_ids)) => { + format!("Session IDs for current user: {:?}", session_ids) + } + Ok(None) => "No current session".to_string(), + Err(e) => format!("Error getting session IDs: {e}"), + } +} + +#[get("/user/profile")] +async fn user_profile(session: Session<'_, UserSession>) -> String { + match session.get() { + Some(user_session) => { + format!( + "Profile for {}: logged in at {}", + user_session.username, user_session.login_time + ) + } + None => "No active session".to_string(), + } +} + +fn rocket() -> Rocket { + let user_storage = MemoryStorageIndexed::::default(); + + rocket::build() + .attach( + RocketFlexSession::::builder() + .storage(user_storage) + .build(), + ) + .mount( + "/", + routes![ + user_login, + get_user_sessions, + get_sessions_for_user, + invalidate_all_user_sessions, + invalidate_other_user_sessions, + invalidate_sessions_for_user, + get_user_session_ids, + user_profile, + ], + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use rocket::http::Status; + use rocket::local::blocking::Client; + + fn create_test_client() -> Client { + Client::tracked(rocket()).expect("valid rocket instance") + } + + #[test] + fn user_login_and_profile() { + let client = create_test_client(); + + // Login user + let response = client.get("/user/login/user1/alice").dispatch(); + assert_eq!(response.status(), Status::Ok); + assert_eq!(response.into_string().unwrap(), "User alice logged in"); + + // Check profile + let response = client.get("/user/profile").dispatch(); + assert_eq!(response.status(), Status::Ok); + assert!(response + .into_string() + .unwrap() + .contains("Profile for alice")); + } + + #[test] + fn multiple_sessions_same_user() { + let client = create_test_client(); + + // First session for user1 + let response = client.get("/user/login/user1/alice").dispatch(); + assert_eq!(response.status(), Status::Ok); + + // Check that we can see current user's sessions + let response = client.get("/user/sessions").dispatch(); + assert_eq!(response.status(), Status::Ok); + // Note: This might show 0 or 1 sessions depending on whether the current + // session cookie is being tracked properly in the test + } + + #[test] + fn get_sessions_by_user_id() { + let client = create_test_client(); + + // Login user + let response = client.get("/user/login/user1/alice").dispatch(); + assert_eq!(response.status(), Status::Ok); + + // Get sessions for specific user ID + let response = client.get("/user/sessions/user1").dispatch(); + assert_eq!(response.status(), Status::Ok); + let body = response.into_string().unwrap(); + assert!(body.contains("Sessions for user user1")); + } + + #[test] + fn test_session_ids_retrieval() { + let client = create_test_client(); + + // Login user + let response = client.get("/user/login/user1/alice").dispatch(); + assert_eq!(response.status(), Status::Ok); + + // Get session IDs + let response = client.get("/user/session-ids").dispatch(); + assert_eq!(response.status(), Status::Ok); + let body = response.into_string().unwrap(); + assert!(body.contains("Session IDs for current user")); + } + + #[test] + fn test_invalidate_sessions() { + let client = create_test_client(); + + // Login user + let response = client.get("/user/login/user1/alice").dispatch(); + assert_eq!(response.status(), Status::Ok); + + // Invalidate all sessions for current user + let response = client.get("/user/invalidate-all").dispatch(); + assert_eq!(response.status(), Status::Ok); + assert!(response + .into_string() + .unwrap() + .contains("1 session(s) for current user invalidated")); + + // Profile should now show no session + let response = client.get("/user/profile").dispatch(); + assert_eq!(response.status(), Status::Ok); + assert_eq!(response.into_string().unwrap(), "No active session"); + } + + #[test] + fn test_invalidate_other_sessions() { + let client = create_test_client(); + + let response = client.get("/user/login/user1/alice").dispatch(); + assert_eq!(response.status(), Status::Ok); + + // Create two more sessions for the same user, simulating a different device + let response = client + .get("/user/login/user1/alice") + .private_cookie("rocket") // empty cookie + .dispatch(); + assert_eq!(response.status(), Status::Ok); + let response = client + .get("/user/login/user1/alice") + .private_cookie("rocket") + .dispatch(); + assert_eq!(response.status(), Status::Ok); + + // Invalidate all sessions except the current one + let response = client.get("/user/invalidate-other").dispatch(); + assert_eq!(response.status(), Status::Ok); + assert_eq!( + response.into_string().unwrap(), + "2 session(s) for current user invalidated." + ); + + // Profile should still show active session + let response = client.get("/user/profile").dispatch(); + assert_eq!(response.status(), Status::Ok); + assert!(response + .into_string() + .unwrap() + .contains("Profile for alice")); + } + + #[test] + fn test_invalidate_sessions_by_user_id() { + let client = create_test_client(); + + // Login user + let response = client.get("/user/login/user2/bob").dispatch(); + assert_eq!(response.status(), Status::Ok); + + // Invalidate sessions for specific user + let response = client.get("/user/invalidate-all/user2").dispatch(); + assert_eq!(response.status(), Status::Ok); + assert!(response + .into_string() + .unwrap() + .contains("1 session(s) for user user2 invalidated")); + } + + #[test] + fn test_no_session_scenarios() { + let client = create_test_client(); + + // Try to get sessions without being logged in + let response = client.get("/user/sessions").dispatch(); + assert_eq!(response.status(), Status::Ok); + assert!(response + .into_string() + .unwrap() + .contains("No current session")); + + // Try to get session IDs without being logged in + let response = client.get("/user/session-ids").dispatch(); + assert_eq!(response.status(), Status::Ok); + assert!(response + .into_string() + .unwrap() + .contains("No current session")); + + // Try to invalidate sessions without being logged in + let response = client.get("/user/invalidate-all").dispatch(); + assert_eq!(response.status(), Status::Ok); + assert!(response + .into_string() + .unwrap() + .contains("No current session")); + } +} diff --git a/tests/storages.rs b/tests/storages_basic.rs similarity index 63% rename from tests/storages.rs rename to tests/storages_basic.rs index 98c37f7..1f21a52 100644 --- a/tests/storages.rs +++ b/tests/storages_basic.rs @@ -1,24 +1,29 @@ +mod common; + #[macro_use] extern crate rocket; use std::{future::Future, pin::Pin}; -use fred::prelude::{ClientLike, ReconnectPolicy}; -use rocket::{http::Status, local::asynchronous::Client, tokio::time::sleep, Build, Rocket}; +use rocket::{ + futures::FutureExt, http::Status, local::asynchronous::Client, tokio::time::sleep, Build, + Rocket, +}; use rocket_flex_session::{ + error::SessionError, storage::{ cookie::CookieStorage, - interface::SessionError, - redis::{RedisFredStorage, RedisType}, + redis::{RedisFredStorage, RedisFredStorageIndexed, RedisType}, sqlx::SqlxPostgresStorage, }, - RocketFlexSession, Session, + RocketFlexSession, Session, SessionIdentifier, }; use serde::{Deserialize, Serialize}; -use sqlx::{Connection, PgPool}; use test_case::test_case; -const POSTGRES_URL: &str = "postgres://postgres:postgres@localhost"; +use crate::common::{ + setup_postgres, setup_redis_fred, teardown_postgres, teardown_redis_fred, POSTGRES_URL, +}; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] struct SessionData { @@ -35,11 +40,11 @@ impl From for String { value.user_id } } -impl TryFrom for SessionData { - type Error = SessionError; - fn try_from(value: fred::types::Value) -> Result { - let user_id = value.as_string().ok_or(SessionError::NotFound)?; - Ok(Self { user_id }) +impl fred::types::FromValue for SessionData { + fn from_value(value: fred::prelude::Value) -> Result { + Ok(Self { + user_id: value.convert()?, + }) } } impl From for fred::types::Value { @@ -47,6 +52,13 @@ impl From for fred::types::Value { Self::String(value.user_id.into()) } } +impl SessionIdentifier for SessionData { + const IDENTIFIER: &str = "user_id"; + type Id = String; + fn identifier(&self) -> Option<&Self::Id> { + Some(&self.user_id) + } +} #[get("/get_session")] fn get_session(session: Session) -> String { @@ -75,40 +87,12 @@ fn expire_session(mut session: Session) { session.set_ttl(1); } -/// Setup a test Postgres database -async fn setup_postgres(base_url: &str) -> (PgPool, String) { - let db_name = format!( - "test_{}", - (0..6) - .map(|_| (b'a' + (rand::random::() % 26)) as char) - .collect::() - ); - let mut cxn = sqlx::PgConnection::connect(base_url).await.unwrap(); - sqlx::query(&format!("CREATE DATABASE {}", db_name)) - .execute(&mut cxn) - .await - .expect("Should create test database"); - let _ = cxn.close().await; - - let db_url = format!("{}/{}", base_url, db_name); - let pool = sqlx::PgPool::connect(&db_url).await.unwrap(); - sqlx::query( - r#"CREATE TABLE IF NOT EXISTS sessions ( - id TEXT PRIMARY KEY, - data TEXT NOT NULL, - expires TIMESTAMPTZ NOT NULL - )"#, - ) - .execute(&pool) - .await - .expect("Should create sessions table"); - - (pool, db_name) -} - async fn create_rocket( storage_case: &str, -) -> (Rocket, Option>>>) { +) -> ( + Rocket, + Option + Send>>>, +) { let (fairing, cleanup_task) = match storage_case { "cookie" => ( RocketFlexSession::::builder() @@ -117,41 +101,31 @@ async fn create_rocket( None, ), "redis" => { - let pool = fred::prelude::Builder::default_centralized() - .set_policy(ReconnectPolicy::new_linear(3, 5, 1)) - .with_performance_config(|c| { - c.default_command_timeout = std::time::Duration::from_secs(5) - }) - .build_pool(3) - .expect("Should build Redis pool"); - pool.init().await.expect("Should initialize Redis pool"); - let storage = RedisFredStorage::new(pool.clone(), RedisType::String, "sess:"); + let (pool, prefix) = setup_redis_fred().await; + let storage = RedisFredStorage::new(pool.clone(), RedisType::String, &prefix); let fairing = RocketFlexSession::::builder() .storage(storage) .build(); - - let cleanup_task: Pin>> = Box::pin(async move { - pool.quit().await.ok(); - drop(pool); - }); + let cleanup_task = teardown_redis_fred(pool, prefix).boxed(); + (fairing, Some(cleanup_task)) + } + "redis_indexed" => { + let (pool, prefix) = setup_redis_fred().await; + let base_storage = RedisFredStorage::new(pool.clone(), RedisType::String, &prefix); + let storage = RedisFredStorageIndexed::new(base_storage, None); + let fairing = RocketFlexSession::::builder() + .storage(storage) + .build(); + let cleanup_task = teardown_redis_fred(pool, prefix).boxed(); (fairing, Some(cleanup_task)) } "sqlx" => { let (pool, db_name) = setup_postgres(POSTGRES_URL).await; - let storage = SqlxPostgresStorage::new(pool.clone(), "sessions"); + let storage = SqlxPostgresStorage::new(pool.clone(), "sessions", None); let fairing = RocketFlexSession::::builder() .storage(storage) .build(); - - let cleanup_task: Pin>> = Box::pin(async move { - pool.close().await; - drop(pool); - let mut cxn = sqlx::PgConnection::connect(POSTGRES_URL).await.unwrap(); - sqlx::query(&format!("DROP DATABASE {} WITH (FORCE)", db_name)) - .execute(&mut cxn) - .await - .expect("Should drop test database"); - }); + let cleanup_task = teardown_postgres(pool, db_name).boxed(); (fairing, Some(cleanup_task)) } _ => unimplemented!(), @@ -166,7 +140,8 @@ async fn create_rocket( } #[test_case("cookie"; "Cookie")] -#[test_case("redis"; "Fred Redis")] +#[test_case("redis"; "Redis Fred")] +#[test_case("redis_indexed"; "Redis Fred Indexed")] #[test_case("sqlx"; "Sqlx Postgres")] #[rocket::async_test] async fn test_storages(storage_case: &str) { diff --git a/tests/storages_indexed.rs b/tests/storages_indexed.rs new file mode 100644 index 0000000..3181827 --- /dev/null +++ b/tests/storages_indexed.rs @@ -0,0 +1,385 @@ +mod common; + +use std::{collections::HashMap, future::Future, pin::Pin}; + +use rocket::{futures::FutureExt, local::asynchronous::Client}; +use rocket_flex_session::{ + storage::{ + memory::MemoryStorageIndexed, + redis::{RedisFredStorage, RedisFredStorageIndexed, RedisType}, + sqlx::SqlxPostgresStorage, + SessionStorageIndexed, + }, + SessionIdentifier, +}; +use test_case::test_case; + +use crate::common::{ + setup_postgres, setup_redis_fred, teardown_postgres, teardown_redis_fred, POSTGRES_URL, +}; + +#[derive(Clone, Debug, PartialEq)] +struct TestSession { + user_id: String, + data: String, +} +impl SessionIdentifier for TestSession { + type Id = String; + + fn identifier(&self) -> Option<&Self::Id> { + Some(&self.user_id) + } +} + +// Impls for Sqlx +impl TryFrom for String { + type Error = std::io::Error; + + fn try_from(value: TestSession) -> Result { + Ok(format!("{}:{}", value.user_id, value.data)) + } +} +impl TryFrom for TestSession { + type Error = std::io::Error; + + fn try_from(value: String) -> Result { + let (user_id, data) = value.split_once(':').ok_or(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Invalid session format", + ))?; + Ok(TestSession { + user_id: user_id.to_string(), + data: data.to_string(), + }) + } +} + +// Impls for fred.rs Redis +const USER_ID_KEY: fred::prelude::Key = fred::types::Key::from_static_str("user_id"); +const DATA_KEY: fred::prelude::Key = fred::types::Key::from_static_str("data"); +impl fred::types::FromValue for TestSession { + fn from_value(value: fred::prelude::Value) -> Result { + let mut map = value.into_map()?; + Ok(Self { + user_id: map.remove(&USER_ID_KEY).unwrap().convert()?, + data: map.remove(&DATA_KEY).unwrap().convert()?, + }) + } +} +impl From for fred::types::Value { + fn from(value: TestSession) -> Self { + let hash: HashMap = + HashMap::from([(USER_ID_KEY, value.user_id), (DATA_KEY, value.data)]); + let fred_map = fred::types::Map::try_from(hash).unwrap(); + fred::types::Value::Map(fred_map) + } +} + +async fn create_storage( + storage_case: &str, +) -> ( + Box>, + Option>>>, +) { + match storage_case { + "memory" => { + let storage = MemoryStorageIndexed::::default(); + (Box::new(storage), None) + } + "redis" => { + let (pool, prefix) = setup_redis_fred().await; + let base_storage = RedisFredStorage::new(pool.clone(), RedisType::Hash, &prefix); + let storage = RedisFredStorageIndexed::new(base_storage, None); + let cleanup_task = teardown_redis_fred(pool, prefix).boxed(); + (Box::new(storage), Some(cleanup_task)) + } + "sqlx" => { + let (pool, db_name) = setup_postgres(POSTGRES_URL).await; + let storage = SqlxPostgresStorage::new(pool.clone(), "sessions", None); + let cleanup_task = teardown_postgres(pool, db_name).boxed(); + (Box::new(storage), Some(cleanup_task)) + } + _ => unimplemented!(), + } +} + +#[test_case("memory"; "Memory")] +#[test_case("sqlx"; "Sqlx Postgres")] +#[test_case("redis"; "Redis Fred")] +#[rocket::async_test] +async fn basic_operations(storage_case: &str) { + let (storage, cleanup_task) = create_storage(storage_case).await; + storage.setup().await.unwrap(); + + let session1 = TestSession { + user_id: "user1".to_string(), + data: "session1_data".to_string(), + }; + let session2 = TestSession { + user_id: "user1".to_string(), + data: "session2_data".to_string(), + }; + let session3 = TestSession { + user_id: "user2".to_string(), + data: "session3_data".to_string(), + }; + + // Save sessions + storage.save("sid1", session1.clone(), 3600).await.unwrap(); + storage.save("sid2", session2.clone(), 3600).await.unwrap(); + storage.save("sid3", session3.clone(), 3600).await.unwrap(); + + // Test get_sessions_by_identifier + let user1_sessions = storage + .get_sessions_by_identifier(&"user1".to_string()) + .await + .unwrap(); + assert_eq!(user1_sessions.len(), 2); + assert!(user1_sessions + .iter() + .any(|(id, data, ttl)| id == "sid1" && data == &session1 && *ttl <= 3600)); + assert!(user1_sessions + .iter() + .any(|(id, data, ttl)| id == "sid2" && data == &session2 && *ttl <= 3600)); + + let user2_sessions = storage + .get_sessions_by_identifier(&"user2".to_string()) + .await + .unwrap(); + assert_eq!(user2_sessions.len(), 1); + assert!(user2_sessions + .iter() + .any(|(id, data, ttl)| id == "sid3" && data == &session3 && *ttl <= 3600)); + + // Test get_session_ids_by_identifier + let user1_session_ids = storage + .get_session_ids_by_identifier(&"user1".to_string()) + .await + .unwrap(); + assert_eq!(user1_session_ids.len(), 2); + assert!(user1_session_ids.contains(&"sid1".to_string())); + assert!(user1_session_ids.contains(&"sid2".to_string())); + + storage.shutdown().await.unwrap(); + if let Some(task) = cleanup_task { + task.await + } +} + +#[test_case("memory"; "Memory")] +#[test_case("sqlx"; "Sqlx Postgres")] +#[test_case("redis"; "Redis Fred")] +#[rocket::async_test] +async fn invalidate_by_identifier(storage_case: &str) { + let (storage, cleanup_task) = create_storage(storage_case).await; + storage.setup().await.unwrap(); + + let session1 = TestSession { + user_id: "user1".to_string(), + data: "session1_data".to_string(), + }; + let session2 = TestSession { + user_id: "user1".to_string(), + data: "session2_data".to_string(), + }; + let session3 = TestSession { + user_id: "user2".to_string(), + data: "session3_data".to_string(), + }; + + // Save sessions + storage.save("sid1", session1, 3600).await.unwrap(); + storage.save("sid2", session2, 3600).await.unwrap(); + storage.save("sid3", session3.clone(), 3600).await.unwrap(); + + // Verify sessions exist + assert_eq!( + storage + .get_sessions_by_identifier(&"user1".to_string()) + .await + .unwrap() + .len(), + 2 + ); + + // Invalidate all sessions for user1 + assert_eq!( + storage + .invalidate_sessions_by_identifier(&"user1".to_string(), None) + .await + .unwrap(), + 2 + ); + + // Verify user1 sessions are gone + assert_eq!( + storage + .get_sessions_by_identifier(&"user1".to_string()) + .await + .unwrap() + .len(), + 0 + ); + + // Verify user2 session still exists + let user2_sessions = storage + .get_sessions_by_identifier(&"user2".to_string()) + .await + .unwrap(); + assert_eq!(user2_sessions.len(), 1); + assert_eq!(user2_sessions[0].0, "sid3"); + assert_eq!(user2_sessions[0].1, session3); + + storage.shutdown().await.unwrap(); + if let Some(task) = cleanup_task { + task.await + } +} + +#[test_case("memory"; "Memory")] +#[test_case("sqlx"; "Sqlx Postgres")] +#[test_case("redis"; "Redis Fred")] +#[rocket::async_test] +async fn invalidate_all_but_one_by_identifier(storage_case: &str) { + let (storage, cleanup_task) = create_storage(storage_case).await; + storage.setup().await.unwrap(); + + let session1 = TestSession { + user_id: "user1".to_string(), + data: "session1_data".to_string(), + }; + let session2 = TestSession { + user_id: "user1".to_string(), + data: "session2_data".to_string(), + }; + let session3 = TestSession { + user_id: "user1".to_string(), + data: "session3_data".to_string(), + }; + + // Save sessions + storage.save("sid1", session1, 3600).await.unwrap(); + storage.save("sid2", session2, 3600).await.unwrap(); + storage.save("sid3", session3.clone(), 3600).await.unwrap(); + + // Verify sessions exist + assert_eq!( + storage + .get_sessions_by_identifier(&"user1".to_string()) + .await + .unwrap() + .len(), + 3 + ); + + // Invalidate all sessions for user1 except the last one + assert_eq!( + storage + .invalidate_sessions_by_identifier(&"user1".to_string(), Some("sid3")) + .await + .unwrap(), + 2 + ); + + // Verify the last user1 session still exists + let user1_sessions = storage + .get_sessions_by_identifier(&"user1".to_string()) + .await + .unwrap(); + assert_eq!(user1_sessions.len(), 1); + assert_eq!(user1_sessions[0].0, "sid3"); + assert_eq!(user1_sessions[0].1, session3); + + storage.shutdown().await.unwrap(); + if let Some(task) = cleanup_task { + task.await + } +} + +#[test_case("memory"; "Memory")] +#[test_case("sqlx"; "Sqlx Postgres")] +#[test_case("redis"; "Redis Fred")] +#[rocket::async_test] +async fn delete_single_session(storage_case: &str) { + let client = Client::tracked(rocket::build()).await.unwrap(); + let (storage, cleanup_task) = create_storage(storage_case).await; + storage.setup().await.unwrap(); + + let session1 = TestSession { + user_id: "user1".to_string(), + data: "session1_data".to_string(), + }; + let session2 = TestSession { + user_id: "user1".to_string(), + data: "session2_data".to_string(), + }; + + // Save sessions + storage.save("sid1", session1.clone(), 3600).await.unwrap(); + storage.save("sid2", session2.clone(), 3600).await.unwrap(); + + // Verify both sessions exist + assert_eq!( + storage + .get_sessions_by_identifier(&"user1".to_string()) + .await + .unwrap() + .len(), + 2 + ); + + // Delete one session + storage.delete("sid1", &client.cookies()).await.unwrap(); + + // Verify only one session remains + let remaining_sessions = storage + .get_sessions_by_identifier(&"user1".to_string()) + .await + .unwrap(); + assert_eq!(remaining_sessions.len(), 1); + assert!(remaining_sessions + .iter() + .any(|(id, data, ttl)| id == "sid2" && data == &session2 && *ttl <= 3600)); + + storage.shutdown().await.unwrap(); + if let Some(task) = cleanup_task { + task.await + } +} + +#[test_case("memory"; "Memory")] +#[test_case("sqlx"; "Sqlx Postgres")] +#[test_case("redis"; "Redis Fred")] +#[rocket::async_test] +async fn nonexistent_identifier(storage_case: &str) { + let (storage, cleanup_task) = create_storage(storage_case).await; + storage.setup().await.unwrap(); + + // Try to get sessions for non-existent identifier + let sessions = storage + .get_sessions_by_identifier(&"nonexistent".to_string()) + .await + .unwrap(); + assert_eq!(sessions.len(), 0); + + // Try to get session IDs for non-existent identifier + let session_ids = storage + .get_session_ids_by_identifier(&"nonexistent".to_string()) + .await + .unwrap(); + assert_eq!(session_ids.len(), 0); + + // Try to invalidate sessions for non-existent identifier (should not error) + assert_eq!( + storage + .invalidate_sessions_by_identifier(&"nonexistent".to_string(), None) + .await + .unwrap(), + 0 + ); + + storage.shutdown().await.unwrap(); + if let Some(task) = cleanup_task { + task.await + } +}