From 6ea66e66ad36d0185dbbf592d85d3ac9c8a3d4c1 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Wed, 3 Sep 2025 17:05:45 -0400 Subject: [PATCH 01/28] fix: surface storage serialization error --- src/storage/interface.rs | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/src/storage/interface.rs b/src/storage/interface.rs index 9af0e23..b7247c6 100644 --- a/src/storage/interface.rs +++ b/src/storage/interface.rs @@ -17,7 +17,7 @@ pub enum SessionError { #[error("Session expired")] Expired, /// Error serializing or deserializing the session data - #[error("Failed to serialize/deserialize session")] + #[error("Failed to serialize/deserialize session: {0}")] Serialization(Box), /// An unexpected error from the storage backend #[error("Storage backend error: {0}")] @@ -25,25 +25,11 @@ pub enum SessionError { #[cfg(feature = "redis_fred")] #[error("fred.rs client error: {0}")] - RedisFredError(fred::error::Error), + RedisFredError(#[from] fred::error::Error), #[cfg(feature = "sqlx_postgres")] #[error("Sqlx error: {0}")] - SqlxError(sqlx::Error), -} - -#[cfg(feature = "redis_fred")] -impl From for SessionError { - fn from(value: fred::error::Error) -> Self { - SessionError::RedisFredError(value) - } -} - -#[cfg(feature = "sqlx_postgres")] -impl From for SessionError { - fn from(value: sqlx::Error) -> Self { - SessionError::SqlxError(value) - } + SqlxError(#[from] sqlx::Error), } pub type SessionResult = Result; From 4eb9ebc8535b4f2a37fa1c887e9fa22c03f51244 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Thu, 4 Sep 2025 00:24:34 -0400 Subject: [PATCH 02/28] preliminary identifier implementation for memory storage --- src/fairing.rs | 2 +- src/storage/cookie.rs | 2 +- src/storage/interface.rs | 65 ++++++++++- src/storage/memory.rs | 162 +++++++++++++++++++++++++- src/storage/redis.rs | 2 +- src/storage/sqlx.rs | 2 +- tests/indexed_storage.rs | 241 +++++++++++++++++++++++++++++++++++++++ 7 files changed, 468 insertions(+), 8 deletions(-) create mode 100644 tests/indexed_storage.rs diff --git a/src/fairing.rs b/src/fairing.rs index d7b6f90..9e4f8ac 100644 --- a/src/fairing.rs +++ b/src/fairing.rs @@ -41,7 +41,7 @@ where // Handle deleted session if let Some(deleted_id) = deleted { - let delete_result = self.storage.delete(&deleted_id).await; + let delete_result = self.storage.delete(&deleted_id, req.cookies()).await; if let Err(e) = delete_result { rocket::error!("Error while deleting session '{}': {}", deleted_id, e); } diff --git a/src/storage/cookie.rs b/src/storage/cookie.rs index 759537d..7a4969e 100644 --- a/src/storage/cookie.rs +++ b/src/storage/cookie.rs @@ -170,7 +170,7 @@ where Ok(()) // no-op (cookie session should already be saved by `save_cookie`) } - async fn delete(&self, _id: &str) -> SessionResult<()> { + async fn delete(&self, _id: &str, _cookie_jar: &CookieJar) -> SessionResult<()> { Ok(()) // no-op (cookie session should already be deleted by `save_cookie`) } } diff --git a/src/storage/interface.rs b/src/storage/interface.rs index b7247c6..72b80a7 100644 --- a/src/storage/interface.rs +++ b/src/storage/interface.rs @@ -55,7 +55,7 @@ where async fn save(&self, id: &str, data: T, ttl: u32) -> SessionResult<()>; /// Delete a session in storage. This will be performed at the end of the request lifecycle. - async fn delete(&self, id: &str) -> SessionResult<()>; + async fn delete(&self, id: &str, cookie_jar: &CookieJar) -> SessionResult<()>; /// Optional callback when there's a pending change to the session data. A `data` value /// of `None` indicates a deleted session. This callback can be used by cookie-based @@ -81,3 +81,66 @@ where Ok(()) // Default no-op } } + +/// Optional trait for session data types that can be grouped by an identifier. +/// This enables features like retrieving all sessions for a user or invalidating +/// all sessions when a user's password changes. +/// +/// The identifier should be stable for the lifetime of a session - it should not +/// change while the session is active. +/// +/// # Example +/// ```rust +/// use rocket_flex_session::storage::SessionIdentifier; +/// +/// #[derive(Clone)] +/// struct MySession { +/// user_id: String, +/// role: String, +/// } +/// +/// impl SessionIdentifier for MySession { +/// type Id = String; +/// +/// fn identifier(&self) -> Option { +/// Some(self.user_id.clone()) +/// } +/// } +/// ``` +pub trait SessionIdentifier { + /// The type of the identifier (e.g., user ID, account ID, etc.) + type Id: Send + Sync; + + /// Extract the identifier from the session data. + /// Returns `None` if the session doesn't have an identifier or + /// shouldn't be indexed. + fn identifier(&self) -> Option<&Self::Id>; +} + +/// Extended trait for storage backends that support session indexing by identifier. +/// This allows operations like finding all sessions for a user or bulk invalidation. +/// +/// Not all storage backends support this - for example, cookie-based storage +/// cannot implement this trait since cookies are only persisted on the client-side. +pub trait IndexedSessionStorage: SessionStorage +where + T: SessionIdentifier + Send + Sync, +{ + /// Retrieve all session data for the given identifier. + fn get_sessions_by_identifier( + &self, + id: &T::Id, + ) -> impl std::future::Future>>; + + /// Get all session IDs associated with the given identifier. + fn get_session_ids_by_identifier( + &self, + id: &T::Id, + ) -> impl std::future::Future>>; + + /// Remove all sessions associated with the given identifier. + fn invalidate_sessions_by_identifier( + &self, + id: &T::Id, + ) -> impl std::future::Future>; +} diff --git a/src/storage/memory.rs b/src/storage/memory.rs index 73bc03a..d3fc39a 100644 --- a/src/storage/memory.rs +++ b/src/storage/memory.rs @@ -1,6 +1,7 @@ //! In-memory session storage implementation use std::{ + collections::{HashMap, HashSet}, sync::{Arc, Mutex}, time::Duration, }; @@ -12,14 +13,16 @@ use rocket::{ tokio::{select, spawn, sync::oneshot}, }; -use super::interface::{SessionError, SessionResult, SessionStorage}; +use super::interface::{ + IndexedSessionStorage, SessionError, SessionIdentifier, SessionResult, SessionStorage, +}; /// In-memory storage provider for sessions. This is designed mostly for local /// development, and not for production use. It uses the [retainer] crate to /// create an async cache. pub struct MemoryStorage { shutdown_tx: Mutex>>, - cache: Arc>, + pub(crate) cache: Arc>, } impl Default for MemoryStorage { @@ -67,7 +70,7 @@ where Ok(()) } - async fn delete(&self, id: &str) -> SessionResult<()> { + async fn delete(&self, id: &str, _cookie_jar: &CookieJar) -> SessionResult<()> { self.cache.remove(&id.to_owned()).await; Ok(()) } @@ -94,3 +97,156 @@ where Ok(()) } } + +impl MemoryStorage { + /// Get access to the underlying cache for indexed operations + pub(crate) fn cache(&self) -> &Cache { + &self.cache + } +} + +/// Extended in-memory storage that supports session indexing by identifier. +/// This allows for operations like retrieving all sessions for a user or +/// bulk invalidation of sessions. +pub struct IndexedMemoryStorage +where + T: SessionIdentifier, +{ + base_storage: MemoryStorage, + // Index from identifier to set of session IDs + identifier_index: Arc>>>, +} + +impl Default for IndexedMemoryStorage +where + T: SessionIdentifier, + T::Id: ToString, +{ + fn default() -> Self { + Self { + base_storage: MemoryStorage::default(), + identifier_index: Arc::default(), + } + } +} + +impl IndexedMemoryStorage +where + T: SessionIdentifier, + T::Id: ToString, +{ + /// Update the identifier index when session data is saved + fn update_identifier_index(&self, session_id: &str, data: &T) { + if let Some(id) = data.identifier() { + let mut index = self.identifier_index.lock().unwrap(); + index + .entry(id.to_string()) + .or_insert_with(HashSet::new) + .insert(session_id.to_owned()); + } + } + + /// Remove from identifier index when session is deleted + fn remove_from_identifier_index(&self, session_id: &str, data: &T) { + if let Some(id) = data.identifier() { + let mut index = self.identifier_index.lock().unwrap(); + let key = id.to_string(); + if let Some(session_ids) = index.get_mut(&key) { + session_ids.remove(session_id); + if session_ids.is_empty() { + index.remove(&key); + } + } + } + } +} + +#[async_trait] +impl SessionStorage for IndexedMemoryStorage +where + T: SessionIdentifier + Clone + Send + Sync + 'static, + T::Id: ToString, +{ + async fn load( + &self, + id: &str, + ttl: Option, + cookie_jar: &CookieJar, + ) -> SessionResult<(T, u32)> { + self.base_storage.load(id, ttl, cookie_jar).await + } + + async fn save(&self, id: &str, data: T, ttl: u32) -> SessionResult<()> { + // Update identifier index before saving + self.update_identifier_index(id, &data); + + // Save using base storage + self.base_storage.save(id, data, ttl).await + } + + async fn delete(&self, id: &str, cookie_jar: &CookieJar) -> SessionResult<()> { + // Get the data first so we can update the index + if let Ok((data, _)) = self.base_storage.load(id, None, cookie_jar).await { + self.remove_from_identifier_index(id, &data); + } + + // Delete using base storage + self.base_storage.delete(id, cookie_jar).await + } + + async fn setup(&self) -> SessionResult<()> { + self.base_storage.setup().await + } + + async fn shutdown(&self) -> SessionResult<()> { + self.base_storage.shutdown().await + } +} + +impl IndexedSessionStorage for IndexedMemoryStorage +where + Self: SessionStorage, + T: SessionIdentifier + Clone + Send + Sync, + T::Id: ToString, +{ + async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult> { + let session_ids = { + let index = self.identifier_index.lock().unwrap(); + index.get(&id.to_string()).cloned().unwrap_or_default() + }; + + let mut sessions: Vec = Vec::new(); + for session_id in session_ids { + if let Some(data) = self.base_storage.cache().get(&session_id).await { + sessions.push(data.value().to_owned()); + } + } + + Ok(sessions) + } + + async fn get_session_ids_by_identifier(&self, id: &T::Id) -> SessionResult> { + let id_str = id.to_string(); + let session_ids = { + let index = self.identifier_index.lock().unwrap(); + index.get(&id_str).cloned().unwrap_or_default() + }; + + Ok(session_ids.into_iter().collect()) + } + + async fn invalidate_sessions_by_identifier(&self, id: &T::Id) -> SessionResult<()> { + let id_str = id.to_string(); + let session_ids = { + let mut index = self.identifier_index.lock().unwrap(); + index.remove(&id_str).unwrap_or_default() + }; + + // Remove all sessions from cache + for session_id in session_ids { + self.base_storage.cache().remove(&session_id).await; + } + + Ok(()) + } +} diff --git a/src/storage/redis.rs b/src/storage/redis.rs index 06a096c..2b75ef1 100644 --- a/src/storage/redis.rs +++ b/src/storage/redis.rs @@ -150,7 +150,7 @@ where Ok(()) } - async fn delete(&self, id: &str) -> SessionResult<()> { + async fn delete(&self, id: &str, _cookie_jar: &CookieJar) -> SessionResult<()> { let _: u8 = self.pool.del(self.key(id)).await?; Ok(()) } diff --git a/src/storage/sqlx.rs b/src/storage/sqlx.rs index 1f68686..b9941dc 100644 --- a/src/storage/sqlx.rs +++ b/src/storage/sqlx.rs @@ -109,7 +109,7 @@ where Ok(()) } - async fn delete(&self, id: &str) -> SessionResult<()> { + async fn delete(&self, id: &str, _cookie_jar: &CookieJar) -> SessionResult<()> { sqlx::query(&format!("DELETE FROM {} WHERE id = $1", &self.table_name)) .bind(id) .execute(&self.pool) diff --git a/tests/indexed_storage.rs b/tests/indexed_storage.rs new file mode 100644 index 0000000..8de78a4 --- /dev/null +++ b/tests/indexed_storage.rs @@ -0,0 +1,241 @@ +use rocket::local::asynchronous::Client; +use rocket_flex_session::storage::{ + interface::{IndexedSessionStorage, SessionIdentifier, SessionStorage}, + memory::IndexedMemoryStorage, +}; + +#[derive(Clone, Debug, PartialEq)] +struct TestSession { + user_id: String, + data: String, +} + +impl SessionIdentifier for TestSession { + type Id = String; + + fn identifier(&self) -> Option<&Self::Id> { + Some(&self.user_id) + } +} + +#[derive(Clone, Debug, PartialEq)] +struct SessionWithoutId { + data: String, +} + +impl SessionIdentifier for SessionWithoutId { + type Id = String; + + fn identifier(&self) -> Option<&Self::Id> { + None // This session type doesn't have an identifier + } +} + +#[rocket::async_test] +async fn indexed_memory_storage_basic_operations() { + let storage = IndexedMemoryStorage::::default(); + storage.setup().await.unwrap(); + + let session1 = TestSession { + user_id: "user1".to_string(), + data: "session1_data".to_string(), + }; + let session2 = TestSession { + user_id: "user1".to_string(), + data: "session2_data".to_string(), + }; + let session3 = TestSession { + user_id: "user2".to_string(), + data: "session3_data".to_string(), + }; + + // Save sessions + storage.save("sid1", session1.clone(), 3600).await.unwrap(); + storage.save("sid2", session2.clone(), 3600).await.unwrap(); + storage.save("sid3", session3.clone(), 3600).await.unwrap(); + + // Test get_sessions_by_identifier + let user1_sessions = storage + .get_sessions_by_identifier(&"user1".to_string()) + .await + .unwrap(); + assert_eq!(user1_sessions.len(), 2); + assert!(user1_sessions.contains(&session1)); + assert!(user1_sessions.contains(&session2)); + + let user2_sessions = storage + .get_sessions_by_identifier(&"user2".to_string()) + .await + .unwrap(); + assert_eq!(user2_sessions.len(), 1); + assert!(user2_sessions.contains(&session3)); + + // Test get_session_ids_by_identifier + let user1_session_ids = storage + .get_session_ids_by_identifier(&"user1".to_string()) + .await + .unwrap(); + assert_eq!(user1_session_ids.len(), 2); + assert!(user1_session_ids.contains(&"sid1".to_string())); + assert!(user1_session_ids.contains(&"sid2".to_string())); + + storage.shutdown().await.unwrap(); +} + +#[rocket::async_test] +async fn indexed_memory_storage_invalidate_by_identifier() { + let storage = IndexedMemoryStorage::::default(); + storage.setup().await.unwrap(); + + let session1 = TestSession { + user_id: "user1".to_string(), + data: "session1_data".to_string(), + }; + let session2 = TestSession { + user_id: "user1".to_string(), + data: "session2_data".to_string(), + }; + let session3 = TestSession { + user_id: "user2".to_string(), + data: "session3_data".to_string(), + }; + + // Save sessions + storage.save("sid1", session1, 3600).await.unwrap(); + storage.save("sid2", session2, 3600).await.unwrap(); + storage.save("sid3", session3.clone(), 3600).await.unwrap(); + + // Verify sessions exist + assert_eq!( + storage + .get_sessions_by_identifier(&"user1".to_string()) + .await + .unwrap() + .len(), + 2 + ); + + // Invalidate all sessions for user1 + storage + .invalidate_sessions_by_identifier(&"user1".to_string()) + .await + .unwrap(); + + // Verify user1 sessions are gone + assert_eq!( + storage + .get_sessions_by_identifier(&"user1".to_string()) + .await + .unwrap() + .len(), + 0 + ); + + // Verify user2 session still exists + let user2_sessions = storage + .get_sessions_by_identifier(&"user2".to_string()) + .await + .unwrap(); + assert_eq!(user2_sessions.len(), 1); + assert!(user2_sessions.contains(&session3)); + + storage.shutdown().await.unwrap(); +} + +#[rocket::async_test] +async fn indexed_memory_storage_delete_single_session() { + let client = Client::tracked(rocket::build()).await.unwrap(); + let storage = IndexedMemoryStorage::::default(); + storage.setup().await.unwrap(); + + let session1 = TestSession { + user_id: "user1".to_string(), + data: "session1_data".to_string(), + }; + let session2 = TestSession { + user_id: "user1".to_string(), + data: "session2_data".to_string(), + }; + + // Save sessions + storage.save("sid1", session1.clone(), 3600).await.unwrap(); + storage.save("sid2", session2.clone(), 3600).await.unwrap(); + + // Verify both sessions exist + assert_eq!( + storage + .get_sessions_by_identifier(&"user1".to_string()) + .await + .unwrap() + .len(), + 2 + ); + + // Delete one session + storage.delete("sid1", &client.cookies()).await.unwrap(); + + // Verify only one session remains + let remaining_sessions = storage + .get_sessions_by_identifier(&"user1".to_string()) + .await + .unwrap(); + assert_eq!(remaining_sessions.len(), 1); + assert!(remaining_sessions.contains(&session2)); + + storage.shutdown().await.unwrap(); +} + +#[rocket::async_test] +async fn indexed_memory_storage_session_without_identifier() { + let client = Client::tracked(rocket::build()).await.unwrap(); + let storage = IndexedMemoryStorage::::default(); + storage.setup().await.unwrap(); + + let session = SessionWithoutId { + data: "test_data".to_string(), + }; + + // Save session (should not be indexed) + storage.save("sid1", session.clone(), 3600).await.unwrap(); + + // Try to get sessions by identifier (should return empty) + let sessions = storage + .get_sessions_by_identifier(&"any_id".to_string()) + .await + .unwrap(); + assert_eq!(sessions.len(), 0); + + // Regular session operations should still work + let (loaded_session, _ttl) = storage.load("sid1", None, &client.cookies()).await.unwrap(); + assert_eq!(loaded_session, session); + + storage.shutdown().await.unwrap(); +} + +#[rocket::async_test] +async fn indexed_memory_storage_nonexistent_identifier() { + let storage = IndexedMemoryStorage::::default(); + storage.setup().await.unwrap(); + + // Try to get sessions for non-existent identifier + let sessions = storage + .get_sessions_by_identifier(&"nonexistent".to_string()) + .await + .unwrap(); + assert_eq!(sessions.len(), 0); + + // Try to get session IDs for non-existent identifier + let session_ids = storage + .get_session_ids_by_identifier(&"nonexistent".to_string()) + .await + .unwrap(); + assert_eq!(session_ids.len(), 0); + + // Try to invalidate sessions for non-existent identifier (should not error) + storage + .invalidate_sessions_by_identifier(&"nonexistent".to_string()) + .await + .unwrap(); + + storage.shutdown().await.unwrap(); +} From 846e5f04e133dbc47ae466267d59c8c47c089b39 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Thu, 4 Sep 2025 05:10:29 -0400 Subject: [PATCH 03/28] refactor and more tests --- src/error.rs | 37 +++++ src/guard.rs | 14 +- src/lib.rs | 5 +- src/session.rs | 48 ++++--- src/session_index.rs | 125 +++++++++++++++++ src/session_inner.rs | 12 ++ src/storage.rs | 4 +- src/storage/cookie.rs | 4 +- src/storage/interface.rs | 95 ++----------- src/storage/memory.rs | 22 ++- src/storage/redis.rs | 4 +- src/storage/sqlx.rs | 4 +- tests/indexed_session.rs | 286 +++++++++++++++++++++++++++++++++++++++ tests/indexed_storage.rs | 26 ++-- tests/storages.rs | 2 +- 15 files changed, 555 insertions(+), 133 deletions(-) create mode 100644 src/error.rs create mode 100644 src/session_index.rs create mode 100644 tests/indexed_session.rs diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..3160247 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,37 @@ +//! Error types + +/// Result type for session operations +pub type SessionResult = Result; + +/// Errors that can happen during session retrieval/handling +#[derive(Debug, thiserror::Error)] +pub enum SessionError { + /// There was no session cookie, or decryption of the cookie failed + #[error("No session cookie")] + NoSessionCookie, + /// Session wasn't found in storage + #[error("Session not found")] + NotFound, + /// Session was found but it was expired + #[error("Session expired")] + Expired, + /// Error serializing or deserializing the session data + #[error("Failed to serialize/deserialize session: {0}")] + Serialization(Box), + /// An indexing operation failed because the storage provider doesn't + /// implement [SessionStorageIndexed](crate::storage::SessionStorageIndexed) + #[error("Storage doesn't support indexing")] + NonIndexedStorage, + /// A generic error from the storage backend. This error type can be + /// used when implementing a custom session storage. + #[error("Storage backend error: {0}")] + Backend(Box), + + #[cfg(feature = "redis_fred")] + #[error("fred.rs client error: {0}")] + RedisFredError(#[from] fred::error::Error), + + #[cfg(feature = "sqlx_postgres")] + #[error("Sqlx error: {0}")] + SqlxError(#[from] sqlx::Error), +} diff --git a/src/guard.rs b/src/guard.rs index 7792aba..e84ab7c 100644 --- a/src/guard.rs +++ b/src/guard.rs @@ -10,10 +10,8 @@ use rocket::{ }; use crate::{ - session::Session, - session_inner::SessionInner, - storage::interface::{SessionError, SessionStorage}, - RocketFlexSession, + error::SessionError, session_inner::SessionInner, storage::SessionStorage, RocketFlexSession, + Session, }; /// Type of the cached inner session data in Rocket's request local cache @@ -48,7 +46,7 @@ where .await; Outcome::Success(Session::new( - cached_inner.clone(), + cached_inner.as_ref(), session_error.as_ref(), cookie_jar, &fairing.options, @@ -85,10 +83,8 @@ async fn get_session_data<'r, T: Send + Sync + Clone>( match storage.load(id, rolling_ttl, cookie_jar).await { Ok((data, ttl)) => { rocket::debug!("Session found. Creating existing session..."); - ( - Arc::new(Mutex::new(SessionInner::new_existing(id, data, ttl))), - None, - ) + let session_inner = SessionInner::new_existing(id, data, ttl); + (Arc::new(Mutex::new(session_inner)), None) } Err(e) => { rocket::debug!("Error from session storage, creating empty session: {}", e); diff --git a/src/lib.rs b/src/lib.rs index ad99697..66989b3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -137,13 +137,16 @@ mod fairing; mod guard; mod options; mod session; +mod session_index; mod session_inner; +pub mod error; pub mod storage; pub use options::SessionOptions; pub use session::Session; +pub use session_index::SessionIdentifier; -use crate::storage::{interface::SessionStorage, memory::MemoryStorage}; +use crate::storage::{memory::MemoryStorage, SessionStorage}; use std::sync::Arc; /** diff --git a/src/session.rs b/src/session.rs index 8b0883d..978f094 100644 --- a/src/session.rs +++ b/src/session.rs @@ -7,13 +7,12 @@ use std::{ fmt::Display, hash::Hash, marker::{Send, Sync}, - sync::{Arc, Mutex, MutexGuard}, + sync::{Mutex, MutexGuard}, }; use crate::{ - options::SessionOptions, - session_inner::SessionInner, - storage::interface::{SessionError, SessionStorage}, + error::SessionError, options::SessionOptions, session_inner::SessionInner, + storage::SessionStorage, }; /** @@ -47,10 +46,10 @@ fn profile(session: Session) -> String { */ pub struct Session<'a, T> where - T: Send + Sync, + T: Send + Sync + Clone, { /// Internal mutable state of the session - inner: Arc>>, + inner: &'a Mutex>, /// Error (if any) when retrieving from storage error: Option<&'a SessionError>, /// Rocket's cookie jar for managing cookies @@ -58,7 +57,7 @@ where /// User's session options options: &'a SessionOptions, /// Configured storage provider for sessions - storage: &'a dyn SessionStorage, + pub(crate) storage: &'a dyn SessionStorage, } impl Display for Session<'_, T> @@ -66,7 +65,7 @@ where T: Send + Sync + Clone, { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "Session(id: {:?})", self.get_inner().get_id()) + write!(f, "Session(id: {:?})", self.get_inner_lock().get_id()) } } @@ -76,7 +75,7 @@ where { /// Create a new session instance to keep track of the session state in a request pub(crate) fn new( - inner: Arc>>, + inner: &'a Mutex>, error: Option<&'a SessionError>, cookie_jar: &'a CookieJar<'a>, options: &'a SessionOptions, @@ -93,12 +92,14 @@ where /// Get the session ID (alphanumeric string). Will be `None` if there's no active session. pub fn id(&self) -> Option { - self.get_inner().get_id().map(|s| s.to_owned()) + self.get_inner_lock().get_id().map(|s| s.to_owned()) } /// Get the current session data via cloning. Will be `None` if there's no active session. pub fn get(&self) -> Option { - self.get_inner().get_current_data().map(|d| d.to_owned()) + self.get_inner_lock() + .get_current_data() + .map(|d| d.to_owned()) } /// Get a reference to the current session data via a closure. @@ -107,7 +108,7 @@ where where F: FnOnce(Option<&T>) -> R, { - f(self.get_inner().get_current_data()) + f(self.get_inner_lock().get_current_data()) } /// Get a mutable reference to the current session data via a closure. @@ -116,7 +117,9 @@ where where UpdateFn: FnOnce(&mut Option) -> R, { - let (response, is_deleted) = self.get_inner().tap_data_mut(f, self.get_default_ttl()); + let (response, is_deleted) = self + .get_inner_lock() + .tap_data_mut(f, self.get_default_ttl()); if is_deleted { self.delete(); } else { @@ -128,20 +131,21 @@ where /// Set/update the session data. Will create a new active session if needed. pub fn set(&mut self, new_data: T) { - self.get_inner().set_data(new_data, self.get_default_ttl()); + self.get_inner_lock() + .set_data(new_data, self.get_default_ttl()); self.update_cookies(); } /// Set the TTL of the session in seconds. This can be used to extend the length /// of the session if needed. This has no effect if there is no active session. pub fn set_ttl(&mut self, new_ttl: u32) { - self.get_inner().set_ttl(new_ttl); + self.get_inner_lock().set_ttl(new_ttl); self.update_cookies(); } /// Get the session TTL in seconds. pub fn ttl(&self) -> u32 { - self.get_inner() + self.get_inner_lock() .get_current_ttl() .unwrap_or(self.get_default_ttl()) } @@ -154,7 +158,7 @@ where /// Delete the session. pub fn delete(&mut self) { // Delete inner session data - let mut inner = self.get_inner(); + let mut inner = self.get_inner_lock(); inner.delete(); // Remove the session cookie @@ -183,7 +187,7 @@ where self.error } - fn get_inner(&self) -> MutexGuard<'_, SessionInner> { + pub(crate) fn get_inner_lock(&self) -> MutexGuard<'_, SessionInner> { self.inner.lock().expect("Failed to get session data lock") } @@ -192,7 +196,7 @@ where } fn update_cookies(&self) { - let inner = self.get_inner(); + let inner = self.get_inner_lock(); let Some(id) = inner.get_id() else { rocket::info!("Cookies not updated: no active session"); return; @@ -226,7 +230,7 @@ where Q: ?Sized + Eq + Hash, K: std::borrow::Borrow, { - self.get_inner() + self.get_inner_lock() .get_current_data() .and_then(|h| h.get(key).cloned()) } @@ -235,7 +239,7 @@ where /// a new session if needed. pub fn set_key(&mut self, key: K, value: V) { let mut new_data = self - .get_inner() + .get_inner_lock() .get_current_data() .cloned() .unwrap_or_default(); @@ -250,7 +254,7 @@ where I: IntoIterator, { let mut new_data = self - .get_inner() + .get_inner_lock() .get_current_data() .cloned() .unwrap_or_default(); diff --git a/src/session_index.rs b/src/session_index.rs new file mode 100644 index 0000000..b78af34 --- /dev/null +++ b/src/session_index.rs @@ -0,0 +1,125 @@ +use crate::{error::SessionError, storage::SessionStorageIndexed, Session}; + +/// Optional trait for session data types that can be grouped by an identifier. +/// This enables features like retrieving all sessions for a user or invalidating +/// all sessions when a user's password changes. +/// +/// The storage provider must support indexing sessions (check the docs for the +/// provider you're using). +/// +/// # Example +/// ```rust +/// use rocket_flex_session::storage::SessionIdentifier; +/// +/// #[derive(Clone)] +/// struct MySession { +/// user_id: String, +/// role: String, +/// } +/// +/// impl SessionIdentifier for MySession { +/// type Id = String; +/// +/// fn identifier(&self) -> Option<&Self::Id> { +/// Some(&self.user_id) +/// } +/// } +/// ``` +pub trait SessionIdentifier { + /// The type of the identifier (e.g., user ID, account ID, etc.) + type Id: Send + Sync + Clone; + + /// Extract the identifier from the session data. + /// Returns `None` if the session doesn't have an identifier and/or + /// shouldn't be indexed. + fn identifier(&self) -> Option<&Self::Id>; +} + +/// Session implementation block for indexing operations +impl<'a, T> Session<'a, T> +where + T: SessionIdentifier + Send + Sync + Clone, +{ + /// Get all session IDs and data for the same identifier as the current session. + /// Returns `None` if there's no session or the session isn't indexed. + pub async fn get_all_sessions(&self) -> Result>, SessionError> { + let Some(identifier) = self.get_identifier() else { + return Ok(None); + }; + let storage = self.get_indexed_storage()?; + let sessions = storage.get_sessions_by_identifier(&identifier).await?; + + Ok(Some(sessions)) + } + + /// Get all session IDs for the same identifier as the current session. + /// Returns `None` if there's no session or the session isn't indexed. + pub async fn get_all_session_ids(&self) -> Result>, SessionError> { + let Some(identifier) = self.get_identifier() else { + return Ok(None); + }; + let storage = self.get_indexed_storage()?; + let session_ids = storage.get_session_ids_by_identifier(&identifier).await?; + + Ok(Some(session_ids)) + } + + /// Invalidate all sessions with the same identifier as the current session. + /// Returns `None` if there's no session or the session isn't indexed. + pub async fn invalidate_all_sessions(&self) -> Result, SessionError> { + let Some(identifier) = self.get_identifier() else { + return Ok(None); + }; + let storage = self.get_indexed_storage()?; + storage + .invalidate_sessions_by_identifier(&identifier) + .await?; + + Ok(Some(())) + } + + /// Get all session IDs and data for a specific identifier. + pub async fn get_sessions_by_identifier( + &self, + identifier: &T::Id, + ) -> Result, SessionError> { + let storage = self.get_indexed_storage()?; + storage.get_sessions_by_identifier(identifier).await + } + + /// Get all session IDs for a specific identifier. + pub async fn get_session_ids_by_identifier( + &self, + identifier: &T::Id, + ) -> Result, SessionError> { + let storage = self.get_indexed_storage()?; + storage.get_session_ids_by_identifier(identifier).await + } + + /// Invalidate all sessions for a specific identifier. + pub async fn invalidate_sessions_by_identifier( + &self, + identifier: &T::Id, + ) -> Result<(), SessionError> { + let storage = self.get_indexed_storage()?; + storage.invalidate_sessions_by_identifier(identifier).await + } + + /// Get the current session's identifier + fn get_identifier(&self) -> Option { + let identifier = { + let inner = self.get_inner_lock(); + inner.get_current_identifier() + }; + identifier + } + + /// Try to cast the storage as an indexed storage + fn get_indexed_storage(&self) -> Result<&dyn SessionStorageIndexed, SessionError> { + let indexed_storage = self + .storage + .as_indexed_storage() + .ok_or(SessionError::NonIndexedStorage)?; + Ok(indexed_storage) + } +} diff --git a/src/session_inner.rs b/src/session_inner.rs index f148191..d092b86 100644 --- a/src/session_inner.rs +++ b/src/session_inner.rs @@ -1,5 +1,7 @@ use rand::distributions::{Alphanumeric, DistString}; +use crate::SessionIdentifier; + /// Represents a current, active session struct ActiveSession { /// Session ID (20-character alphanumeric string) @@ -151,3 +153,13 @@ where ) } } + +impl SessionInner +where + T: SessionIdentifier + Clone, +{ + pub(crate) fn get_current_identifier(&self) -> Option { + self.get_current_data() + .and_then(|data| data.identifier().cloned()) + } +} diff --git a/src/storage.rs b/src/storage.rs index c0bc63e..afd62d5 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -1,6 +1,8 @@ //! Storage implementations for sessions -pub mod interface; +mod interface; +pub use interface::*; + pub mod memory; #[cfg(feature = "cookie")] diff --git a/src/storage/cookie.rs b/src/storage/cookie.rs index 7a4969e..90fed2f 100644 --- a/src/storage/cookie.rs +++ b/src/storage/cookie.rs @@ -7,7 +7,9 @@ use rocket::{ time::{Duration, OffsetDateTime}, }; -use super::interface::{SessionError, SessionResult, SessionStorage}; +use crate::error::{SessionError, SessionResult}; + +use super::interface::SessionStorage; /** Storage provider for sessions backed by cookies. All session data is serialized to JSON diff --git a/src/storage/interface.rs b/src/storage/interface.rs index 72b80a7..2b43557 100644 --- a/src/storage/interface.rs +++ b/src/storage/interface.rs @@ -1,38 +1,8 @@ //! Shared interface for session storage -use std::fmt::Debug; - use rocket::{async_trait, http::CookieJar}; -/// Errors that can happen during session retrieval/handling -#[derive(Debug, thiserror::Error)] -pub enum SessionError { - /// There was no session cookie, or decryption of the cookie failed - #[error("No session cookie")] - NoSessionCookie, - /// Session wasn't found in storage - #[error("Session not found")] - NotFound, - /// Session was found but it was expired - #[error("Session expired")] - Expired, - /// Error serializing or deserializing the session data - #[error("Failed to serialize/deserialize session: {0}")] - Serialization(Box), - /// An unexpected error from the storage backend - #[error("Storage backend error: {0}")] - Backend(Box), - - #[cfg(feature = "redis_fred")] - #[error("fred.rs client error: {0}")] - RedisFredError(#[from] fred::error::Error), - - #[cfg(feature = "sqlx_postgres")] - #[error("Sqlx error: {0}")] - SqlxError(#[from] sqlx::Error), -} - -pub type SessionResult = Result; +use crate::{error::SessionResult, SessionIdentifier}; /// Trait representing a session backend storage. You can use your own session storage /// by implementing this trait. @@ -43,7 +13,7 @@ where { /// Load session data and TTL (time-to-live in seconds) from storage. If a TTL value is provided, /// it should be set upon retreiving the session. If session is already expired - /// or otherwise invalid, a [SessionError] should be returned instead. + /// or otherwise invalid, a [`SessionError`](crate::error::SessionError) should be returned instead. async fn load( &self, id: &str, @@ -71,6 +41,12 @@ where Ok(()) // Default no-op } + /// Storages that support indexing (by implementing [`SessionStorageIndexed`]) must + /// also implement this. Implementation should be trivial: `Some(self)` + fn as_indexed_storage(&self) -> Option<&dyn SessionStorageIndexed> { + None // Default not supported + } + /// Optional setup of resources that will be called on server startup async fn setup(&self) -> SessionResult<()> { Ok(()) // Default no-op @@ -82,65 +58,22 @@ where } } -/// Optional trait for session data types that can be grouped by an identifier. -/// This enables features like retrieving all sessions for a user or invalidating -/// all sessions when a user's password changes. -/// -/// The identifier should be stable for the lifetime of a session - it should not -/// change while the session is active. -/// -/// # Example -/// ```rust -/// use rocket_flex_session::storage::SessionIdentifier; -/// -/// #[derive(Clone)] -/// struct MySession { -/// user_id: String, -/// role: String, -/// } -/// -/// impl SessionIdentifier for MySession { -/// type Id = String; -/// -/// fn identifier(&self) -> Option { -/// Some(self.user_id.clone()) -/// } -/// } -/// ``` -pub trait SessionIdentifier { - /// The type of the identifier (e.g., user ID, account ID, etc.) - type Id: Send + Sync; - - /// Extract the identifier from the session data. - /// Returns `None` if the session doesn't have an identifier or - /// shouldn't be indexed. - fn identifier(&self) -> Option<&Self::Id>; -} - /// Extended trait for storage backends that support session indexing by identifier. /// This allows operations like finding all sessions for a user or bulk invalidation. /// -/// Not all storage backends support this - for example, cookie-based storage +/// Not all storage backends can support this - for example, cookie-based storage /// cannot implement this trait since cookies are only persisted on the client-side. -pub trait IndexedSessionStorage: SessionStorage +#[async_trait] +pub trait SessionStorageIndexed: SessionStorage where T: SessionIdentifier + Send + Sync, { /// Retrieve all session data for the given identifier. - fn get_sessions_by_identifier( - &self, - id: &T::Id, - ) -> impl std::future::Future>>; + async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult>; /// Get all session IDs associated with the given identifier. - fn get_session_ids_by_identifier( - &self, - id: &T::Id, - ) -> impl std::future::Future>>; + async fn get_session_ids_by_identifier(&self, id: &T::Id) -> SessionResult>; /// Remove all sessions associated with the given identifier. - fn invalidate_sessions_by_identifier( - &self, - id: &T::Id, - ) -> impl std::future::Future>; + async fn invalidate_sessions_by_identifier(&self, id: &T::Id) -> SessionResult<()>; } diff --git a/src/storage/memory.rs b/src/storage/memory.rs index d3fc39a..70156f6 100644 --- a/src/storage/memory.rs +++ b/src/storage/memory.rs @@ -13,10 +13,13 @@ use rocket::{ tokio::{select, spawn, sync::oneshot}, }; -use super::interface::{ - IndexedSessionStorage, SessionError, SessionIdentifier, SessionResult, SessionStorage, +use crate::{ + error::{SessionError, SessionResult}, + SessionIdentifier, }; +use super::interface::{SessionStorage, SessionStorageIndexed}; + /// In-memory storage provider for sessions. This is designed mostly for local /// development, and not for production use. It uses the [retainer] crate to /// create an async cache. @@ -105,7 +108,7 @@ impl MemoryStorage { } } -/// Extended in-memory storage that supports session indexing by identifier. +/// In-memory storage that supports session indexing by identifier. /// This allows for operations like retrieving all sessions for a user or /// bulk invalidation of sessions. pub struct IndexedMemoryStorage @@ -194,6 +197,10 @@ where self.base_storage.delete(id, cookie_jar).await } + fn as_indexed_storage(&self) -> Option<&dyn SessionStorageIndexed> { + Some(self) + } + async fn setup(&self) -> SessionResult<()> { self.base_storage.setup().await } @@ -203,22 +210,23 @@ where } } -impl IndexedSessionStorage for IndexedMemoryStorage +#[async_trait] +impl SessionStorageIndexed for IndexedMemoryStorage where Self: SessionStorage, T: SessionIdentifier + Clone + Send + Sync, T::Id: ToString, { - async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult> { + async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult> { let session_ids = { let index = self.identifier_index.lock().unwrap(); index.get(&id.to_string()).cloned().unwrap_or_default() }; - let mut sessions: Vec = Vec::new(); + let mut sessions: Vec<(String, T)> = Vec::new(); for session_id in session_ids { if let Some(data) = self.base_storage.cache().get(&session_id).await { - sessions.push(data.value().to_owned()); + sessions.push((session_id, data.value().to_owned())); } } diff --git a/src/storage/redis.rs b/src/storage/redis.rs index 2b75ef1..a99b612 100644 --- a/src/storage/redis.rs +++ b/src/storage/redis.rs @@ -6,7 +6,9 @@ use fred::{ }; use rocket::{async_trait, http::CookieJar}; -use super::interface::{SessionError, SessionResult, SessionStorage}; +use crate::error::{SessionError, SessionResult}; + +use super::interface::SessionStorage; #[derive(Debug)] pub enum RedisType { diff --git a/src/storage/sqlx.rs b/src/storage/sqlx.rs index b9941dc..1dca8a2 100644 --- a/src/storage/sqlx.rs +++ b/src/storage/sqlx.rs @@ -4,7 +4,9 @@ use rocket::{async_trait, http::CookieJar}; use sqlx::{PgPool, Row}; use time::{Duration, OffsetDateTime}; -use super::interface::{SessionError, SessionResult, SessionStorage}; +use crate::error::{SessionError, SessionResult}; + +use super::interface::SessionStorage; /** Session store using PostgreSQL via [sqlx](https://docs.rs/crate/sqlx). Stores the session data as a string, so you'll need diff --git a/tests/indexed_session.rs b/tests/indexed_session.rs new file mode 100644 index 0000000..1533c77 --- /dev/null +++ b/tests/indexed_session.rs @@ -0,0 +1,286 @@ +use rocket::{ + get, launch, routes, + serde::{Deserialize, Serialize}, +}; +use rocket_flex_session::{ + storage::memory::IndexedMemoryStorage, RocketFlexSession, Session, SessionIdentifier, +}; + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +struct UserSession { + user_id: String, + username: String, + login_time: u64, +} + +impl SessionIdentifier for UserSession { + type Id = String; + + fn identifier(&self) -> Option<&Self::Id> { + Some(&self.user_id) + } +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +struct AdminSession { + admin_id: String, + role: String, + permissions: Vec, +} + +impl SessionIdentifier for AdminSession { + type Id = String; + + fn identifier(&self) -> Option<&Self::Id> { + Some(&self.admin_id) + } +} + +// Routes for testing user sessions +#[get("/user/login//")] +async fn user_login( + mut session: Session<'_, UserSession>, + user_id: String, + username: String, +) -> String { + let user_session = UserSession { + user_id: user_id.clone(), + username: username.clone(), + login_time: 1234567890, + }; + + session.set(user_session); + format!("User {} logged in", username) +} + +#[get("/user/sessions")] +async fn get_user_sessions(session: Session<'_, UserSession>) -> String { + match session.get_all_sessions().await { + Ok(Some(sessions)) => { + format!("Found {} sessions for current user", sessions.len()) + } + Ok(None) => "No current session".to_string(), + Err(e) => format!("Error getting sessions: {}", e), + } +} + +#[get("/user/sessions/")] +async fn get_sessions_for_user(session: Session<'_, UserSession>, user_id: String) -> String { + match session.get_sessions_by_identifier(&user_id).await { + Ok(sessions) => { + format!("Sessions for user {}: {:?}", user_id, sessions) + } + Err(e) => format!("Error getting sessions: {}", e), + } +} + +#[get("/user/invalidate-all")] +async fn invalidate_all_user_sessions(session: Session<'_, UserSession>) -> String { + match session.invalidate_all_sessions().await { + Ok(Some(())) => "All sessions for current user invalidated".to_string(), + Ok(None) => "No current session".to_string(), + Err(e) => format!("Error invalidating sessions: {}", e), + } +} + +#[get("/user/invalidate-all/")] +async fn invalidate_sessions_for_user( + session: Session<'_, UserSession>, + user_id: String, +) -> String { + match session.invalidate_sessions_by_identifier(&user_id).await { + Ok(()) => format!("All sessions for user {} invalidated", user_id), + Err(e) => format!("Error invalidating sessions: {}", e), + } +} + +#[get("/user/session-ids")] +async fn get_user_session_ids(session: Session<'_, UserSession>) -> String { + match session.get_all_session_ids().await { + Ok(Some(session_ids)) => { + format!("Session IDs for current user: {:?}", session_ids) + } + Ok(None) => "No current session".to_string(), + Err(e) => format!("Error getting session IDs: {}", e), + } +} + +#[get("/user/profile")] +async fn user_profile(session: Session<'_, UserSession>) -> String { + match session.get() { + Some(user_session) => { + format!( + "Profile for {}: logged in at {}", + user_session.username, user_session.login_time + ) + } + None => "No active session".to_string(), + } +} + +#[launch] +fn rocket() -> _ { + let user_storage = IndexedMemoryStorage::::default(); + + rocket::build() + .attach( + RocketFlexSession::::builder() + .storage(user_storage) + .build(), + ) + .mount( + "/", + routes![ + user_login, + get_user_sessions, + get_sessions_for_user, + invalidate_all_user_sessions, + invalidate_sessions_for_user, + get_user_session_ids, + user_profile, + ], + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use rocket::http::Status; + use rocket::local::blocking::Client; + + fn create_test_client() -> Client { + Client::tracked(rocket()).expect("valid rocket instance") + } + + #[test] + fn user_login_and_profile() { + let client = create_test_client(); + + // Login user + let response = client.get("/user/login/user1/alice").dispatch(); + assert_eq!(response.status(), Status::Ok); + assert_eq!(response.into_string().unwrap(), "User alice logged in"); + + // Check profile + let response = client.get("/user/profile").dispatch(); + assert_eq!(response.status(), Status::Ok); + assert!(response + .into_string() + .unwrap() + .contains("Profile for alice")); + } + + #[test] + fn multiple_sessions_same_user() { + let client = create_test_client(); + + // First session for user1 + let response = client.get("/user/login/user1/alice").dispatch(); + assert_eq!(response.status(), Status::Ok); + + // Check that we can see current user's sessions + let response = client.get("/user/sessions").dispatch(); + assert_eq!(response.status(), Status::Ok); + // Note: This might show 0 or 1 sessions depending on whether the current + // session cookie is being tracked properly in the test + } + + #[test] + fn get_sessions_by_user_id() { + let client = create_test_client(); + + // Login user + let response = client.get("/user/login/user1/alice").dispatch(); + assert_eq!(response.status(), Status::Ok); + + // Get sessions for specific user ID + let response = client.get("/user/sessions/user1").dispatch(); + assert_eq!(response.status(), Status::Ok); + let body = response.into_string().unwrap(); + println!("{body}"); + assert!(body.contains("Sessions for user user1")); + } + + #[test] + fn test_session_ids_retrieval() { + let client = create_test_client(); + + // Login user + let response = client.get("/user/login/user1/alice").dispatch(); + assert_eq!(response.status(), Status::Ok); + + // Get session IDs + let response = client.get("/user/session-ids").dispatch(); + assert_eq!(response.status(), Status::Ok); + let body = response.into_string().unwrap(); + assert!(body.contains("Session IDs for current user")); + } + + #[test] + fn test_invalidate_sessions() { + let client = create_test_client(); + + // Login user + let response = client.get("/user/login/user1/alice").dispatch(); + assert_eq!(response.status(), Status::Ok); + + // Invalidate all sessions for current user + let response = client.get("/user/invalidate-all").dispatch(); + assert_eq!(response.status(), Status::Ok); + assert!(response + .into_string() + .unwrap() + .contains("All sessions for current user invalidated")); + + // Profile should now show no session + let response = client.get("/user/profile").dispatch(); + assert_eq!(response.status(), Status::Ok); + assert_eq!(response.into_string().unwrap(), "No active session"); + } + + #[test] + fn test_invalidate_sessions_by_user_id() { + let client = create_test_client(); + + // Login user + let response = client.get("/user/login/user2/bob").dispatch(); + assert_eq!(response.status(), Status::Ok); + + // Invalidate sessions for specific user + let response = client.get("/user/invalidate-all/user2").dispatch(); + assert_eq!(response.status(), Status::Ok); + assert!(response + .into_string() + .unwrap() + .contains("All sessions for user user2 invalidated")); + } + + #[test] + fn test_no_session_scenarios() { + let client = create_test_client(); + + // Try to get sessions without being logged in + let response = client.get("/user/sessions").dispatch(); + assert_eq!(response.status(), Status::Ok); + assert!(response + .into_string() + .unwrap() + .contains("No current session")); + + // Try to get session IDs without being logged in + let response = client.get("/user/session-ids").dispatch(); + assert_eq!(response.status(), Status::Ok); + assert!(response + .into_string() + .unwrap() + .contains("No current session")); + + // Try to invalidate sessions without being logged in + let response = client.get("/user/invalidate-all").dispatch(); + assert_eq!(response.status(), Status::Ok); + assert!(response + .into_string() + .unwrap() + .contains("No current session")); + } +} diff --git a/tests/indexed_storage.rs b/tests/indexed_storage.rs index 8de78a4..441e2fd 100644 --- a/tests/indexed_storage.rs +++ b/tests/indexed_storage.rs @@ -1,7 +1,7 @@ use rocket::local::asynchronous::Client; -use rocket_flex_session::storage::{ - interface::{IndexedSessionStorage, SessionIdentifier, SessionStorage}, - memory::IndexedMemoryStorage, +use rocket_flex_session::{ + storage::{memory::IndexedMemoryStorage, SessionStorage, SessionStorageIndexed}, + SessionIdentifier, }; #[derive(Clone, Debug, PartialEq)] @@ -60,15 +60,21 @@ async fn indexed_memory_storage_basic_operations() { .await .unwrap(); assert_eq!(user1_sessions.len(), 2); - assert!(user1_sessions.contains(&session1)); - assert!(user1_sessions.contains(&session2)); + assert!(user1_sessions + .iter() + .any(|(id, data)| id == "sid1" && data == &session1)); + assert!(user1_sessions + .iter() + .any(|(id, data)| id == "sid2" && data == &session2)); let user2_sessions = storage .get_sessions_by_identifier(&"user2".to_string()) .await .unwrap(); assert_eq!(user2_sessions.len(), 1); - assert!(user2_sessions.contains(&session3)); + assert!(user2_sessions + .iter() + .any(|(id, data)| id == "sid3" && data == &session3)); // Test get_session_ids_by_identifier let user1_session_ids = storage @@ -137,7 +143,9 @@ async fn indexed_memory_storage_invalidate_by_identifier() { .await .unwrap(); assert_eq!(user2_sessions.len(), 1); - assert!(user2_sessions.contains(&session3)); + assert!(user2_sessions + .iter() + .any(|(id, data)| id == "sid3" && data == &session3)); storage.shutdown().await.unwrap(); } @@ -180,7 +188,9 @@ async fn indexed_memory_storage_delete_single_session() { .await .unwrap(); assert_eq!(remaining_sessions.len(), 1); - assert!(remaining_sessions.contains(&session2)); + assert!(remaining_sessions + .iter() + .any(|(id, data)| id == "sid2" && data == &session2)); storage.shutdown().await.unwrap(); } diff --git a/tests/storages.rs b/tests/storages.rs index 98c37f7..ba6494a 100644 --- a/tests/storages.rs +++ b/tests/storages.rs @@ -6,9 +6,9 @@ use std::{future::Future, pin::Pin}; use fred::prelude::{ClientLike, ReconnectPolicy}; use rocket::{http::Status, local::asynchronous::Client, tokio::time::sleep, Build, Rocket}; use rocket_flex_session::{ + error::SessionError, storage::{ cookie::CookieStorage, - interface::SessionError, redis::{RedisFredStorage, RedisType}, sqlx::SqlxPostgresStorage, }, From 912a85e88cb5090ed86db85b6c1f4449e71907ad Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Thu, 4 Sep 2025 05:41:09 -0400 Subject: [PATCH 04/28] perf: remove unneeded Arc --- src/fairing.rs | 5 +++-- src/guard.rs | 15 ++++++--------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/fairing.rs b/src/fairing.rs index 9e4f8ac..55bdec1 100644 --- a/src/fairing.rs +++ b/src/fairing.rs @@ -1,6 +1,6 @@ use std::{ marker::{Send, Sync}, - sync::Arc, + sync::Mutex, }; use rocket::{fairing::Fairing, Build, Orbit, Request, Response, Rocket}; @@ -34,7 +34,8 @@ where async fn on_response<'r>(&self, req: &'r Request<'_>, _res: &mut Response<'r>) { // Get session data from request local cache, or generate a default empty one - let (session_inner, _): &LocalCachedSession = req.local_cache(|| (Arc::default(), None)); + let (session_inner, _): &LocalCachedSession = + req.local_cache(|| (Mutex::default(), None)); // Take inner session data let (updated, deleted) = session_inner.lock().unwrap().take_for_storage(); diff --git a/src/guard.rs b/src/guard.rs index e84ab7c..5a7fe51 100644 --- a/src/guard.rs +++ b/src/guard.rs @@ -1,7 +1,4 @@ -use std::{ - any::type_name, - sync::{Arc, Mutex}, -}; +use std::{any::type_name, sync::Mutex}; use rocket::{ http::{Cookie, CookieJar}, @@ -15,7 +12,7 @@ use crate::{ }; /// Type of the cached inner session data in Rocket's request local cache -pub(crate) type LocalCachedSession = (Arc>>, Option); +pub(crate) type LocalCachedSession = (Mutex>, Option); #[rocket::async_trait] impl<'r, T> FromRequest<'r> for Session<'r, T> @@ -46,7 +43,7 @@ where .await; Outcome::Success(Session::new( - cached_inner.as_ref(), + cached_inner, session_error.as_ref(), cookie_jar, &fairing.options, @@ -84,16 +81,16 @@ async fn get_session_data<'r, T: Send + Sync + Clone>( Ok((data, ttl)) => { rocket::debug!("Session found. Creating existing session..."); let session_inner = SessionInner::new_existing(id, data, ttl); - (Arc::new(Mutex::new(session_inner)), None) + (Mutex::new(session_inner), None) } Err(e) => { rocket::debug!("Error from session storage, creating empty session: {}", e); - (Arc::default(), Some(e)) + (Mutex::default(), Some(e)) } } } else { rocket::debug!("No valid session cookie found. Creating empty session..."); - (Arc::default(), Some(SessionError::NoSessionCookie)) + (Mutex::default(), Some(SessionError::NoSessionCookie)) } } From b22645a043dcf3012f58e8e4f5490690eb535a07 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Thu, 4 Sep 2025 06:50:15 -0400 Subject: [PATCH 05/28] docs: update docs --- src/fairing.rs | 122 +++++++++++++++++++- src/lib.rs | 258 ++++++++++++++++++++++++------------------ src/options.rs | 4 +- src/session.rs | 8 +- src/session_index.rs | 2 +- src/storage.rs | 20 ++++ src/storage/memory.rs | 31 ++++- src/storage/redis.rs | 2 +- 8 files changed, 324 insertions(+), 123 deletions(-) diff --git a/src/fairing.rs b/src/fairing.rs index 55bdec1..7ac5aff 100644 --- a/src/fairing.rs +++ b/src/fairing.rs @@ -1,11 +1,129 @@ use std::{ marker::{Send, Sync}, - sync::Mutex, + sync::{Arc, Mutex}, }; use rocket::{fairing::Fairing, Build, Orbit, Request, Response, Rocket}; -use crate::{guard::LocalCachedSession, RocketFlexSession}; +use crate::{ + guard::LocalCachedSession, + storage::{memory::MemoryStorage, SessionStorage}, + RocketFlexSessionOptions, +}; + +/** +A Rocket fairing that enables sessions. + +# Type Parameters +* `T` - The type of your session data. Must be thread-safe and + implement Clone. The storage provider you use may have additional + trait bounds as well. + +# Example +```rust +use rocket_flex_session::{RocketFlexSession, storage::cookie::CookieStorage}; +use rocket::time::Duration; +use rocket::serde::{Deserialize, Serialize}; + +#[derive(Clone, Serialize, Deserialize)] +struct MySession { + user_id: String, + role: String, +} + +#[rocket::launch] +fn rocket() -> _ { + // Use default settings + let session_fairing = RocketFlexSession::::default(); + + // Or customize settings with the builder + let custom_session = RocketFlexSession::::builder() + .storage(CookieStorage::default()) // or a custom storage provider + .with_options(|opt| { + opt.cookie_name = "my_cookie".to_string(); + opt.path = "/app".to_string(); + opt.max_age = 7 * 24 * 60 * 60; // 7 days + }) + .build(); + + rocket::build() + .attach(session_fairing) + // ... other configuration ... +} +``` +*/ +#[derive(Clone)] +pub struct RocketFlexSession { + pub(crate) options: RocketFlexSessionOptions, + pub(crate) storage: Arc>, +} +impl RocketFlexSession +where + T: Send + Sync + Clone + 'static, +{ + /// Build a session configuration + pub fn builder() -> RocketFlexSessionBuilder { + RocketFlexSessionBuilder::default() + } +} +impl Default for RocketFlexSession +where + T: Send + Sync + Clone + 'static, +{ + fn default() -> Self { + Self { + options: Default::default(), + storage: Arc::new(MemoryStorage::default()), + } + } +} + +/// Builder to configure the [RocketFlexSession] fairing +pub struct RocketFlexSessionBuilder +where + T: Send + Sync + Clone + 'static, +{ + fairing: RocketFlexSession, +} +impl Default for RocketFlexSessionBuilder +where + T: Send + Sync + Clone + 'static, +{ + fn default() -> Self { + Self { + fairing: Default::default(), + } + } +} +impl RocketFlexSessionBuilder +where + T: Send + Sync + Clone + 'static, +{ + /// Set the session options via a closure. If you're using a cookie-based storage + /// provider, make sure to set the corresponding cookie settings + /// in the storage configuration as well. + pub fn with_options(&mut self, options_fn: OptionsFn) -> &mut Self + where + OptionsFn: FnOnce(&mut RocketFlexSessionOptions), + { + options_fn(&mut self.fairing.options); + self + } + + /// Set the session storage provider + pub fn storage(&mut self, storage: S) -> &mut Self + where + S: SessionStorage + 'static, + { + self.fairing.storage = Arc::new(storage); + self + } + + /// Build the fairing + pub fn build(&self) -> RocketFlexSession { + self.fairing.clone() + } +} #[rocket::async_trait] impl Fairing for RocketFlexSession diff --git a/src/lib.rs b/src/lib.rs index 66989b3..41c6779 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,7 +14,9 @@ Simple, extensible session library for Rocket applications. call will be made to get the session data, and if the session is updated multiple times during the request, only one call will be made at the end of the request to save the session. - Multiple storage providers available, or you can - use your own session storage by implementing the [SessionStorage] trait. + use your own session storage by implementing the (`SessionStorage`)[crate::storage::SessionStorage] trait. +- Optional session indexing support for advanced features like multi-device login tracking, + bulk session invalidation, and security auditing. # Usage While technically not needed for development, it is highly recommended to @@ -120,145 +122,177 @@ fn login(mut session: Session) { } ``` -# Feature flags +## Session Indexing -These features can be enabled as shown -[in Cargo's documentation](https://doc.rust-lang.org/cargo/reference/features.html). +For use cases like multi-device login tracking or other security features, you can use a storage +provider that supports indexing, and then group sessions by an identifier (such as a user ID) using the [`SessionIdentifier`] trait: -| Name | Description | -|---------|----------------| -| `cookie` | A cookie-based session store. Data is serialized using serde_json and then encrypted into the value of a cookie. | -| `redis_fred` | A session store for Redis (and Redis-compatible databases), using the [fred.rs](https://docs.rs/crate/fred) crate. | -| `sqlx_postgres` | A session store using PostgreSQL via the [sqlx](https://docs.rs/crate/sqlx) crate. | -| `rocket_okapi` | Enables support for the [rocket_okapi](https://docs.rs/crate/rocket_okapi) crate if needed. | -*/ - -mod fairing; -mod guard; -mod options; -mod session; -mod session_index; -mod session_inner; - -pub mod error; -pub mod storage; -pub use options::SessionOptions; -pub use session::Session; -pub use session_index::SessionIdentifier; +```rust +use rocket::routes; +use rocket_flex_session::{Session, SessionIdentifier, RocketFlexSession}; +use rocket_flex_session::storage::memory::IndexedMemoryStorage; -use crate::storage::{memory::MemoryStorage, SessionStorage}; -use std::sync::Arc; +#[derive(Clone)] +struct UserSession { + user_id: String, + device_name: String, +} -/** -A Rocket fairing that enables sessions. +impl SessionIdentifier for UserSession { + type Id = String; -# Type Parameters -* `T` - The type of your session data. Must be thread-safe and - implement Clone. The storage provider you use may have additional - trait bounds as well. + fn identifier(&self) -> Option<&Self::Id> { + Some(&self.user_id) // Group sessions by user_id + } +} -# Example -```rust -use rocket_flex_session::{RocketFlexSession, SessionOptions, storage::cookie::CookieStorage}; -use rocket::time::Duration; -use rocket::serde::{Deserialize, Serialize}; +#[rocket::get("/user/sessions")] +async fn get_all_user_sessions(session: Session<'_, UserSession>) -> String { + match session.get_all_sessions().await { + Ok(Some(sessions)) => format!("Found {} active sessions", sessions.len()), + Ok(None) => "No active session".to_string(), + Err(e) => format!("Error: {}", e), + } +} -#[derive(Clone, Serialize, Deserialize)] -struct MySession { - user_id: String, - role: String, +#[rocket::get("/user/logout-everywhere")] +async fn logout_everywhere(session: Session<'_, UserSession>) -> String { + match session.invalidate_all_sessions().await { + Ok(Some(())) => "Logged out from all devices".to_string(), + Ok(None) => "No active session".to_string(), + Err(e) => format!("Error: {}", e), + } } #[rocket::launch] fn rocket() -> _ { - // Use default settings - let session_fairing = RocketFlexSession::::default(); - - // Or customize settings with the builder - let custom_session = RocketFlexSession::::builder() - .storage(CookieStorage::default()) // or a custom storage provider - .with_options(|opt| { - opt.cookie_name = "my_cookie".to_string(); - opt.path = "/app".to_string(); - opt.max_age = 7 * 24 * 60 * 60; // 7 days - }) - .build(); - rocket::build() - .attach(session_fairing) - // ... other configuration ... + .attach( + RocketFlexSession::::builder() + .storage(IndexedMemoryStorage::default()) + .build() + ) + .mount("/", routes![get_all_user_sessions, logout_everywhere]) } ``` -*/ -#[derive(Clone)] -pub struct RocketFlexSession { - pub(crate) options: SessionOptions, - pub(crate) storage: Arc>, -} -impl RocketFlexSession + +# Storage Providers + +This crate supports multiple storage backends with different capabilities: + +## Available Storage Providers + +| Storage | Feature Flag | Indexing Support | Use Case | +|---------|-------------|------------------|----------| +| [`storage::memory::MemoryStorage`] | Built-in | ❌ | Development, testing | +| [`storage::memory::IndexedMemoryStorage`] | Built-in | ✅ | Development with indexing features | +| [`storage::cookie::CookieStorage`] | `cookie` | ❌ | Client-side storage, stateless servers | +| [`storage::redis::RedisFredStorage`] | `redis_fred` | ❌ | Production, distributed systems | +| [`storage::sqlx::SqlxPostgresStorage`] | `sqlx_postgres` | ✅* | Production, existing database | + +*Support planned - see [Custom Storage](#custom-storage) section for implementation details. + +## Custom Storage + +To implement a custom storage provider, implement the [`SessionStorage`](crate::storage::SessionStorage) trait: + +```rust +use rocket_flex_session::{error::SessionResult, storage::SessionStorage}; +use rocket::{async_trait, http::CookieJar}; + +pub struct MyCustomStorage {} + +#[async_trait] +impl SessionStorage for MyCustomStorage where T: Send + Sync + Clone + 'static, { - /// Build a session configuration - pub fn builder() -> RocketFlexSessionBuilder { - RocketFlexSessionBuilder::default() + async fn load(&self, id: &str, ttl: Option, cookie_jar: &CookieJar) -> SessionResult<(T, u32)> { + // Load session from your storage + todo!() } -} -impl Default for RocketFlexSession -where - T: Send + Sync + Clone + 'static, -{ - fn default() -> Self { - Self { - options: Default::default(), - storage: Arc::new(MemoryStorage::default()), - } + + async fn save(&self, id: &str, data: T, ttl: u32) -> SessionResult<()> { + // Save session to your storage + todo!() } -} -/// Builder to configure the [RocketFlexSession] fairing -pub struct RocketFlexSessionBuilder -where - T: Send + Sync + Clone + 'static, -{ - fairing: RocketFlexSession, + async fn delete(&self, id: &str, cookie_jar: &CookieJar) -> SessionResult<()> { + // Delete session from your storage + todo!() + } } -impl Default for RocketFlexSessionBuilder +``` + +### Adding Indexing Support + +To support session indexing, also implement [`SessionStorageIndexed`](crate::storage::SessionStorageIndexed) and add the `as_indexed_storage` method +to the [`SessionStorage`](crate::storage::SessionStorage) trait: + + +```rust,ignore +use rocket_flex_session::{error::SessionResult, storage::{SessionStorage, SessionStorageIndexed, SessionIdentifier}}; + +struct MyCustomStorage; + +#[async_trait] +impl SessionStorageIndexed for MyCustomStorage where - T: Send + Sync + Clone + 'static, + T: SessionIdentifier + Send + Sync + Clone + 'static, { - fn default() -> Self { - Self { - fairing: Default::default(), - } + async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult> { + // Return all (session_id, session_data) pairs for the identifier + todo!() } + // etc... } -impl RocketFlexSessionBuilder + +// Make sure to also add this to the `SessionStorage` trait to enable indexing support +#[async_trait] +impl SessionStorage for MyCustomStorage where T: Send + Sync + Clone + 'static, { - /// Set the session options via a closure. If you're using a cookie-based storage - /// provider, make sure to set the corresponding cookie settings - /// in the storage configuration as well. - pub fn with_options(&mut self, options_fn: OptionsFn) -> &mut Self - where - OptionsFn: FnOnce(&mut SessionOptions), - { - options_fn(&mut self.fairing.options); - self - } - - /// Set the session storage provider - pub fn storage(&mut self, storage: S) -> &mut Self - where - S: SessionStorage + 'static, - { - self.fairing.storage = Arc::new(storage); - self - } + // ... other methods ... - /// Build the fairing - pub fn build(&self) -> RocketFlexSession { - self.fairing.clone() + fn as_indexed_storage(&self) -> Option<&dyn SessionStorageIndexed> { + Some(self) // Enable indexing support } } +``` + +### Implementation Tips + +1. **Thread Safety**: All storage implementations must be `Send + Sync` +2. **Trait bounds**: Add additional trait bounds to the session data type as needed +3. **Error Handling**: Use [`error::SessionError::Backend`] for custom errors +4. **TTL Handling**: Respect the TTL parameters in `load` and `save` for session expiration +5. **Indexing Consistency**: Keep identifier indexes in sync with session data +6. **Cleanup**: Implement proper cleanup in `shutdown()` if needed + +# Feature flags + +These features can be enabled as shown +[in Cargo's documentation](https://doc.rust-lang.org/cargo/reference/features.html). + +| Name | Description | +|---------|----------------| +| `cookie` | A cookie-based session store. Data is serialized using serde_json and then encrypted into the value of a cookie. | +| `redis_fred` | A session store for Redis (and Redis-compatible databases), using the [fred.rs](https://docs.rs/crate/fred) crate. | +| `sqlx_postgres` | A session store using PostgreSQL via the [sqlx](https://docs.rs/crate/sqlx) crate. | +| `rocket_okapi` | Enables support for the [rocket_okapi](https://docs.rs/crate/rocket_okapi) crate if needed. | +*/ + +mod fairing; +mod guard; +mod options; +mod session; +mod session_index; +mod session_inner; + +pub mod error; +pub mod storage; +pub use fairing::{RocketFlexSession, RocketFlexSessionBuilder}; +pub use options::RocketFlexSessionOptions; +pub use session::Session; +pub use session_index::SessionIdentifier; diff --git a/src/options.rs b/src/options.rs index 5bb7ab0..53204aa 100644 --- a/src/options.rs +++ b/src/options.rs @@ -1,6 +1,6 @@ /// Options for configuring the session. #[derive(Clone, Debug)] -pub struct SessionOptions { +pub struct RocketFlexSessionOptions { /// The name of the cookie used to store the session ID (default: `"rocket"`) pub cookie_name: String, /// The session cookie's `Domain` attribute (default: `None`) @@ -26,7 +26,7 @@ pub struct SessionOptions { pub ttl: Option, } -impl Default for SessionOptions { +impl Default for RocketFlexSessionOptions { fn default() -> Self { Self { cookie_name: "rocket".to_owned(), diff --git a/src/session.rs b/src/session.rs index 978f094..951edee 100644 --- a/src/session.rs +++ b/src/session.rs @@ -11,7 +11,7 @@ use std::{ }; use crate::{ - error::SessionError, options::SessionOptions, session_inner::SessionInner, + error::SessionError, options::RocketFlexSessionOptions, session_inner::SessionInner, storage::SessionStorage, }; @@ -55,7 +55,7 @@ where /// Rocket's cookie jar for managing cookies cookie_jar: &'a CookieJar<'a>, /// User's session options - options: &'a SessionOptions, + options: &'a RocketFlexSessionOptions, /// Configured storage provider for sessions pub(crate) storage: &'a dyn SessionStorage, } @@ -78,7 +78,7 @@ where inner: &'a Mutex>, error: Option<&'a SessionError>, cookie_jar: &'a CookieJar<'a>, - options: &'a SessionOptions, + options: &'a RocketFlexSessionOptions, storage: &'a dyn SessionStorage, ) -> Self { Self { @@ -264,7 +264,7 @@ where } /// Create the session cookie -fn create_session_cookie(id: &str, options: &SessionOptions) -> Cookie<'static> { +fn create_session_cookie(id: &str, options: &RocketFlexSessionOptions) -> Cookie<'static> { let mut cookie = Cookie::build((options.cookie_name.to_owned(), id.to_owned())) .http_only(options.http_only) .max_age(Duration::seconds(options.max_age.into())) diff --git a/src/session_index.rs b/src/session_index.rs index b78af34..ac856ad 100644 --- a/src/session_index.rs +++ b/src/session_index.rs @@ -9,7 +9,7 @@ use crate::{error::SessionError, storage::SessionStorageIndexed, Session}; /// /// # Example /// ```rust -/// use rocket_flex_session::storage::SessionIdentifier; +/// use rocket_flex_session::SessionIdentifier; /// /// #[derive(Clone)] /// struct MySession { diff --git a/src/storage.rs b/src/storage.rs index afd62d5..5f1b8cb 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -1,4 +1,24 @@ //! Storage implementations for sessions +//! +//! This module provides various storage backends for session data, with optional +//! support for session indexing by identifier. +//! +//! ## Session Indexing +//! +//! Some storage backends support indexing sessions by an identifier (like user ID). +//! This enables advanced features such as: +//! +//! - Finding all active sessions for a user +//! - Bulk invalidation of sessions (e.g., "log out everywhere") +//! - Security auditing and monitoring +//! +//! To use indexing, your session type must implement [`crate::SessionIdentifier`] and you +//! must use a storage backend that implements [`SessionStorageIndexed`]. +//! +//! ## Custom Storage +//! +//! Implement [`SessionStorage`] to create custom storage backends. For indexing +//! support, also implement [`SessionStorageIndexed`]. mod interface; pub use interface::*; diff --git a/src/storage/memory.rs b/src/storage/memory.rs index 70156f6..d7e1cb5 100644 --- a/src/storage/memory.rs +++ b/src/storage/memory.rs @@ -23,6 +23,8 @@ use super::interface::{SessionStorage, SessionStorageIndexed}; /// In-memory storage provider for sessions. This is designed mostly for local /// development, and not for production use. It uses the [retainer] crate to /// create an async cache. +/// +/// For session indexing support, see [`IndexedMemoryStorage`]. pub struct MemoryStorage { shutdown_tx: Mutex>>, pub(crate) cache: Arc>, @@ -108,9 +110,36 @@ impl MemoryStorage { } } -/// In-memory storage that supports session indexing by identifier. +/// Extended in-memory storage that supports session indexing by identifier. /// This allows for operations like retrieving all sessions for a user or /// bulk invalidation of sessions. +/// +/// Unlike [`MemoryStorage`], this implementation maintains an index mapping +/// identifiers to session IDs, enabling efficient lookups and bulk operations. +/// +/// # Example +/// ```rust +/// use rocket_flex_session::storage::memory::IndexedMemoryStorage; +/// use rocket_flex_session::{SessionIdentifier, RocketFlexSession}; +/// +/// #[derive(Clone)] +/// struct UserSession { +/// user_id: String, +/// data: String, +/// } +/// +/// impl SessionIdentifier for UserSession { +/// type Id = String; +/// fn identifier(&self) -> Option<&Self::Id> { +/// Some(&self.user_id) +/// } +/// } +/// +/// let storage = IndexedMemoryStorage::::default(); +/// let fairing = RocketFlexSession::builder() +/// .storage(storage) +/// .build(); +/// ``` pub struct IndexedMemoryStorage where T: SessionIdentifier, diff --git a/src/storage/redis.rs b/src/storage/redis.rs index a99b612..7e2e292 100644 --- a/src/storage/redis.rs +++ b/src/storage/redis.rs @@ -26,7 +26,7 @@ inverse `TryFrom for Value`, in order to dictate how the data will be co ```rust use fred::prelude::{Builder, ClientLike, Config, Value}; -use rocket_flex_session::storage::{interface::SessionError, redis::{RedisFredStorage, RedisType}}; +use rocket_flex_session::{error::SessionError, storage::{redis::{RedisFredStorage, RedisType}}}; async fn setup_storage() -> RedisFredStorage { // Setup and initialize a fred.rs Redis pool. From 9b768813b683636184379d58235c29d48f313686 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Thu, 4 Sep 2025 07:09:27 -0400 Subject: [PATCH 06/28] perf: unnecessary cloning for hashmap data --- src/session.rs | 49 ++++++++++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/src/session.rs b/src/session.rs index 951edee..f84ae8a 100644 --- a/src/session.rs +++ b/src/session.rs @@ -129,7 +129,7 @@ where response } - /// Set/update the session data. Will create a new active session if needed. + /// Set/update the session data. Will create a new active session if there isn't one. pub fn set(&mut self, new_data: T) { self.get_inner_lock() .set_data(new_data, self.get_default_ttl()); @@ -198,7 +198,7 @@ where fn update_cookies(&self) { let inner = self.get_inner_lock(); let Some(id) = inner.get_id() else { - rocket::info!("Cookies not updated: no active session"); + rocket::warn!("Cookies not updated: no active session"); return; }; @@ -235,31 +235,42 @@ where .and_then(|h| h.get(key).cloned()) } - /// Set the value of a key in the session data. Will create - /// a new session if needed. - pub fn set_key(&mut self, key: K, value: V) { - let mut new_data = self + /// Get the value of a key in the session data via a reference + pub fn tap_key(&self, key: &Q, f: F) -> R + where + Q: ?Sized + Eq + Hash, + K: std::borrow::Borrow, + F: FnOnce(Option<&V>) -> R, + { + f(self .get_inner_lock() .get_current_data() - .cloned() - .unwrap_or_default(); - new_data.insert(key, value); - self.set(new_data); + .and_then(|d| d.get(key))) + } + + /// Set the value of a key in the session data. Will create a new session if there isn't one. + pub fn set_key(&mut self, key: K, value: V) { + self.get_inner_lock().tap_data_mut( + |data| { + data.get_or_insert_default().insert(key, value); + }, + self.get_default_ttl(), + ); + self.update_cookies(); } - /// Set multiple keys and values in the session data. Will create - /// a new session if needed. + /// Set multiple keys and values in the session data. Will create a new session if there isn't one. pub fn set_keys(&mut self, kv_iter: I) where I: IntoIterator, { - let mut new_data = self - .get_inner_lock() - .get_current_data() - .cloned() - .unwrap_or_default(); - new_data.extend(kv_iter); - self.set(new_data); + self.get_inner_lock().tap_data_mut( + |data| { + data.get_or_insert_default().extend(kv_iter); + }, + self.get_default_ttl(), + ); + self.update_cookies(); } } From 68accfcca1ef3f5d9bd2e506013c0ff625aa7c0d Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Thu, 4 Sep 2025 07:33:55 -0400 Subject: [PATCH 07/28] feat: add identifier name --- src/session_index.rs | 6 +++++- tests/indexed_session.rs | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/session_index.rs b/src/session_index.rs index ac856ad..5cab35e 100644 --- a/src/session_index.rs +++ b/src/session_index.rs @@ -18,6 +18,7 @@ use crate::{error::SessionError, storage::SessionStorageIndexed, Session}; /// } /// /// impl SessionIdentifier for MySession { +/// const NAME: &str = "user_id"; /// type Id = String; /// /// fn identifier(&self) -> Option<&Self::Id> { @@ -26,7 +27,10 @@ use crate::{error::SessionError, storage::SessionStorageIndexed, Session}; /// } /// ``` pub trait SessionIdentifier { - /// The type of the identifier (e.g., user ID, account ID, etc.) + /// The name of the identifier (default: "user_id") + const NAME: &str = "user_id"; + + /// The type of the identifier type Id: Send + Sync + Clone; /// Extract the identifier from the session data. diff --git a/tests/indexed_session.rs b/tests/indexed_session.rs index 1533c77..727c4b5 100644 --- a/tests/indexed_session.rs +++ b/tests/indexed_session.rs @@ -14,6 +14,7 @@ struct UserSession { } impl SessionIdentifier for UserSession { + const NAME: &str = "user_id"; type Id = String; fn identifier(&self) -> Option<&Self::Id> { @@ -29,6 +30,7 @@ struct AdminSession { } impl SessionIdentifier for AdminSession { + const NAME: &str = "admin_id"; type Id = String; fn identifier(&self) -> Option<&Self::Id> { From 5e44d49c0d05ed5da1bfe64f02e5b133fb1e7237 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Thu, 4 Sep 2025 23:32:01 -0400 Subject: [PATCH 08/28] feat: add indexing support for sqlx postgres --- README.md | 1 - src/session_index.rs | 4 +- src/session_inner.rs | 5 +- src/storage/memory.rs | 14 +- src/storage/sqlx.rs | 113 +++++++++++++--- tests/basic.rs | 2 +- tests/common/mod.rs | 35 +++++ ...{indexed_session.rs => session_indexed.rs} | 0 tests/storages.rs | 52 ++----- ...indexed_storage.rs => storages_indexed.rs} | 128 +++++++++++------- 10 files changed, 237 insertions(+), 117 deletions(-) create mode 100644 tests/common/mod.rs rename tests/{indexed_session.rs => session_indexed.rs} (100%) rename tests/{indexed_storage.rs => storages_indexed.rs} (66%) diff --git a/README.md b/README.md index c62633e..50d605f 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,6 @@ Add to your `Cargo.toml`: ```toml [dependencies] -... rocket = "0.5" rocket-flex-session = { version = "0.1" } ``` diff --git a/src/session_index.rs b/src/session_index.rs index 5cab35e..91c0afc 100644 --- a/src/session_index.rs +++ b/src/session_index.rs @@ -27,7 +27,7 @@ use crate::{error::SessionError, storage::SessionStorageIndexed, Session}; /// } /// ``` pub trait SessionIdentifier { - /// The name of the identifier (default: "user_id") + /// The name of the identifier (default: `"user_id"`), that may be used as a field/key name by the storage backend. const NAME: &str = "user_id"; /// The type of the identifier @@ -113,7 +113,7 @@ where fn get_identifier(&self) -> Option { let identifier = { let inner = self.get_inner_lock(); - inner.get_current_identifier() + inner.get_current_identifier().cloned() }; identifier } diff --git a/src/session_inner.rs b/src/session_inner.rs index d092b86..a6cf405 100644 --- a/src/session_inner.rs +++ b/src/session_inner.rs @@ -158,8 +158,7 @@ impl SessionInner where T: SessionIdentifier + Clone, { - pub(crate) fn get_current_identifier(&self) -> Option { - self.get_current_data() - .and_then(|data| data.identifier().cloned()) + pub(crate) fn get_current_identifier(&self) -> Option<&T::Id> { + self.get_current_data().and_then(|data| data.identifier()) } } diff --git a/src/storage/memory.rs b/src/storage/memory.rs index d7e1cb5..fc49046 100644 --- a/src/storage/memory.rs +++ b/src/storage/memory.rs @@ -27,7 +27,7 @@ use super::interface::{SessionStorage, SessionStorageIndexed}; /// For session indexing support, see [`IndexedMemoryStorage`]. pub struct MemoryStorage { shutdown_tx: Mutex>>, - pub(crate) cache: Arc>, + cache: Arc>, } impl Default for MemoryStorage { @@ -62,10 +62,8 @@ where ) .await; } - Ok(( - data.to_owned(), - ttl.unwrap_or(data.expiration().remaining().unwrap().as_secs() as u32), - )) + let ttl = ttl.unwrap_or(data.expiration().remaining().unwrap().as_secs() as u32); + Ok((data.to_owned(), ttl)) } async fn save(&self, id: &str, data: T, ttl: u32) -> SessionResult<()> { @@ -114,8 +112,8 @@ impl MemoryStorage { /// This allows for operations like retrieving all sessions for a user or /// bulk invalidation of sessions. /// -/// Unlike [`MemoryStorage`], this implementation maintains an index mapping -/// identifiers to session IDs, enabling efficient lookups and bulk operations. +/// You must implement the [`SessionIdentifier`] trait for your session type, +/// and the [`SessionIdentifier::Id`] type must implement [`ToString`]. /// /// # Example /// ```rust @@ -152,7 +150,7 @@ where impl Default for IndexedMemoryStorage where T: SessionIdentifier, - T::Id: ToString, + ::Id: ToString, { fn default() -> Self { Self { diff --git a/src/storage/sqlx.rs b/src/storage/sqlx.rs index 1dca8a2..928dd6a 100644 --- a/src/storage/sqlx.rs +++ b/src/storage/sqlx.rs @@ -4,19 +4,29 @@ use rocket::{async_trait, http::CookieJar}; use sqlx::{PgPool, Row}; use time::{Duration, OffsetDateTime}; -use crate::error::{SessionError, SessionResult}; +use crate::{ + error::{SessionError, SessionResult}, + storage::SessionStorageIndexed, + SessionIdentifier, +}; use super::interface::SessionStorage; /** -Session store using PostgreSQL via [sqlx](https://docs.rs/crate/sqlx). Stores the session data as a string, so you'll need -to implement `TryFrom for String` and `TryFrom for YourSession` -for your session data type. Expects a table to already exist with the following columns: +Session store using PostgreSQL via [sqlx](https://docs.rs/crate/sqlx). + +Stores the session data as a string, so you'll need to implement `ToString` (or Display) +and `TryFrom` for your session data type. This storage providers supports session +indexing, so you'll also need to implement [`SessionIdentifier`](crate::SessionIdentifier), +and its [`Id`](crate::SessionIdentifier::Id) must be a [type supported by sqlx](https://docs.rs/sqlx/latest/sqlx/postgres/types/index.html). +Expects a table to already exist with the following columns: + | Name | Type | |------|---------| -| id | text PRIMARY KEY | -| data | text NOT NULL (or jsonb if using JSON) | -| expires | timestamptz NOT NULL | +| id | `text` PRIMARY KEY | +| data | `text` NOT NULL (or `jsonb` if using JSON) | +| `` | `` (this should match the [`SessionIdentifier`](crate::SessionIdentifier) | +| expires | `timestamptz` NOT NULL | */ pub struct SqlxPostgresStorage { pool: PgPool, @@ -35,9 +45,10 @@ impl SqlxPostgresStorage { #[async_trait] impl SessionStorage for SqlxPostgresStorage where - T: TryFrom + TryInto + Clone + Send + Sync + 'static, + T: SessionIdentifier + TryFrom + ToString + Clone + Send + Sync + 'static, + ::Id: + for<'q> sqlx::Encode<'q, sqlx::Postgres> + sqlx::Type, >::Error: std::error::Error + Send + Sync + 'static, - >::Error: std::error::Error + Send + Sync + 'static, { async fn load( &self, @@ -87,24 +98,21 @@ where } async fn save(&self, id: &str, data: T, ttl: u32) -> SessionResult<()> { - let raw_str: String = data - .try_into() - .map_err(|e| SessionError::Serialization(Box::new(e)))?; - let expires = OffsetDateTime::now_utc() + Duration::seconds(ttl.into()); - sqlx::query(&format!( r#" - INSERT INTO "{}" (id, data, expires) - VALUES ($1, $2, $3) + INSERT INTO "{}" (id, {}, data, expires) + VALUES ($1, $2, $3, $4) ON CONFLICT (id) DO UPDATE SET data = EXCLUDED.data, expires = EXCLUDED.expires "#, - self.table_name + self.table_name, + T::NAME )) .bind(id) - .bind(raw_str) - .bind(expires) + .bind(data.identifier()) + .bind(data.to_string()) + .bind(OffsetDateTime::now_utc() + Duration::seconds(ttl.into())) .execute(&self.pool) .await?; @@ -120,3 +128,70 @@ where Ok(()) } } + +#[async_trait] +impl SessionStorageIndexed for SqlxPostgresStorage +where + T: SessionIdentifier + TryFrom + ToString + Clone + Send + Sync + 'static, + ::Id: + for<'q> sqlx::Encode<'q, sqlx::Postgres> + sqlx::Type, + >::Error: std::error::Error + Send + Sync + 'static, +{ + async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult> { + let rows = sqlx::query(&format!( + r#" + SELECT id, data FROM "{}" + WHERE {} = $1 AND expires > CURRENT_TIMESTAMP"#, + &self.table_name, + T::NAME + )) + .bind(id) + .fetch_all(&self.pool) + .await?; + + let parsed_rows = rows + .into_iter() + .filter_map(|row| { + let id: String = row.try_get(0).ok()?; + let raw_data: String = row.try_get(1).ok()?; + let data = T::try_from(raw_data).ok()?; + Some((id, data)) + }) + .collect(); + Ok(parsed_rows) + } + + async fn get_session_ids_by_identifier(&self, id: &T::Id) -> SessionResult> { + let rows = sqlx::query(&format!( + r#" + SELECT id FROM "{}" + WHERE {} = $1 AND expires > CURRENT_TIMESTAMP"#, + &self.table_name, + T::NAME + )) + .bind(id) + .fetch_all(&self.pool) + .await?; + + let parsed_rows = rows + .into_iter() + .filter_map(|row| row.try_get(0).ok()) + .collect(); + Ok(parsed_rows) + } + + async fn invalidate_sessions_by_identifier(&self, id: &T::Id) -> SessionResult<()> { + let _rows = sqlx::query(&format!( + r#" + DELETE FROM "{}" + WHERE {} = $1"#, + &self.table_name, + T::NAME + )) + .bind(id) + .execute(&self.pool) + .await?; + + Ok(()) + } +} diff --git a/tests/basic.rs b/tests/basic.rs index 38169f8..c6c651a 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -41,7 +41,7 @@ fn delete_session(mut session: Session) -> &'static str { #[get("/get_hash_session/")] fn get_hash_session(session: Session>, key: &str) -> String { match session.get_key(key) { - Some(value) => value.clone(), + Some(value) => value, None => "No value".to_string(), } } diff --git a/tests/common/mod.rs b/tests/common/mod.rs new file mode 100644 index 0000000..b3895ab --- /dev/null +++ b/tests/common/mod.rs @@ -0,0 +1,35 @@ +use sqlx::{Connection, PgPool}; + +pub const POSTGRES_URL: &str = "postgres://postgres:postgres@localhost"; + +/// Setup a test Postgres database +pub async fn setup_postgres(base_url: &str) -> (PgPool, String) { + let db_name = format!( + "test_{}", + (0..6) + .map(|_| (b'a' + (rand::random::() % 26)) as char) + .collect::() + ); + let mut cxn = sqlx::PgConnection::connect(base_url).await.unwrap(); + sqlx::query(&format!("CREATE DATABASE {}", db_name)) + .execute(&mut cxn) + .await + .expect("Should create test database"); + let _ = cxn.close().await; + + let db_url = format!("{}/{}", base_url, db_name); + let pool = sqlx::PgPool::connect(&db_url).await.unwrap(); + sqlx::query( + r#"CREATE TABLE IF NOT EXISTS sessions ( + id TEXT PRIMARY KEY, + data TEXT NOT NULL, + user_id TEXT, + expires TIMESTAMPTZ NOT NULL + )"#, + ) + .execute(&pool) + .await + .expect("Should create sessions table"); + + (pool, db_name) +} diff --git a/tests/indexed_session.rs b/tests/session_indexed.rs similarity index 100% rename from tests/indexed_session.rs rename to tests/session_indexed.rs diff --git a/tests/storages.rs b/tests/storages.rs index ba6494a..741b123 100644 --- a/tests/storages.rs +++ b/tests/storages.rs @@ -1,3 +1,5 @@ +mod common; + #[macro_use] extern crate rocket; @@ -12,13 +14,13 @@ use rocket_flex_session::{ redis::{RedisFredStorage, RedisType}, sqlx::SqlxPostgresStorage, }, - RocketFlexSession, Session, + RocketFlexSession, Session, SessionIdentifier, }; use serde::{Deserialize, Serialize}; -use sqlx::{Connection, PgPool}; +use sqlx::Connection; use test_case::test_case; -const POSTGRES_URL: &str = "postgres://postgres:postgres@localhost"; +use crate::common::{setup_postgres, POSTGRES_URL}; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] struct SessionData { @@ -30,9 +32,9 @@ impl TryFrom for SessionData { Ok(Self { user_id: value }) } } -impl From for String { - fn from(value: SessionData) -> Self { - value.user_id +impl std::fmt::Display for SessionData { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.user_id) } } impl TryFrom for SessionData { @@ -47,6 +49,13 @@ impl From for fred::types::Value { Self::String(value.user_id.into()) } } +impl SessionIdentifier for SessionData { + const NAME: &str = "user_id"; + type Id = String; + fn identifier(&self) -> Option<&Self::Id> { + Some(&self.user_id) + } +} #[get("/get_session")] fn get_session(session: Session) -> String { @@ -75,37 +84,6 @@ fn expire_session(mut session: Session) { session.set_ttl(1); } -/// Setup a test Postgres database -async fn setup_postgres(base_url: &str) -> (PgPool, String) { - let db_name = format!( - "test_{}", - (0..6) - .map(|_| (b'a' + (rand::random::() % 26)) as char) - .collect::() - ); - let mut cxn = sqlx::PgConnection::connect(base_url).await.unwrap(); - sqlx::query(&format!("CREATE DATABASE {}", db_name)) - .execute(&mut cxn) - .await - .expect("Should create test database"); - let _ = cxn.close().await; - - let db_url = format!("{}/{}", base_url, db_name); - let pool = sqlx::PgPool::connect(&db_url).await.unwrap(); - sqlx::query( - r#"CREATE TABLE IF NOT EXISTS sessions ( - id TEXT PRIMARY KEY, - data TEXT NOT NULL, - expires TIMESTAMPTZ NOT NULL - )"#, - ) - .execute(&pool) - .await - .expect("Should create sessions table"); - - (pool, db_name) -} - async fn create_rocket( storage_case: &str, ) -> (Rocket, Option>>>) { diff --git a/tests/indexed_storage.rs b/tests/storages_indexed.rs similarity index 66% rename from tests/indexed_storage.rs rename to tests/storages_indexed.rs index 441e2fd..d919c82 100644 --- a/tests/indexed_storage.rs +++ b/tests/storages_indexed.rs @@ -1,15 +1,22 @@ +mod common; + +use std::{future::Future, pin::Pin}; + use rocket::local::asynchronous::Client; use rocket_flex_session::{ - storage::{memory::IndexedMemoryStorage, SessionStorage, SessionStorageIndexed}, + storage::{memory::IndexedMemoryStorage, sqlx::SqlxPostgresStorage, SessionStorageIndexed}, SessionIdentifier, }; +use sqlx::Connection; +use test_case::test_case; + +use crate::common::{setup_postgres, POSTGRES_URL}; #[derive(Clone, Debug, PartialEq)] struct TestSession { user_id: String, data: String, } - impl SessionIdentifier for TestSession { type Id = String; @@ -17,23 +24,61 @@ impl SessionIdentifier for TestSession { Some(&self.user_id) } } - -#[derive(Clone, Debug, PartialEq)] -struct SessionWithoutId { - data: String, +impl ToString for TestSession { + fn to_string(&self) -> String { + format!("{}:{}", self.user_id, self.data) + } +} +impl TryFrom for TestSession { + type Error = std::io::Error; + + fn try_from(value: String) -> Result { + let (user_id, data) = value.split_once(':').ok_or(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Invalid session format", + ))?; + Ok(TestSession { + user_id: user_id.to_string(), + data: data.to_string(), + }) + } } -impl SessionIdentifier for SessionWithoutId { - type Id = String; - - fn identifier(&self) -> Option<&Self::Id> { - None // This session type doesn't have an identifier +async fn create_storage( + storage_case: &str, +) -> ( + Box>, + Option>>>, +) { + match storage_case { + "memory" => { + let storage = IndexedMemoryStorage::::default(); + (Box::new(storage), None) + } + "sqlx" => { + let (pool, db_name) = setup_postgres(POSTGRES_URL).await; + let storage = SqlxPostgresStorage::new(pool.clone(), "sessions"); + + let cleanup_task: Pin>> = Box::pin(async move { + pool.close().await; + drop(pool); + let mut cxn = sqlx::PgConnection::connect(POSTGRES_URL).await.unwrap(); + sqlx::query(&format!("DROP DATABASE {} WITH (FORCE)", db_name)) + .execute(&mut cxn) + .await + .expect("Should drop test database"); + }); + (Box::new(storage), Some(cleanup_task)) + } + _ => unimplemented!(), } } +#[test_case("memory")] +#[test_case("sqlx")] #[rocket::async_test] -async fn indexed_memory_storage_basic_operations() { - let storage = IndexedMemoryStorage::::default(); +async fn basic_operations(storage_case: &str) { + let (storage, cleanup_task) = create_storage(storage_case).await; storage.setup().await.unwrap(); let session1 = TestSession { @@ -86,11 +131,16 @@ async fn indexed_memory_storage_basic_operations() { assert!(user1_session_ids.contains(&"sid2".to_string())); storage.shutdown().await.unwrap(); + if let Some(task) = cleanup_task { + task.await + } } +#[test_case("memory")] +#[test_case("sqlx")] #[rocket::async_test] -async fn indexed_memory_storage_invalidate_by_identifier() { - let storage = IndexedMemoryStorage::::default(); +async fn invalidate_by_identifier(storage_case: &str) { + let (storage, cleanup_task) = create_storage(storage_case).await; storage.setup().await.unwrap(); let session1 = TestSession { @@ -148,12 +198,17 @@ async fn indexed_memory_storage_invalidate_by_identifier() { .any(|(id, data)| id == "sid3" && data == &session3)); storage.shutdown().await.unwrap(); + if let Some(task) = cleanup_task { + task.await + } } +#[test_case("memory")] +#[test_case("sqlx")] #[rocket::async_test] -async fn indexed_memory_storage_delete_single_session() { +async fn delete_single_session(storage_case: &str) { let client = Client::tracked(rocket::build()).await.unwrap(); - let storage = IndexedMemoryStorage::::default(); + let (storage, cleanup_task) = create_storage(storage_case).await; storage.setup().await.unwrap(); let session1 = TestSession { @@ -193,38 +248,16 @@ async fn indexed_memory_storage_delete_single_session() { .any(|(id, data)| id == "sid2" && data == &session2)); storage.shutdown().await.unwrap(); + if let Some(task) = cleanup_task { + task.await + } } +#[test_case("memory")] +#[test_case("sqlx")] #[rocket::async_test] -async fn indexed_memory_storage_session_without_identifier() { - let client = Client::tracked(rocket::build()).await.unwrap(); - let storage = IndexedMemoryStorage::::default(); - storage.setup().await.unwrap(); - - let session = SessionWithoutId { - data: "test_data".to_string(), - }; - - // Save session (should not be indexed) - storage.save("sid1", session.clone(), 3600).await.unwrap(); - - // Try to get sessions by identifier (should return empty) - let sessions = storage - .get_sessions_by_identifier(&"any_id".to_string()) - .await - .unwrap(); - assert_eq!(sessions.len(), 0); - - // Regular session operations should still work - let (loaded_session, _ttl) = storage.load("sid1", None, &client.cookies()).await.unwrap(); - assert_eq!(loaded_session, session); - - storage.shutdown().await.unwrap(); -} - -#[rocket::async_test] -async fn indexed_memory_storage_nonexistent_identifier() { - let storage = IndexedMemoryStorage::::default(); +async fn nonexistent_identifier(storage_case: &str) { + let (storage, cleanup_task) = create_storage(storage_case).await; storage.setup().await.unwrap(); // Try to get sessions for non-existent identifier @@ -248,4 +281,7 @@ async fn indexed_memory_storage_nonexistent_identifier() { .unwrap(); storage.shutdown().await.unwrap(); + if let Some(task) = cleanup_task { + task.await + } } From a6b97c7bbdca692d98cebfa34ce15064cf076ae3 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Fri, 5 Sep 2025 01:20:13 -0400 Subject: [PATCH 09/28] fix: hashmap support and doc fixes --- src/lib.rs | 21 ++++++++++----------- src/session.rs | 2 +- src/storage.rs | 3 +-- src/storage/redis.rs | 43 ++++++++++++++++++++++--------------------- src/storage/sqlx.rs | 8 +++----- tests/storages.rs | 7 +++++++ 6 files changed, 44 insertions(+), 40 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 41c6779..558fc56 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,7 +14,7 @@ Simple, extensible session library for Rocket applications. call will be made to get the session data, and if the session is updated multiple times during the request, only one call will be made at the end of the request to save the session. - Multiple storage providers available, or you can - use your own session storage by implementing the (`SessionStorage`)[crate::storage::SessionStorage] trait. + use your own session storage by implementing the [`SessionStorage`](crate::storage::SessionStorage) trait. - Optional session indexing support for advanced features like multi-device login tracking, bulk session invalidation, and security auditing. @@ -104,8 +104,8 @@ For more info and examples of this powerful pattern, please see Rocket's documen ## HashMap session data -Instead of a custom struct, you can use a [HashMap](std::collections::HashMap) as your Session data type. This is -particularly useful if you expect your session data structure to be inconsistent and/or change frequently. +Instead of a custom struct, you can use a [HashMap](std::collections::HashMap) as your Session data type if the +storage provider supports it. This is particularly useful if you expect your session data structure to be inconsistent and/or change frequently. When using a HashMap, there are [some additional helper functions](file:///Users/farshad/Projects/pg-user-manager/api/target/doc/rocket_flex_session/struct.Session.html#method.get_key) to read and set keys. @@ -182,15 +182,14 @@ This crate supports multiple storage backends with different capabilities: ## Available Storage Providers -| Storage | Feature Flag | Indexing Support | Use Case | -|---------|-------------|------------------|----------| -| [`storage::memory::MemoryStorage`] | Built-in | ❌ | Development, testing | -| [`storage::memory::IndexedMemoryStorage`] | Built-in | ✅ | Development with indexing features | -| [`storage::cookie::CookieStorage`] | `cookie` | ❌ | Client-side storage, stateless servers | -| [`storage::redis::RedisFredStorage`] | `redis_fred` | ❌ | Production, distributed systems | -| [`storage::sqlx::SqlxPostgresStorage`] | `sqlx_postgres` | ✅* | Production, existing database | +| Storage | Feature Flag | Indexing support | HashMap support | Use Cases | +|---------|-------------|------------------|----------|----------| +| [`storage::memory::MemoryStorage`] | Built-in | ❌ | ✅ | Development, testing | +| [`storage::memory::IndexedMemoryStorage`] | Built-in | ✅ | ✅ | Development with indexing features | +| [`storage::cookie::CookieStorage`] | `cookie` | ❌ | ✅ | Client-side storage, stateless servers | +| [`storage::redis::RedisFredStorage`] | `redis_fred` | ❌ | ✅ | Production, distributed systems | +| [`storage::sqlx::SqlxPostgresStorage`] | `sqlx_postgres` | ✅ | ❌ | Production, existing database | -*Support planned - see [Custom Storage](#custom-storage) section for implementation details. ## Custom Storage diff --git a/src/session.rs b/src/session.rs index f84ae8a..1edf51b 100644 --- a/src/session.rs +++ b/src/session.rs @@ -235,7 +235,7 @@ where .and_then(|h| h.get(key).cloned()) } - /// Get the value of a key in the session data via a reference + /// Get the value of a key in the session data via a closure pub fn tap_key(&self, key: &Q, f: F) -> R where Q: ?Sized + Eq + Hash, diff --git a/src/storage.rs b/src/storage.rs index 5f1b8cb..61b19e3 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -12,8 +12,7 @@ //! - Bulk invalidation of sessions (e.g., "log out everywhere") //! - Security auditing and monitoring //! -//! To use indexing, your session type must implement [`crate::SessionIdentifier`] and you -//! must use a storage backend that implements [`SessionStorageIndexed`]. +//! To use indexing, your session type must implement [`crate::SessionIdentifier`]. //! //! ## Custom Storage //! diff --git a/src/storage/redis.rs b/src/storage/redis.rs index 7e2e292..9e1ff2e 100644 --- a/src/storage/redis.rs +++ b/src/storage/redis.rs @@ -1,7 +1,7 @@ //! Session storage with Redis (and Redis-compatible databases) use fred::{ - prelude::{HashesInterface, KeysInterface, Pool, Value}, + prelude::{FromValue, HashesInterface, KeysInterface, Pool, Value}, types::Expiration, }; use rocket::{async_trait, http::CookieJar}; @@ -17,15 +17,19 @@ pub enum RedisType { } /** -Session storage with Redis (and Redis-compatible databases) using the [fred.rs](https://docs.rs/fred) crate. -You can store the data as a Redis string or hash. Your session data type must implement `TryFrom` -using the fred.rs [Value](https://docs.rs/fred/latest/fred/types/enum.Value.html) type, as well as the -inverse `TryFrom for Value`, in order to dictate how the data will be converted to/from the Redis data type. +Redis session storage using the [fred.rs](https://docs.rs/fred) crate. + +You can store the data as a Redis string or hash. Your session data type must implement [`FromValue`](https://docs.rs/fred/latest/fred/types/trait.FromValue.html) +from the fred.rs crate, as well as the inverse `From` or `TryFrom` for [`Value`](https://docs.rs/fred/latest/fred/types/enum.Value.html) in order +to dictate how the data will be converted to/from the Redis data type. - For `RedisType::String`, convert to/from `Value::String` - For `RedisType::Hash`, convert to/from `Value::Map` +💡 Common hashmap types like `HashMap` are automatically supported - make sure to use `RedisType::Hash` +when constructing the storage to ensure they are properly converted and stored as Redis hashes. + ```rust -use fred::prelude::{Builder, ClientLike, Config, Value}; +use fred::prelude::{Builder, ClientLike, Config, FromValue, Value}; use rocket_flex_session::{error::SessionError, storage::{redis::{RedisFredStorage, RedisType}}}; async fn setup_storage() -> RedisFredStorage { @@ -35,6 +39,8 @@ async fn setup_storage() -> RedisFredStorage { .build_pool(4) .expect("Should build Redis pool"); redis_pool.init().await.expect("Should initialize Redis pool"); + + // Construct the storage let storage = RedisFredStorage::new( redis_pool, RedisType::String, // or RedisType::Hash @@ -48,19 +54,16 @@ struct MySessionData { user_id: String, } -// Implement `TryFrom` to convert from the Redis value to your session data type -impl TryFrom for MySessionData { - type Error = SessionError; // or use your own error type - fn try_from(value: Value) -> Result { - match value { - Value::String(id) => Ok(MySessionData { - user_id: id.to_string(), - }), - _ => Err(SessionError::NotFound), - } +// Implement `FromValue` to convert from the Redis value to your session data type +impl FromValue for MySessionData { + fn from_value(value: Value) -> Result { + let data: String = value.convert()?; // fred.rs provides several conversion methods on the Value type + Ok(MySessionData { + user_id: data, + }) } } -// You can use From or TryFrom for the inverse conversion +// Implement the inverse conversion impl From for Value { fn from(data: MySessionData) -> Self { Value::String(data.user_id.into()) @@ -90,8 +93,7 @@ impl RedisFredStorage { #[async_trait] impl SessionStorage for RedisFredStorage where - T: TryFrom + TryInto + Clone + Send + Sync + 'static, - >::Error: std::error::Error + Send + Sync + 'static, + T: FromValue + TryInto + Clone + Send + Sync + 'static, >::Error: std::error::Error + Send + Sync + 'static, { async fn load( @@ -119,8 +121,7 @@ where }; let found_value = value.ok_or(SessionError::NotFound)?; - let data = - T::try_from(found_value).map_err(|e| SessionError::Serialization(Box::new(e)))?; + let data = T::from_value(found_value)?; Ok((data, ttl.unwrap_or(orig_ttl.try_into().unwrap_or(0)))) } diff --git a/src/storage/sqlx.rs b/src/storage/sqlx.rs index 928dd6a..3f21079 100644 --- a/src/storage/sqlx.rs +++ b/src/storage/sqlx.rs @@ -13,11 +13,9 @@ use crate::{ use super::interface::SessionStorage; /** -Session store using PostgreSQL via [sqlx](https://docs.rs/crate/sqlx). +Session store using PostgreSQL via [sqlx](https://docs.rs/crate/sqlx) that stores session data as a string, and supports session indexing. -Stores the session data as a string, so you'll need to implement `ToString` (or Display) -and `TryFrom` for your session data type. This storage providers supports session -indexing, so you'll also need to implement [`SessionIdentifier`](crate::SessionIdentifier), +You'll need to implement `ToString` (or Display) and `TryFrom` for your session data type. You'll also need to implement [`SessionIdentifier`], and its [`Id`](crate::SessionIdentifier::Id) must be a [type supported by sqlx](https://docs.rs/sqlx/latest/sqlx/postgres/types/index.html). Expects a table to already exist with the following columns: @@ -25,7 +23,7 @@ Expects a table to already exist with the following columns: |------|---------| | id | `text` PRIMARY KEY | | data | `text` NOT NULL (or `jsonb` if using JSON) | -| `` | `` (this should match the [`SessionIdentifier`](crate::SessionIdentifier) | +| `` | `` (the name and type should match the [`SessionIdentifier`] impl) | | expires | `timestamptz` NOT NULL | */ pub struct SqlxPostgresStorage { diff --git a/tests/storages.rs b/tests/storages.rs index 741b123..5e48216 100644 --- a/tests/storages.rs +++ b/tests/storages.rs @@ -32,6 +32,13 @@ impl TryFrom for SessionData { Ok(Self { user_id: value }) } } +impl fred::types::FromValue for SessionData { + fn from_value(value: fred::prelude::Value) -> Result { + Ok(Self { + user_id: value.convert()?, + }) + } +} impl std::fmt::Display for SessionData { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.user_id) From 0af7b55cac09207f56a6e7d9a72718558b005c48 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Fri, 5 Sep 2025 01:25:28 -0400 Subject: [PATCH 10/28] renaming for consistency --- src/lib.rs | 6 +++--- src/storage/memory.rs | 16 ++++++++-------- src/storage/redis.rs | 5 +++-- tests/session_indexed.rs | 4 ++-- tests/storages_indexed.rs | 4 ++-- 5 files changed, 18 insertions(+), 17 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 558fc56..7a2be4f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -130,7 +130,7 @@ provider that supports indexing, and then group sessions by an identifier (such ```rust use rocket::routes; use rocket_flex_session::{Session, SessionIdentifier, RocketFlexSession}; -use rocket_flex_session::storage::memory::IndexedMemoryStorage; +use rocket_flex_session::storage::memory::MemoryStorageIndexed; #[derive(Clone)] struct UserSession { @@ -169,7 +169,7 @@ fn rocket() -> _ { rocket::build() .attach( RocketFlexSession::::builder() - .storage(IndexedMemoryStorage::default()) + .storage(MemoryStorageIndexed::default()) .build() ) .mount("/", routes![get_all_user_sessions, logout_everywhere]) @@ -185,7 +185,7 @@ This crate supports multiple storage backends with different capabilities: | Storage | Feature Flag | Indexing support | HashMap support | Use Cases | |---------|-------------|------------------|----------|----------| | [`storage::memory::MemoryStorage`] | Built-in | ❌ | ✅ | Development, testing | -| [`storage::memory::IndexedMemoryStorage`] | Built-in | ✅ | ✅ | Development with indexing features | +| [`storage::memory::MemoryStorageIndexed`] | Built-in | ✅ | ✅ | Development with indexing features | | [`storage::cookie::CookieStorage`] | `cookie` | ❌ | ✅ | Client-side storage, stateless servers | | [`storage::redis::RedisFredStorage`] | `redis_fred` | ❌ | ✅ | Production, distributed systems | | [`storage::sqlx::SqlxPostgresStorage`] | `sqlx_postgres` | ✅ | ❌ | Production, existing database | diff --git a/src/storage/memory.rs b/src/storage/memory.rs index fc49046..7c9b532 100644 --- a/src/storage/memory.rs +++ b/src/storage/memory.rs @@ -24,7 +24,7 @@ use super::interface::{SessionStorage, SessionStorageIndexed}; /// development, and not for production use. It uses the [retainer] crate to /// create an async cache. /// -/// For session indexing support, see [`IndexedMemoryStorage`]. +/// For session indexing support, see [`MemoryStorageIndexed`]. pub struct MemoryStorage { shutdown_tx: Mutex>>, cache: Arc>, @@ -117,7 +117,7 @@ impl MemoryStorage { /// /// # Example /// ```rust -/// use rocket_flex_session::storage::memory::IndexedMemoryStorage; +/// use rocket_flex_session::storage::memory::MemoryStorageIndexed; /// use rocket_flex_session::{SessionIdentifier, RocketFlexSession}; /// /// #[derive(Clone)] @@ -133,12 +133,12 @@ impl MemoryStorage { /// } /// } /// -/// let storage = IndexedMemoryStorage::::default(); +/// let storage = MemoryStorageIndexed::::default(); /// let fairing = RocketFlexSession::builder() /// .storage(storage) /// .build(); /// ``` -pub struct IndexedMemoryStorage +pub struct MemoryStorageIndexed where T: SessionIdentifier, { @@ -147,7 +147,7 @@ where identifier_index: Arc>>>, } -impl Default for IndexedMemoryStorage +impl Default for MemoryStorageIndexed where T: SessionIdentifier, ::Id: ToString, @@ -160,7 +160,7 @@ where } } -impl IndexedMemoryStorage +impl MemoryStorageIndexed where T: SessionIdentifier, T::Id: ToString, @@ -192,7 +192,7 @@ where } #[async_trait] -impl SessionStorage for IndexedMemoryStorage +impl SessionStorage for MemoryStorageIndexed where T: SessionIdentifier + Clone + Send + Sync + 'static, T::Id: ToString, @@ -238,7 +238,7 @@ where } #[async_trait] -impl SessionStorageIndexed for IndexedMemoryStorage +impl SessionStorageIndexed for MemoryStorageIndexed where Self: SessionStorage, T: SessionIdentifier + Clone + Send + Sync, diff --git a/src/storage/redis.rs b/src/storage/redis.rs index 9e1ff2e..75f4208 100644 --- a/src/storage/redis.rs +++ b/src/storage/redis.rs @@ -22,8 +22,8 @@ Redis session storage using the [fred.rs](https://docs.rs/fred) crate. You can store the data as a Redis string or hash. Your session data type must implement [`FromValue`](https://docs.rs/fred/latest/fred/types/trait.FromValue.html) from the fred.rs crate, as well as the inverse `From` or `TryFrom` for [`Value`](https://docs.rs/fred/latest/fred/types/enum.Value.html) in order to dictate how the data will be converted to/from the Redis data type. -- For `RedisType::String`, convert to/from `Value::String` -- For `RedisType::Hash`, convert to/from `Value::Map` +- For Redis string types, convert to/from `Value::String` +- For Redis hash types, convert to/from `Value::Map` 💡 Common hashmap types like `HashMap` are automatically supported - make sure to use `RedisType::Hash` when constructing the storage to ensure they are properly converted and stored as Redis hashes. @@ -50,6 +50,7 @@ async fn setup_storage() -> RedisFredStorage { storage } +// If using a custom struct, implement the following... struct MySessionData { user_id: String, } diff --git a/tests/session_indexed.rs b/tests/session_indexed.rs index 727c4b5..ddd12b4 100644 --- a/tests/session_indexed.rs +++ b/tests/session_indexed.rs @@ -3,7 +3,7 @@ use rocket::{ serde::{Deserialize, Serialize}, }; use rocket_flex_session::{ - storage::memory::IndexedMemoryStorage, RocketFlexSession, Session, SessionIdentifier, + storage::memory::MemoryStorageIndexed, RocketFlexSession, Session, SessionIdentifier, }; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] @@ -122,7 +122,7 @@ async fn user_profile(session: Session<'_, UserSession>) -> String { #[launch] fn rocket() -> _ { - let user_storage = IndexedMemoryStorage::::default(); + let user_storage = MemoryStorageIndexed::::default(); rocket::build() .attach( diff --git a/tests/storages_indexed.rs b/tests/storages_indexed.rs index d919c82..2c08e3b 100644 --- a/tests/storages_indexed.rs +++ b/tests/storages_indexed.rs @@ -4,7 +4,7 @@ use std::{future::Future, pin::Pin}; use rocket::local::asynchronous::Client; use rocket_flex_session::{ - storage::{memory::IndexedMemoryStorage, sqlx::SqlxPostgresStorage, SessionStorageIndexed}, + storage::{memory::MemoryStorageIndexed, sqlx::SqlxPostgresStorage, SessionStorageIndexed}, SessionIdentifier, }; use sqlx::Connection; @@ -52,7 +52,7 @@ async fn create_storage( ) { match storage_case { "memory" => { - let storage = IndexedMemoryStorage::::default(); + let storage = MemoryStorageIndexed::::default(); (Box::new(storage), None) } "sqlx" => { From fbf4086952fcbabf2102eeff03940c39c0f3e997 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Fri, 5 Sep 2025 02:30:36 -0400 Subject: [PATCH 11/28] storage: clean up expired sessions for sqlx postgres --- src/error.rs | 3 + src/storage/sqlx.rs | 120 +++++++++++++++++++++++++++++++------- tests/common/mod.rs | 10 ++++ tests/storages.rs | 33 +++++------ tests/storages_indexed.rs | 18 ++---- 5 files changed, 131 insertions(+), 53 deletions(-) diff --git a/src/error.rs b/src/error.rs index 3160247..3096d76 100644 --- a/src/error.rs +++ b/src/error.rs @@ -26,6 +26,9 @@ pub enum SessionError { /// used when implementing a custom session storage. #[error("Storage backend error: {0}")] Backend(Box), + /// Error occurred while setting up or tearing down the session storage + #[error("Error during storage setup or teardown: {0}")] + SetupTeardown(String), #[cfg(feature = "redis_fred")] #[error("fred.rs client error: {0}")] diff --git a/src/storage/sqlx.rs b/src/storage/sqlx.rs index 3f21079..851e93e 100644 --- a/src/storage/sqlx.rs +++ b/src/storage/sqlx.rs @@ -1,6 +1,14 @@ //! Session storage in PostgreSQL via sqlx -use rocket::{async_trait, http::CookieJar}; +use rocket::{ + async_trait, + http::CookieJar, + tokio::{ + self, + sync::{oneshot, Mutex}, + time::interval, + }, +}; use sqlx::{PgPool, Row}; use time::{Duration, OffsetDateTime}; @@ -23,23 +31,42 @@ Expects a table to already exist with the following columns: |------|---------| | id | `text` PRIMARY KEY | | data | `text` NOT NULL (or `jsonb` if using JSON) | -| `` | `` (the name and type should match the [`SessionIdentifier`] impl) | +| `` | `` (the name and type should match your [`SessionIdentifier`] impl) | | expires | `timestamptz` NOT NULL | */ pub struct SqlxPostgresStorage { pool: PgPool, table_name: String, + cleanup_interval: Option, + shutdown_tx: Mutex>>, } impl SqlxPostgresStorage { - pub fn new(pool: PgPool, table_name: &str) -> SqlxPostgresStorage { + /// Creates a new [`SqlxPostgresStorage`]. + /// + /// Parameters: + /// - `pool`: An initialized Postgres connection pool. + /// - `table_name`: The name of the table to use for storing sessions. + /// - `cleanup_interval`: Interval to check for and clean up expired sessions. If `None`, + /// expired sessions won't be cleaned up automatically. + pub fn new( + pool: PgPool, + table_name: &str, + cleanup_interval: Option, + ) -> SqlxPostgresStorage { Self { pool, table_name: table_name.to_owned(), + cleanup_interval, + shutdown_tx: Mutex::default(), } } } +const ID_COLUMN: &str = "id"; +const DATA_COLUMN: &str = "data"; +const EXPIRES_COLUMN: &str = "expires"; + #[async_trait] impl SessionStorage for SqlxPostgresStorage where @@ -58,10 +85,10 @@ where Some(new_ttl) => { sqlx::query(&format!( r#" - UPDATE "{}" SET expires = $1 - WHERE id = $2 AND expires > CURRENT_TIMESTAMP - RETURNING data, expires"#, - &self.table_name + UPDATE "{}" SET {EXPIRES_COLUMN} = $1 + WHERE {ID_COLUMN} = $2 AND {EXPIRES_COLUMN} > CURRENT_TIMESTAMP + RETURNING {DATA_COLUMN}, {EXPIRES_COLUMN}"#, + &self.table_name, )) .bind(OffsetDateTime::now_utc() + Duration::seconds(new_ttl.into())) .bind(id) @@ -71,9 +98,9 @@ where None => { sqlx::query(&format!( r#" - SELECT data, expires FROM "{}" - WHERE id = $1 AND expires > CURRENT_TIMESTAMP"#, - &self.table_name + SELECT {DATA_COLUMN}, {EXPIRES_COLUMN} FROM "{}" + WHERE {ID_COLUMN} = $1 AND {EXPIRES_COLUMN} > CURRENT_TIMESTAMP"#, + &self.table_name, )) .bind(id) .fetch_optional(&self.pool) @@ -83,8 +110,8 @@ where let (raw_str, expires) = match row { Some(row) => { - let data: String = row.try_get("data")?; - let expires: OffsetDateTime = row.try_get("expires")?; + let data: String = row.try_get(DATA_COLUMN)?; + let expires: OffsetDateTime = row.try_get(EXPIRES_COLUMN)?; (data, expires) } None => return Err(SessionError::NotFound), @@ -98,11 +125,11 @@ where async fn save(&self, id: &str, data: T, ttl: u32) -> SessionResult<()> { sqlx::query(&format!( r#" - INSERT INTO "{}" (id, {}, data, expires) + INSERT INTO "{}" ({ID_COLUMN}, {}, {DATA_COLUMN}, {EXPIRES_COLUMN}) VALUES ($1, $2, $3, $4) - ON CONFLICT (id) DO UPDATE SET - data = EXCLUDED.data, - expires = EXCLUDED.expires + ON CONFLICT ({ID_COLUMN}) DO UPDATE SET + {DATA_COLUMN} = EXCLUDED.{DATA_COLUMN}, + {EXPIRES_COLUMN} = EXCLUDED.{EXPIRES_COLUMN} "#, self.table_name, T::NAME @@ -118,13 +145,66 @@ where } async fn delete(&self, id: &str, _cookie_jar: &CookieJar) -> SessionResult<()> { - sqlx::query(&format!("DELETE FROM {} WHERE id = $1", &self.table_name)) - .bind(id) - .execute(&self.pool) - .await?; + sqlx::query(&format!( + "DELETE FROM {} WHERE {ID_COLUMN} = $1", + &self.table_name + )) + .bind(id) + .execute(&self.pool) + .await?; Ok(()) } + + async fn setup(&self) -> SessionResult<()> { + let Some(cleanup_interval) = self.cleanup_interval else { + return Ok(()); + }; + let (tx, rx) = oneshot::channel(); + let pool = self.pool.clone(); + let table_name = self.table_name.clone(); + tokio::spawn(async move { + rocket::info!("Starting session cleanup monitor"); + let mut interval = interval(cleanup_interval); + tokio::select! { + _ = async { + loop { + interval.tick().await; + rocket::debug!("Cleaning up expired sessions"); + if let Err(e) = cleanup_expired_sessions(&table_name, &pool).await { + rocket::error!("Error deleting expired sessions: {e}"); + } + } + } => (), + _ = rx => { + rocket::info!("Session cleanup monitor shutdown"); + } + } + }); + self.shutdown_tx.lock().await.replace(tx); + + Ok(()) + } + + async fn shutdown(&self) -> SessionResult<()> { + if let Some(tx) = self.shutdown_tx.lock().await.take() { + tx.send(()).map_err(|_| { + SessionError::SetupTeardown("Failed to send shutdown signal".to_string()) + })?; + } + Ok(()) + } +} + +async fn cleanup_expired_sessions(table_name: &str, pool: &PgPool) -> Result { + rocket::debug!("Cleaning up expired sessions"); + let rows = sqlx::query(&format!( + "DELETE FROM {table_name} WHERE {EXPIRES_COLUMN} < $1" + )) + .bind(OffsetDateTime::now_utc()) + .execute(pool) + .await?; + Ok(rows.rows_affected()) } #[async_trait] diff --git a/tests/common/mod.rs b/tests/common/mod.rs index b3895ab..9a14e54 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -33,3 +33,13 @@ pub async fn setup_postgres(base_url: &str) -> (PgPool, String) { (pool, db_name) } + +pub async fn teardown_postgres(pool: sqlx::Pool, db_name: String) { + pool.close().await; + drop(pool); + let mut cxn = sqlx::PgConnection::connect(POSTGRES_URL).await.unwrap(); + sqlx::query(&format!("DROP DATABASE {} WITH (FORCE)", db_name)) + .execute(&mut cxn) + .await + .expect("Should drop test database"); +} diff --git a/tests/storages.rs b/tests/storages.rs index 5e48216..5b24cf6 100644 --- a/tests/storages.rs +++ b/tests/storages.rs @@ -6,7 +6,10 @@ extern crate rocket; use std::{future::Future, pin::Pin}; use fred::prelude::{ClientLike, ReconnectPolicy}; -use rocket::{http::Status, local::asynchronous::Client, tokio::time::sleep, Build, Rocket}; +use rocket::{ + futures::FutureExt, http::Status, local::asynchronous::Client, tokio::time::sleep, Build, + Rocket, +}; use rocket_flex_session::{ error::SessionError, storage::{ @@ -17,10 +20,9 @@ use rocket_flex_session::{ RocketFlexSession, Session, SessionIdentifier, }; use serde::{Deserialize, Serialize}; -use sqlx::Connection; use test_case::test_case; -use crate::common::{setup_postgres, POSTGRES_URL}; +use crate::common::{setup_postgres, teardown_postgres, POSTGRES_URL}; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] struct SessionData { @@ -93,7 +95,10 @@ fn expire_session(mut session: Session) { async fn create_rocket( storage_case: &str, -) -> (Rocket, Option>>>) { +) -> ( + Rocket, + Option + Send>>>, +) { let (fairing, cleanup_task) = match storage_case { "cookie" => ( RocketFlexSession::::builder() @@ -114,29 +119,19 @@ async fn create_rocket( let fairing = RocketFlexSession::::builder() .storage(storage) .build(); - - let cleanup_task: Pin>> = Box::pin(async move { + let cleanup_task = async move { pool.quit().await.ok(); - drop(pool); - }); + } + .boxed(); (fairing, Some(cleanup_task)) } "sqlx" => { let (pool, db_name) = setup_postgres(POSTGRES_URL).await; - let storage = SqlxPostgresStorage::new(pool.clone(), "sessions"); + let storage = SqlxPostgresStorage::new(pool.clone(), "sessions", None); let fairing = RocketFlexSession::::builder() .storage(storage) .build(); - - let cleanup_task: Pin>> = Box::pin(async move { - pool.close().await; - drop(pool); - let mut cxn = sqlx::PgConnection::connect(POSTGRES_URL).await.unwrap(); - sqlx::query(&format!("DROP DATABASE {} WITH (FORCE)", db_name)) - .execute(&mut cxn) - .await - .expect("Should drop test database"); - }); + let cleanup_task = teardown_postgres(pool, db_name).boxed(); (fairing, Some(cleanup_task)) } _ => unimplemented!(), diff --git a/tests/storages_indexed.rs b/tests/storages_indexed.rs index 2c08e3b..b6bbe0c 100644 --- a/tests/storages_indexed.rs +++ b/tests/storages_indexed.rs @@ -2,15 +2,14 @@ mod common; use std::{future::Future, pin::Pin}; -use rocket::local::asynchronous::Client; +use rocket::{futures::FutureExt, local::asynchronous::Client}; use rocket_flex_session::{ storage::{memory::MemoryStorageIndexed, sqlx::SqlxPostgresStorage, SessionStorageIndexed}, SessionIdentifier, }; -use sqlx::Connection; use test_case::test_case; -use crate::common::{setup_postgres, POSTGRES_URL}; +use crate::common::{setup_postgres, teardown_postgres, POSTGRES_URL}; #[derive(Clone, Debug, PartialEq)] struct TestSession { @@ -57,17 +56,8 @@ async fn create_storage( } "sqlx" => { let (pool, db_name) = setup_postgres(POSTGRES_URL).await; - let storage = SqlxPostgresStorage::new(pool.clone(), "sessions"); - - let cleanup_task: Pin>> = Box::pin(async move { - pool.close().await; - drop(pool); - let mut cxn = sqlx::PgConnection::connect(POSTGRES_URL).await.unwrap(); - sqlx::query(&format!("DROP DATABASE {} WITH (FORCE)", db_name)) - .execute(&mut cxn) - .await - .expect("Should drop test database"); - }); + let storage = SqlxPostgresStorage::new(pool.clone(), "sessions", None); + let cleanup_task = teardown_postgres(pool, db_name).boxed(); (Box::new(storage), Some(cleanup_task)) } _ => unimplemented!(), From abad2f55b36d6ceb77713a55ebebbcfffb1ada7d Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Fri, 5 Sep 2025 03:12:49 -0400 Subject: [PATCH 12/28] storage: return number of invalidated sessions --- src/fairing.rs | 2 +- src/session.rs | 2 +- src/session_index.rs | 12 ++++++------ src/storage/interface.rs | 6 +++--- src/storage/memory.rs | 8 ++++---- src/storage/sqlx.rs | 6 +++--- tests/session_indexed.rs | 27 +++++++++++++-------------- tests/storages_indexed.rs | 38 ++++++++++++++++++++++---------------- 8 files changed, 53 insertions(+), 48 deletions(-) diff --git a/src/fairing.rs b/src/fairing.rs index 7ac5aff..7da5851 100644 --- a/src/fairing.rs +++ b/src/fairing.rs @@ -33,7 +33,7 @@ struct MySession { #[rocket::launch] fn rocket() -> _ { - // Use default settings + // Use default settings with in-memory storage let session_fairing = RocketFlexSession::::default(); // Or customize settings with the builder diff --git a/src/session.rs b/src/session.rs index 1edf51b..8d9e9f6 100644 --- a/src/session.rs +++ b/src/session.rs @@ -18,7 +18,7 @@ use crate::{ /** Represents the current session state. When used as a request guard, it will attempt to retrieve the session. The request guard will always succeed - if a -valid session wasn't found, `session.get()` will return `None` indicating an +valid session wasn't found, the data functions will return `None` indicating an inactive session. # Type Parameters diff --git a/src/session_index.rs b/src/session_index.rs index 91c0afc..751851c 100644 --- a/src/session_index.rs +++ b/src/session_index.rs @@ -68,18 +68,18 @@ where Ok(Some(session_ids)) } - /// Invalidate all sessions with the same identifier as the current session. + /// Invalidate all sessions with the same identifier as the current session, returning the number of sessions invalidated. /// Returns `None` if there's no session or the session isn't indexed. - pub async fn invalidate_all_sessions(&self) -> Result, SessionError> { + pub async fn invalidate_all_sessions(&self) -> Result, SessionError> { let Some(identifier) = self.get_identifier() else { return Ok(None); }; let storage = self.get_indexed_storage()?; - storage + let num_sessions = storage .invalidate_sessions_by_identifier(&identifier) .await?; - Ok(Some(())) + Ok(Some(num_sessions)) } /// Get all session IDs and data for a specific identifier. @@ -100,11 +100,11 @@ where storage.get_session_ids_by_identifier(identifier).await } - /// Invalidate all sessions for a specific identifier. + /// Invalidate all sessions for a specific identifier, returning the number of sessions invalidated. pub async fn invalidate_sessions_by_identifier( &self, identifier: &T::Id, - ) -> Result<(), SessionError> { + ) -> Result { let storage = self.get_indexed_storage()?; storage.invalidate_sessions_by_identifier(identifier).await } diff --git a/src/storage/interface.rs b/src/storage/interface.rs index 2b43557..04a3d97 100644 --- a/src/storage/interface.rs +++ b/src/storage/interface.rs @@ -68,12 +68,12 @@ pub trait SessionStorageIndexed: SessionStorage where T: SessionIdentifier + Send + Sync, { - /// Retrieve all session data for the given identifier. + /// Retrieve all session IDs and data for the given identifier. async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult>; /// Get all session IDs associated with the given identifier. async fn get_session_ids_by_identifier(&self, id: &T::Id) -> SessionResult>; - /// Remove all sessions associated with the given identifier. - async fn invalidate_sessions_by_identifier(&self, id: &T::Id) -> SessionResult<()>; + /// Remove all sessions associated with the given identifier. Returns the number of sessions removed. + async fn invalidate_sessions_by_identifier(&self, id: &T::Id) -> SessionResult; } diff --git a/src/storage/memory.rs b/src/storage/memory.rs index 7c9b532..d93dc9d 100644 --- a/src/storage/memory.rs +++ b/src/storage/memory.rs @@ -270,7 +270,7 @@ where Ok(session_ids.into_iter().collect()) } - async fn invalidate_sessions_by_identifier(&self, id: &T::Id) -> SessionResult<()> { + async fn invalidate_sessions_by_identifier(&self, id: &T::Id) -> SessionResult { let id_str = id.to_string(); let session_ids = { let mut index = self.identifier_index.lock().unwrap(); @@ -278,10 +278,10 @@ where }; // Remove all sessions from cache - for session_id in session_ids { - self.base_storage.cache().remove(&session_id).await; + for session_id in &session_ids { + self.base_storage.cache().remove(session_id).await; } - Ok(()) + Ok(session_ids.len() as u64) } } diff --git a/src/storage/sqlx.rs b/src/storage/sqlx.rs index 851e93e..dffd67d 100644 --- a/src/storage/sqlx.rs +++ b/src/storage/sqlx.rs @@ -258,8 +258,8 @@ where Ok(parsed_rows) } - async fn invalidate_sessions_by_identifier(&self, id: &T::Id) -> SessionResult<()> { - let _rows = sqlx::query(&format!( + async fn invalidate_sessions_by_identifier(&self, id: &T::Id) -> SessionResult { + let rows = sqlx::query(&format!( r#" DELETE FROM "{}" WHERE {} = $1"#, @@ -270,6 +270,6 @@ where .execute(&self.pool) .await?; - Ok(()) + Ok(rows.rows_affected()) } } diff --git a/tests/session_indexed.rs b/tests/session_indexed.rs index ddd12b4..f312d5e 100644 --- a/tests/session_indexed.rs +++ b/tests/session_indexed.rs @@ -1,6 +1,7 @@ use rocket::{ - get, launch, routes, + get, routes, serde::{Deserialize, Serialize}, + Build, Rocket, }; use rocket_flex_session::{ storage::memory::MemoryStorageIndexed, RocketFlexSession, Session, SessionIdentifier, @@ -62,7 +63,7 @@ async fn get_user_sessions(session: Session<'_, UserSession>) -> String { format!("Found {} sessions for current user", sessions.len()) } Ok(None) => "No current session".to_string(), - Err(e) => format!("Error getting sessions: {}", e), + Err(e) => format!("Error getting sessions: {e}"), } } @@ -70,18 +71,18 @@ async fn get_user_sessions(session: Session<'_, UserSession>) -> String { async fn get_sessions_for_user(session: Session<'_, UserSession>, user_id: String) -> String { match session.get_sessions_by_identifier(&user_id).await { Ok(sessions) => { - format!("Sessions for user {}: {:?}", user_id, sessions) + format!("Sessions for user {user_id}: {:?}", sessions) } - Err(e) => format!("Error getting sessions: {}", e), + Err(e) => format!("Error getting sessions: {e}"), } } #[get("/user/invalidate-all")] async fn invalidate_all_user_sessions(session: Session<'_, UserSession>) -> String { match session.invalidate_all_sessions().await { - Ok(Some(())) => "All sessions for current user invalidated".to_string(), + Ok(Some(n)) => format!("{n} session(s) for current user invalidated."), Ok(None) => "No current session".to_string(), - Err(e) => format!("Error invalidating sessions: {}", e), + Err(e) => format!("Error invalidating sessions: {e}"), } } @@ -91,8 +92,8 @@ async fn invalidate_sessions_for_user( user_id: String, ) -> String { match session.invalidate_sessions_by_identifier(&user_id).await { - Ok(()) => format!("All sessions for user {} invalidated", user_id), - Err(e) => format!("Error invalidating sessions: {}", e), + Ok(n) => format!("{n} session(s) for user {user_id} invalidated"), + Err(e) => format!("Error invalidating sessions: {e}"), } } @@ -103,7 +104,7 @@ async fn get_user_session_ids(session: Session<'_, UserSession>) -> String { format!("Session IDs for current user: {:?}", session_ids) } Ok(None) => "No current session".to_string(), - Err(e) => format!("Error getting session IDs: {}", e), + Err(e) => format!("Error getting session IDs: {e}"), } } @@ -120,8 +121,7 @@ async fn user_profile(session: Session<'_, UserSession>) -> String { } } -#[launch] -fn rocket() -> _ { +fn rocket() -> Rocket { let user_storage = MemoryStorageIndexed::::default(); rocket::build() @@ -199,7 +199,6 @@ mod tests { let response = client.get("/user/sessions/user1").dispatch(); assert_eq!(response.status(), Status::Ok); let body = response.into_string().unwrap(); - println!("{body}"); assert!(body.contains("Sessions for user user1")); } @@ -232,7 +231,7 @@ mod tests { assert!(response .into_string() .unwrap() - .contains("All sessions for current user invalidated")); + .contains("1 session(s) for current user invalidated")); // Profile should now show no session let response = client.get("/user/profile").dispatch(); @@ -254,7 +253,7 @@ mod tests { assert!(response .into_string() .unwrap() - .contains("All sessions for user user2 invalidated")); + .contains("1 session(s) for user user2 invalidated")); } #[test] diff --git a/tests/storages_indexed.rs b/tests/storages_indexed.rs index b6bbe0c..2fa1c4d 100644 --- a/tests/storages_indexed.rs +++ b/tests/storages_indexed.rs @@ -64,8 +64,8 @@ async fn create_storage( } } -#[test_case("memory")] -#[test_case("sqlx")] +#[test_case("memory"; "Memory")] +#[test_case("sqlx"; "Sqlx Postgres")] #[rocket::async_test] async fn basic_operations(storage_case: &str) { let (storage, cleanup_task) = create_storage(storage_case).await; @@ -126,8 +126,8 @@ async fn basic_operations(storage_case: &str) { } } -#[test_case("memory")] -#[test_case("sqlx")] +#[test_case("memory"; "Memory")] +#[test_case("sqlx"; "Sqlx Postgres")] #[rocket::async_test] async fn invalidate_by_identifier(storage_case: &str) { let (storage, cleanup_task) = create_storage(storage_case).await; @@ -162,10 +162,13 @@ async fn invalidate_by_identifier(storage_case: &str) { ); // Invalidate all sessions for user1 - storage - .invalidate_sessions_by_identifier(&"user1".to_string()) - .await - .unwrap(); + assert_eq!( + storage + .invalidate_sessions_by_identifier(&"user1".to_string()) + .await + .unwrap(), + 2 + ); // Verify user1 sessions are gone assert_eq!( @@ -193,8 +196,8 @@ async fn invalidate_by_identifier(storage_case: &str) { } } -#[test_case("memory")] -#[test_case("sqlx")] +#[test_case("memory"; "Memory")] +#[test_case("sqlx"; "Sqlx Postgres")] #[rocket::async_test] async fn delete_single_session(storage_case: &str) { let client = Client::tracked(rocket::build()).await.unwrap(); @@ -243,8 +246,8 @@ async fn delete_single_session(storage_case: &str) { } } -#[test_case("memory")] -#[test_case("sqlx")] +#[test_case("memory"; "Memory")] +#[test_case("sqlx"; "Sqlx Postgres")] #[rocket::async_test] async fn nonexistent_identifier(storage_case: &str) { let (storage, cleanup_task) = create_storage(storage_case).await; @@ -265,10 +268,13 @@ async fn nonexistent_identifier(storage_case: &str) { assert_eq!(session_ids.len(), 0); // Try to invalidate sessions for non-existent identifier (should not error) - storage - .invalidate_sessions_by_identifier(&"nonexistent".to_string()) - .await - .unwrap(); + assert_eq!( + storage + .invalidate_sessions_by_identifier(&"nonexistent".to_string()) + .await + .unwrap(), + 0 + ); storage.shutdown().await.unwrap(); if let Some(task) = cleanup_task { From a35de58a7611034116d5171d50e41f99cda7511a Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sat, 6 Sep 2025 03:21:54 -0400 Subject: [PATCH 13/28] tests for indexed storage operations --- src/guard.rs | 19 ++--- src/session.rs | 17 +++-- src/session_index.rs | 28 ++++---- src/storage/interface.rs | 9 ++- src/storage/memory.rs | 32 +++++++-- src/storage/sqlx.rs | 141 ++++++++++++++++++-------------------- tests/session_indexed.rs | 48 ++++++++++++- tests/storages_indexed.rs | 66 ++++++++++++++++-- 8 files changed, 246 insertions(+), 114 deletions(-) diff --git a/src/guard.rs b/src/guard.rs index 5a7fe51..54ba437 100644 --- a/src/guard.rs +++ b/src/guard.rs @@ -1,7 +1,7 @@ use std::{any::type_name, sync::Mutex}; use rocket::{ - http::{Cookie, CookieJar}, + http::CookieJar, request::{FromRequest, Outcome}, Request, }; @@ -26,17 +26,17 @@ where let fairing = get_fairing::(req.rocket()); let cookie_jar = req.cookies(); + // Use rocket's local cache so that the session data is only fetched once per request let (cached_inner, session_error): &LocalCachedSession = req .local_cache_async(async { - let session_cookie = cookie_jar.get_private(&fairing.options.cookie_name); - get_session_data( - session_cookie, + fetch_session_data( + cookie_jar, + &fairing.options.cookie_name, fairing .options .rolling .then(|| fairing.options.ttl.unwrap_or(fairing.options.max_age)), fairing.storage.as_ref(), - cookie_jar, ) .await }) @@ -66,14 +66,15 @@ where }) } -/// Get session data from storage +/// Fetch session data from storage #[inline(always)] -async fn get_session_data<'r, T: Send + Sync + Clone>( - session_cookie: Option>, +async fn fetch_session_data<'r, T: Send + Sync + Clone>( + cookie_jar: &'r CookieJar<'_>, + cookie_name: &str, rolling_ttl: Option, storage: &'r dyn SessionStorage, - cookie_jar: &'r CookieJar<'_>, ) -> LocalCachedSession { + let session_cookie = cookie_jar.get_private(cookie_name); if let Some(cookie) = session_cookie { let id = cookie.value(); rocket::debug!("Got session id '{}' from cookie. Retrieving session...", id); diff --git a/src/session.rs b/src/session.rs index 8d9e9f6..312632c 100644 --- a/src/session.rs +++ b/src/session.rs @@ -251,9 +251,7 @@ where /// Set the value of a key in the session data. Will create a new session if there isn't one. pub fn set_key(&mut self, key: K, value: V) { self.get_inner_lock().tap_data_mut( - |data| { - data.get_or_insert_default().insert(key, value); - }, + |data| data.get_or_insert_default().insert(key, value), self.get_default_ttl(), ); self.update_cookies(); @@ -265,9 +263,16 @@ where I: IntoIterator, { self.get_inner_lock().tap_data_mut( - |data| { - data.get_or_insert_default().extend(kv_iter); - }, + |data| data.get_or_insert_default().extend(kv_iter), + self.get_default_ttl(), + ); + self.update_cookies(); + } + + /// Remove a key from the session data. + pub fn remove_key(&mut self, key: K) { + self.get_inner_lock().tap_data_mut( + |data| data.get_or_insert_default().remove(&key), self.get_default_ttl(), ); self.update_cookies(); diff --git a/src/session_index.rs b/src/session_index.rs index 751851c..2ca9302 100644 --- a/src/session_index.rs +++ b/src/session_index.rs @@ -68,15 +68,21 @@ where Ok(Some(session_ids)) } - /// Invalidate all sessions with the same identifier as the current session, returning the number of sessions invalidated. - /// Returns `None` if there's no session or the session isn't indexed. - pub async fn invalidate_all_sessions(&self) -> Result, SessionError> { - let Some(identifier) = self.get_identifier() else { + /// Invalidate all sessions with the same identifier as the current session, optionally keeping the current session active. + /// Returns the number of sessions invalidated, or `None` if there's no session or the session isn't indexed. + pub async fn invalidate_all_sessions( + &self, + keep_current: bool, + ) -> Result, SessionError> { + let Some((session_id, identifier)) = self.id().zip(self.get_identifier()) else { return Ok(None); }; let storage = self.get_indexed_storage()?; let num_sessions = storage - .invalidate_sessions_by_identifier(&identifier) + .invalidate_sessions_by_identifier( + &identifier, + keep_current.then_some(session_id.as_str()), + ) .await?; Ok(Some(num_sessions)) @@ -106,16 +112,14 @@ where identifier: &T::Id, ) -> Result { let storage = self.get_indexed_storage()?; - storage.invalidate_sessions_by_identifier(identifier).await + storage + .invalidate_sessions_by_identifier(identifier, None) + .await } - /// Get the current session's identifier + /// Get the current session's identifier, if there is one. fn get_identifier(&self) -> Option { - let identifier = { - let inner = self.get_inner_lock(); - inner.get_current_identifier().cloned() - }; - identifier + self.get_inner_lock().get_current_identifier().cloned() } /// Try to cast the storage as an indexed storage diff --git a/src/storage/interface.rs b/src/storage/interface.rs index 04a3d97..e5ec677 100644 --- a/src/storage/interface.rs +++ b/src/storage/interface.rs @@ -74,6 +74,11 @@ where /// Get all session IDs associated with the given identifier. async fn get_session_ids_by_identifier(&self, id: &T::Id) -> SessionResult>; - /// Remove all sessions associated with the given identifier. Returns the number of sessions removed. - async fn invalidate_sessions_by_identifier(&self, id: &T::Id) -> SessionResult; + /// Remove all sessions associated with the given identifier, optionally excluding one session ID. + /// Returns the number of sessions removed. + async fn invalidate_sessions_by_identifier( + &self, + id: &T::Id, + excluded_session_id: Option<&str>, + ) -> SessionResult; } diff --git a/src/storage/memory.rs b/src/storage/memory.rs index d93dc9d..17868e4 100644 --- a/src/storage/memory.rs +++ b/src/storage/memory.rs @@ -270,18 +270,38 @@ where Ok(session_ids.into_iter().collect()) } - async fn invalidate_sessions_by_identifier(&self, id: &T::Id) -> SessionResult { + async fn invalidate_sessions_by_identifier( + &self, + id: &T::Id, + excluded_session_id: Option<&str>, + ) -> SessionResult { let id_str = id.to_string(); - let session_ids = { - let mut index = self.identifier_index.lock().unwrap(); - index.remove(&id_str).unwrap_or_default() + let mut session_ids_to_remove = { + let index = self.identifier_index.lock().unwrap(); + index.get(&id_str).cloned().unwrap_or_default() }; + if let Some(session_id) = excluded_session_id { + session_ids_to_remove.retain(|id| id != session_id); + } // Remove all sessions from cache - for session_id in &session_ids { + for session_id in &session_ids_to_remove { self.base_storage.cache().remove(session_id).await; } - Ok(session_ids.len() as u64) + // Remove all sessions from index + { + let mut index = self.identifier_index.lock().unwrap(); + if let Some(session_set) = index.get_mut(&id_str) { + for session_id in &session_ids_to_remove { + session_set.remove(session_id); + } + if session_set.is_empty() { + index.remove(&id_str); + } + } + } + + Ok(session_ids_to_remove.len() as u64) } } diff --git a/src/storage/sqlx.rs b/src/storage/sqlx.rs index dffd67d..2c2fe62 100644 --- a/src/storage/sqlx.rs +++ b/src/storage/sqlx.rs @@ -83,28 +83,28 @@ where ) -> SessionResult<(T, u32)> { let row = match ttl { Some(new_ttl) => { - sqlx::query(&format!( - r#" - UPDATE "{}" SET {EXPIRES_COLUMN} = $1 - WHERE {ID_COLUMN} = $2 AND {EXPIRES_COLUMN} > CURRENT_TIMESTAMP - RETURNING {DATA_COLUMN}, {EXPIRES_COLUMN}"#, + let sql = format!( + "UPDATE \"{}\" SET {EXPIRES_COLUMN} = $1 \ + WHERE {ID_COLUMN} = $2 AND {EXPIRES_COLUMN} > CURRENT_TIMESTAMP \ + RETURNING {DATA_COLUMN}, {EXPIRES_COLUMN}", &self.table_name, - )) - .bind(OffsetDateTime::now_utc() + Duration::seconds(new_ttl.into())) - .bind(id) - .fetch_optional(&self.pool) - .await? + ); + sqlx::query(&sql) + .bind(OffsetDateTime::now_utc() + Duration::seconds(new_ttl.into())) + .bind(id) + .fetch_optional(&self.pool) + .await? } None => { - sqlx::query(&format!( - r#" - SELECT {DATA_COLUMN}, {EXPIRES_COLUMN} FROM "{}" - WHERE {ID_COLUMN} = $1 AND {EXPIRES_COLUMN} > CURRENT_TIMESTAMP"#, + let sql = format!( + "SELECT {DATA_COLUMN}, {EXPIRES_COLUMN} FROM \"{}\" \ + WHERE {ID_COLUMN} = $1 AND {EXPIRES_COLUMN} > CURRENT_TIMESTAMP", &self.table_name, - )) - .bind(id) - .fetch_optional(&self.pool) - .await? + ); + sqlx::query(&sql) + .bind(id) + .fetch_optional(&self.pool) + .await? } }; @@ -123,35 +123,29 @@ where } async fn save(&self, id: &str, data: T, ttl: u32) -> SessionResult<()> { - sqlx::query(&format!( - r#" - INSERT INTO "{}" ({ID_COLUMN}, {}, {DATA_COLUMN}, {EXPIRES_COLUMN}) - VALUES ($1, $2, $3, $4) - ON CONFLICT ({ID_COLUMN}) DO UPDATE SET - {DATA_COLUMN} = EXCLUDED.{DATA_COLUMN}, - {EXPIRES_COLUMN} = EXCLUDED.{EXPIRES_COLUMN} - "#, + let sql = format!( + "INSERT INTO \"{}\" ({ID_COLUMN}, {}, {DATA_COLUMN}, {EXPIRES_COLUMN}) \ + VALUES ($1, $2, $3, $4) \ + ON CONFLICT ({ID_COLUMN}) DO UPDATE SET \ + {DATA_COLUMN} = EXCLUDED.{DATA_COLUMN}, \ + {EXPIRES_COLUMN} = EXCLUDED.{EXPIRES_COLUMN}", self.table_name, T::NAME - )) - .bind(id) - .bind(data.identifier()) - .bind(data.to_string()) - .bind(OffsetDateTime::now_utc() + Duration::seconds(ttl.into())) - .execute(&self.pool) - .await?; + ); + sqlx::query(&sql) + .bind(id) + .bind(data.identifier()) + .bind(data.to_string()) + .bind(OffsetDateTime::now_utc() + Duration::seconds(ttl.into())) + .execute(&self.pool) + .await?; Ok(()) } async fn delete(&self, id: &str, _cookie_jar: &CookieJar) -> SessionResult<()> { - sqlx::query(&format!( - "DELETE FROM {} WHERE {ID_COLUMN} = $1", - &self.table_name - )) - .bind(id) - .execute(&self.pool) - .await?; + let sql = format!("DELETE FROM {} WHERE {ID_COLUMN} = $1", &self.table_name); + sqlx::query(&sql).bind(id).execute(&self.pool).await?; Ok(()) } @@ -198,12 +192,11 @@ where async fn cleanup_expired_sessions(table_name: &str, pool: &PgPool) -> Result { rocket::debug!("Cleaning up expired sessions"); - let rows = sqlx::query(&format!( - "DELETE FROM {table_name} WHERE {EXPIRES_COLUMN} < $1" - )) - .bind(OffsetDateTime::now_utc()) - .execute(pool) - .await?; + let sql = format!("DELETE FROM \"{table_name}\" WHERE {EXPIRES_COLUMN} < $1"); + let rows = sqlx::query(&sql) + .bind(OffsetDateTime::now_utc()) + .execute(pool) + .await?; Ok(rows.rows_affected()) } @@ -216,17 +209,13 @@ where >::Error: std::error::Error + Send + Sync + 'static, { async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult> { - let rows = sqlx::query(&format!( - r#" - SELECT id, data FROM "{}" - WHERE {} = $1 AND expires > CURRENT_TIMESTAMP"#, + let sql = format!( + "SELECT {ID_COLUMN}, {DATA_COLUMN} FROM \"{}\" \ + WHERE {} = $1 AND {EXPIRES_COLUMN} > CURRENT_TIMESTAMP", &self.table_name, T::NAME - )) - .bind(id) - .fetch_all(&self.pool) - .await?; - + ); + let rows = sqlx::query(&sql).bind(id).fetch_all(&self.pool).await?; let parsed_rows = rows .into_iter() .filter_map(|row| { @@ -236,39 +225,45 @@ where Some((id, data)) }) .collect(); + Ok(parsed_rows) } async fn get_session_ids_by_identifier(&self, id: &T::Id) -> SessionResult> { - let rows = sqlx::query(&format!( - r#" - SELECT id FROM "{}" - WHERE {} = $1 AND expires > CURRENT_TIMESTAMP"#, + let sql = format!( + "SELECT {ID_COLUMN} FROM \"{}\" \ + WHERE {} = $1 AND {EXPIRES_COLUMN} > CURRENT_TIMESTAMP", &self.table_name, T::NAME - )) - .bind(id) - .fetch_all(&self.pool) - .await?; - + ); + let rows = sqlx::query(&sql).bind(id).fetch_all(&self.pool).await?; let parsed_rows = rows .into_iter() .filter_map(|row| row.try_get(0).ok()) .collect(); + Ok(parsed_rows) } - async fn invalidate_sessions_by_identifier(&self, id: &T::Id) -> SessionResult { - let rows = sqlx::query(&format!( - r#" - DELETE FROM "{}" - WHERE {} = $1"#, + async fn invalidate_sessions_by_identifier( + &self, + id: &T::Id, + excluded_session_id: Option<&str>, + ) -> SessionResult { + let mut sql = format!( + "DELETE FROM \"{}\" WHERE {} = $1", &self.table_name, T::NAME - )) - .bind(id) - .execute(&self.pool) - .await?; + ); + if excluded_session_id.is_some() { + sql.push_str(&format!(" AND {ID_COLUMN} != $2")); + } + + let mut query = sqlx::query(&sql).bind(id); + if let Some(excluded_id) = excluded_session_id { + query = query.bind(excluded_id); + } + let rows = query.execute(&self.pool).await?; Ok(rows.rows_affected()) } diff --git a/tests/session_indexed.rs b/tests/session_indexed.rs index f312d5e..9ec05c7 100644 --- a/tests/session_indexed.rs +++ b/tests/session_indexed.rs @@ -79,7 +79,16 @@ async fn get_sessions_for_user(session: Session<'_, UserSession>, user_id: Strin #[get("/user/invalidate-all")] async fn invalidate_all_user_sessions(session: Session<'_, UserSession>) -> String { - match session.invalidate_all_sessions().await { + match session.invalidate_all_sessions(false).await { + Ok(Some(n)) => format!("{n} session(s) for current user invalidated."), + Ok(None) => "No current session".to_string(), + Err(e) => format!("Error invalidating sessions: {e}"), + } +} + +#[get("/user/invalidate-other")] +async fn invalidate_other_user_sessions(session: Session<'_, UserSession>) -> String { + match session.invalidate_all_sessions(true).await { Ok(Some(n)) => format!("{n} session(s) for current user invalidated."), Ok(None) => "No current session".to_string(), Err(e) => format!("Error invalidating sessions: {e}"), @@ -137,6 +146,7 @@ fn rocket() -> Rocket { get_user_sessions, get_sessions_for_user, invalidate_all_user_sessions, + invalidate_other_user_sessions, invalidate_sessions_for_user, get_user_session_ids, user_profile, @@ -239,6 +249,42 @@ mod tests { assert_eq!(response.into_string().unwrap(), "No active session"); } + #[test] + fn test_invalidate_other_sessions() { + let client = create_test_client(); + + let response = client.get("/user/login/user1/alice").dispatch(); + assert_eq!(response.status(), Status::Ok); + + // Create two more sessions for the same user, simulating a different device + let response = client + .get("/user/login/user1/alice") + .private_cookie("rocket") // empty cookie + .dispatch(); + assert_eq!(response.status(), Status::Ok); + let response = client + .get("/user/login/user1/alice") + .private_cookie("rocket") + .dispatch(); + assert_eq!(response.status(), Status::Ok); + + // Invalidate all sessions except the current one + let response = client.get("/user/invalidate-other").dispatch(); + assert_eq!(response.status(), Status::Ok); + assert_eq!( + response.into_string().unwrap(), + "2 session(s) for current user invalidated." + ); + + // Profile should still show active session + let response = client.get("/user/profile").dispatch(); + assert_eq!(response.status(), Status::Ok); + assert!(response + .into_string() + .unwrap() + .contains("Profile for alice")); + } + #[test] fn test_invalidate_sessions_by_user_id() { let client = create_test_client(); diff --git a/tests/storages_indexed.rs b/tests/storages_indexed.rs index 2fa1c4d..bec17b4 100644 --- a/tests/storages_indexed.rs +++ b/tests/storages_indexed.rs @@ -164,7 +164,7 @@ async fn invalidate_by_identifier(storage_case: &str) { // Invalidate all sessions for user1 assert_eq!( storage - .invalidate_sessions_by_identifier(&"user1".to_string()) + .invalidate_sessions_by_identifier(&"user1".to_string(), None) .await .unwrap(), 2 @@ -186,9 +186,65 @@ async fn invalidate_by_identifier(storage_case: &str) { .await .unwrap(); assert_eq!(user2_sessions.len(), 1); - assert!(user2_sessions - .iter() - .any(|(id, data)| id == "sid3" && data == &session3)); + assert_eq!(user2_sessions[0], ("sid3".to_string(), session3)); + + storage.shutdown().await.unwrap(); + if let Some(task) = cleanup_task { + task.await + } +} + +#[test_case("memory"; "Memory")] +#[test_case("sqlx"; "Sqlx Postgres")] +#[rocket::async_test] +async fn invalidate_all_but_one_by_identifier(storage_case: &str) { + let (storage, cleanup_task) = create_storage(storage_case).await; + storage.setup().await.unwrap(); + + let session1 = TestSession { + user_id: "user1".to_string(), + data: "session1_data".to_string(), + }; + let session2 = TestSession { + user_id: "user1".to_string(), + data: "session2_data".to_string(), + }; + let session3 = TestSession { + user_id: "user1".to_string(), + data: "session3_data".to_string(), + }; + + // Save sessions + storage.save("sid1", session1, 3600).await.unwrap(); + storage.save("sid2", session2, 3600).await.unwrap(); + storage.save("sid3", session3.clone(), 3600).await.unwrap(); + + // Verify sessions exist + assert_eq!( + storage + .get_sessions_by_identifier(&"user1".to_string()) + .await + .unwrap() + .len(), + 3 + ); + + // Invalidate all sessions for user1 except the last one + assert_eq!( + storage + .invalidate_sessions_by_identifier(&"user1".to_string(), Some("sid3")) + .await + .unwrap(), + 2 + ); + + // Verify the last user1 session still exists + let user1_sessions = storage + .get_sessions_by_identifier(&"user1".to_string()) + .await + .unwrap(); + assert_eq!(user1_sessions.len(), 1); + assert_eq!(user1_sessions[0], ("sid3".to_string(), session3)); storage.shutdown().await.unwrap(); if let Some(task) = cleanup_task { @@ -270,7 +326,7 @@ async fn nonexistent_identifier(storage_case: &str) { // Try to invalidate sessions for non-existent identifier (should not error) assert_eq!( storage - .invalidate_sessions_by_identifier(&"nonexistent".to_string()) + .invalidate_sessions_by_identifier(&"nonexistent".to_string(), None) .await .unwrap(), 0 From b5dd901bf39841f309339d9f162839cdbef5ae79 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sat, 6 Sep 2025 03:50:33 -0400 Subject: [PATCH 14/28] fix docs --- src/lib.rs | 4 ++-- src/storage/redis.rs | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 7a2be4f..5784e66 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -157,8 +157,8 @@ async fn get_all_user_sessions(session: Session<'_, UserSession>) -> String { #[rocket::get("/user/logout-everywhere")] async fn logout_everywhere(session: Session<'_, UserSession>) -> String { - match session.invalidate_all_sessions().await { - Ok(Some(())) => "Logged out from all devices".to_string(), + match session.invalidate_all_sessions(false).await { + Ok(Some(n)) => format!("Logged out from {n} sessions"), Ok(None) => "No active session".to_string(), Err(e) => format!("Error: {}", e), } diff --git a/src/storage/redis.rs b/src/storage/redis.rs index 75f4208..ae958e3 100644 --- a/src/storage/redis.rs +++ b/src/storage/redis.rs @@ -50,11 +50,10 @@ async fn setup_storage() -> RedisFredStorage { storage } -// If using a custom struct, implement the following... +// If using a custom struct for your session data, implement the following... struct MySessionData { user_id: String, } - // Implement `FromValue` to convert from the Redis value to your session data type impl FromValue for MySessionData { fn from_value(value: Value) -> Result { From 6a7bc957e86a2af14b400934eaf750380806d900 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sun, 7 Sep 2025 01:08:43 -0400 Subject: [PATCH 15/28] add fred.rs with indexing support --- Cargo.toml | 1 + src/session_index.rs | 2 +- src/storage/interface.rs | 8 +- src/storage/memory.rs | 11 +- src/storage/redis.rs | 217 ++++++++++++++++++++--- src/storage/sqlx.rs | 25 ++- tests/common/mod.rs | 37 +++- tests/session_indexed.rs | 4 +- tests/{storages.rs => storages_basic.rs} | 38 ++-- tests/storages_indexed.rs | 48 ++++- 10 files changed, 313 insertions(+), 78 deletions(-) rename tests/{storages.rs => storages_basic.rs} (85%) diff --git a/Cargo.toml b/Cargo.toml index e0bdd17..7184002 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ rustdoc-args = ["--cfg", "docsrs"] fred = { version = "10.1", optional = true, default-features = false, features = [ "i-keys", "i-hashes", + "i-sets", ] } rand = "0.8" retainer = "0.3" diff --git a/src/session_index.rs b/src/session_index.rs index 2ca9302..9242d56 100644 --- a/src/session_index.rs +++ b/src/session_index.rs @@ -28,7 +28,7 @@ use crate::{error::SessionError, storage::SessionStorageIndexed, Session}; /// ``` pub trait SessionIdentifier { /// The name of the identifier (default: `"user_id"`), that may be used as a field/key name by the storage backend. - const NAME: &str = "user_id"; + const IDENTIFIER: &str = "user_id"; /// The type of the identifier type Id: Send + Sync + Clone; diff --git a/src/storage/interface.rs b/src/storage/interface.rs index e5ec677..f85305d 100644 --- a/src/storage/interface.rs +++ b/src/storage/interface.rs @@ -68,14 +68,14 @@ pub trait SessionStorageIndexed: SessionStorage where T: SessionIdentifier + Send + Sync, { - /// Retrieve all session IDs and data for the given identifier. + /// Retrieve all tracked session IDs and data for the given identifier. async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult>; - /// Get all session IDs associated with the given identifier. + /// Get all tracked session IDs associated with the given identifier. async fn get_session_ids_by_identifier(&self, id: &T::Id) -> SessionResult>; - /// Remove all sessions associated with the given identifier, optionally excluding one session ID. - /// Returns the number of sessions removed. + /// Invalidate all tracked sessions associated with the given identifier, optionally excluding one session ID. + /// Returns the number of sessions invalidated. async fn invalidate_sessions_by_identifier( &self, id: &T::Id, diff --git a/src/storage/memory.rs b/src/storage/memory.rs index 17868e4..74e18d2 100644 --- a/src/storage/memory.rs +++ b/src/storage/memory.rs @@ -101,13 +101,6 @@ where } } -impl MemoryStorage { - /// Get access to the underlying cache for indexed operations - pub(crate) fn cache(&self) -> &Cache { - &self.cache - } -} - /// Extended in-memory storage that supports session indexing by identifier. /// This allows for operations like retrieving all sessions for a user or /// bulk invalidation of sessions. @@ -252,7 +245,7 @@ where let mut sessions: Vec<(String, T)> = Vec::new(); for session_id in session_ids { - if let Some(data) = self.base_storage.cache().get(&session_id).await { + if let Some(data) = self.base_storage.cache.get(&session_id).await { sessions.push((session_id, data.value().to_owned())); } } @@ -286,7 +279,7 @@ where // Remove all sessions from cache for session_id in &session_ids_to_remove { - self.base_storage.cache().remove(session_id).await; + self.base_storage.cache.remove(session_id).await; } // Remove all sessions from index diff --git a/src/storage/redis.rs b/src/storage/redis.rs index ae958e3..a541013 100644 --- a/src/storage/redis.rs +++ b/src/storage/redis.rs @@ -1,15 +1,20 @@ //! Session storage with Redis (and Redis-compatible databases) use fred::{ - prelude::{FromValue, HashesInterface, KeysInterface, Pool, Value}, + prelude::{FromValue, HashesInterface, KeysInterface, Pool, SetsInterface, Value}, types::Expiration, }; use rocket::{async_trait, http::CookieJar}; -use crate::error::{SessionError, SessionResult}; +use crate::{ + error::{SessionError, SessionResult}, + storage::SessionStorageIndexed, + SessionIdentifier, +}; use super::interface::SessionStorage; +/// The Redis type to use for the session data #[derive(Debug)] pub enum RedisType { String, @@ -76,6 +81,7 @@ pub struct RedisFredStorage { prefix: String, redis_type: RedisType, } + impl RedisFredStorage { pub fn new(pool: Pool, redis_type: RedisType, key_prefix: &str) -> Self { Self { @@ -88,20 +94,16 @@ impl RedisFredStorage { fn key(&self, id: &str) -> String { format!("{}{id}", self.prefix) } -} -#[async_trait] -impl SessionStorage for RedisFredStorage -where - T: FromValue + TryInto + Clone + Send + Sync + 'static, - >::Error: std::error::Error + Send + Sync + 'static, -{ - async fn load( - &self, - id: &str, - ttl: Option, - _cookie_jar: &CookieJar, - ) -> SessionResult<(T, u32)> { + fn session_index_key(&self, identifier_name: &str, identifier: &impl ToString) -> String { + format!( + "{}{identifier_name}:{}", + self.prefix, + identifier.to_string() + ) + } + + async fn fetch_session(&self, id: &str, ttl: Option) -> SessionResult<(Value, u32)> { let key = self.key(id); let pipeline = self.pool.next().pipeline(); let _: () = match self.redis_type { @@ -121,16 +123,11 @@ where }; let found_value = value.ok_or(SessionError::NotFound)?; - let data = T::from_value(found_value)?; - - Ok((data, ttl.unwrap_or(orig_ttl.try_into().unwrap_or(0)))) + Ok((found_value, ttl.unwrap_or(orig_ttl.try_into().unwrap_or(0)))) } - async fn save(&self, id: &str, data: T, ttl: u32) -> SessionResult<()> { + async fn save_session(&self, id: &str, value: Value, ttl: u32) -> SessionResult<()> { let key = self.key(id); - let value: Value = data - .try_into() - .map_err(|e| SessionError::Serialization(Box::new(e)))?; let _: () = match self.redis_type { RedisType::String => { self.pool @@ -152,9 +149,185 @@ where }; Ok(()) } +} + +#[async_trait] +impl SessionStorage for RedisFredStorage +where + T: FromValue + TryInto + Clone + Send + Sync + 'static, + >::Error: std::error::Error + Send + Sync + 'static, +{ + async fn load( + &self, + id: &str, + ttl: Option, + _cookie_jar: &CookieJar, + ) -> SessionResult<(T, u32)> { + let (value, ttl) = self.fetch_session(id, ttl).await?; + let data = T::from_value(value)?; + Ok((data, ttl)) + } + + async fn save(&self, id: &str, data: T, ttl: u32) -> SessionResult<()> { + let value: Value = data + .try_into() + .map_err(|e| SessionError::Serialization(Box::new(e)))?; + self.save_session(id, value, ttl).await?; + Ok(()) + } async fn delete(&self, id: &str, _cookie_jar: &CookieJar) -> SessionResult<()> { let _: u8 = self.pool.del(self.key(id)).await?; Ok(()) } } + +/// Redis session storage using the [fred.rs](https://docs.rs/fred) crate. This is a wrapper around +/// [`RedisFredStorage`] that adds support for indexing sessions by an identifier (e.g. `user_id`). +/// +/// In addition to the requirements for [`RedisFredStorage`], your session data type must +/// implement [`SessionIdentifier`], and its [Id](`SessionIdentifier::Id`) type +/// must implement [`ToString`]. Sessions are tracked in Redis sets, with a key format of +/// `:`. e.g.: `sess:user_id:1` +pub struct RedisFredStorageIndexed { + base_storage: RedisFredStorage, +} + +impl RedisFredStorageIndexed { + pub fn new(base_storage: RedisFredStorage) -> Self { + Self { base_storage } + } +} + +#[async_trait] +impl SessionStorage for RedisFredStorageIndexed +where + T: SessionIdentifier + FromValue + TryInto + Clone + Send + Sync + 'static, + >::Error: std::error::Error + Send + Sync + 'static, + ::Id: ToString, +{ + async fn load( + &self, + id: &str, + ttl: Option, + _cookie_jar: &CookieJar, + ) -> SessionResult<(T, u32)> { + let (value, ttl) = self.base_storage.fetch_session(id, ttl).await?; + let data = T::from_value(value)?; + Ok((data, ttl)) + } + + async fn save(&self, id: &str, data: T, ttl: u32) -> SessionResult<()> { + if let Some(identifier) = data.identifier() { + let session_idx_key = self + .base_storage + .session_index_key(T::IDENTIFIER, identifier); + let pipeline = self.base_storage.pool.next().pipeline(); + let _: () = pipeline.sadd(&session_idx_key, id).await?; + let _: () = pipeline.expire(&session_idx_key, ttl.into(), None).await?; + let _: () = pipeline.all().await?; + } + + let value: Value = data + .try_into() + .map_err(|e| SessionError::Serialization(Box::new(e)))?; + self.base_storage.save_session(id, value, ttl).await + } + + async fn delete(&self, id: &str, _cookie_jar: &CookieJar) -> SessionResult<()> { + let (value, _) = self.base_storage.fetch_session(id, None).await?; + let data = T::from_value(value)?; + + let pipeline = self.base_storage.pool.next().pipeline(); + let _: () = pipeline.del(self.base_storage.key(id)).await?; + if let Some(identifier) = data.identifier() { + let session_idx_key = self + .base_storage + .session_index_key(T::IDENTIFIER, identifier); + let _: () = pipeline.srem(&session_idx_key, id).await?; + } + Ok(pipeline.all().await?) + } +} + +#[async_trait] +impl SessionStorageIndexed for RedisFredStorageIndexed +where + T: SessionIdentifier + FromValue + TryInto + Clone + Send + Sync + 'static, + >::Error: std::error::Error + Send + Sync + 'static, + ::Id: ToString, +{ + async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult> { + let session_index_key = self.base_storage.session_index_key(T::IDENTIFIER, id); + let session_ids: Vec = self.base_storage.pool.smembers(&session_index_key).await?; + + let session_value_pipeline = self.base_storage.pool.next().pipeline(); + for session_id in &session_ids { + let session_key = self.base_storage.key(&session_id); + let _: () = match self.base_storage.redis_type { + RedisType::String => session_value_pipeline.get(&session_key).await?, + RedisType::Hash => session_value_pipeline.hgetall(&session_key).await?, + }; + } + let session_values: Vec> = session_value_pipeline.all().await?; + + let sessions = session_values + .into_iter() + .enumerate() + .filter_map(|(idx, value)| { + value.and_then(|value| { + let session_id = session_ids.get(idx)?.clone(); + let data = T::from_value(value).ok()?; + Some((session_id, data)) + }) + }) + .collect(); + Ok(sessions) + } + + async fn get_session_ids_by_identifier(&self, id: &T::Id) -> SessionResult> { + let session_index_key = self.base_storage.session_index_key(T::IDENTIFIER, id); + let session_ids: Vec = self.base_storage.pool.smembers(&session_index_key).await?; + + let session_exist_pipeline = self.base_storage.pool.next().pipeline(); + for session_id in &session_ids { + let session_key = self.base_storage.key(&session_id); + let _: () = session_exist_pipeline.exists(&session_key).await?; + } + let session_exist_results: Vec = session_exist_pipeline.all().await?; + + let existing_sessions = session_ids + .into_iter() + .enumerate() + .filter_map(|(idx, id)| session_exist_results.get(idx)?.then_some(id)) + .collect(); + Ok(existing_sessions) + } + + async fn invalidate_sessions_by_identifier( + &self, + id: &T::Id, + excluded_session_id: Option<&str>, + ) -> SessionResult { + let session_index_key = self.base_storage.session_index_key(T::IDENTIFIER, id); + let mut session_ids: Vec = + self.base_storage.pool.smembers(&session_index_key).await?; + if let Some(excluded_id) = excluded_session_id { + session_ids.retain(|id| id != excluded_id); + } + if session_ids.is_empty() { + return Ok(0); + } + + let session_keys: Vec<_> = session_ids + .iter() + .map(|id| self.base_storage.key(id)) + .collect(); + let delete_pipeline = self.base_storage.pool.next().pipeline(); + let _: () = delete_pipeline.del(session_keys).await?; + let _: () = delete_pipeline.srem(session_index_key, session_ids).await?; + let (del_num, _srem_num): (u64, u64) = delete_pipeline.all().await?; + + Ok(del_num) + } +} diff --git a/src/storage/sqlx.rs b/src/storage/sqlx.rs index 2c2fe62..2e7956b 100644 --- a/src/storage/sqlx.rs +++ b/src/storage/sqlx.rs @@ -1,4 +1,4 @@ -//! Session storage in PostgreSQL via sqlx +//! Session storage via sqlx use rocket::{ async_trait, @@ -130,7 +130,7 @@ where {DATA_COLUMN} = EXCLUDED.{DATA_COLUMN}, \ {EXPIRES_COLUMN} = EXCLUDED.{EXPIRES_COLUMN}", self.table_name, - T::NAME + T::IDENTIFIER ); sqlx::query(&sql) .bind(id) @@ -154,24 +154,23 @@ where let Some(cleanup_interval) = self.cleanup_interval else { return Ok(()); }; - let (tx, rx) = oneshot::channel(); + let (tx, mut rx) = oneshot::channel(); let pool = self.pool.clone(); let table_name = self.table_name.clone(); tokio::spawn(async move { rocket::info!("Starting session cleanup monitor"); let mut interval = interval(cleanup_interval); - tokio::select! { - _ = async { - loop { - interval.tick().await; + loop { + tokio::select! { + _ = interval.tick() => { rocket::debug!("Cleaning up expired sessions"); if let Err(e) = cleanup_expired_sessions(&table_name, &pool).await { rocket::error!("Error deleting expired sessions: {e}"); } } - } => (), - _ = rx => { - rocket::info!("Session cleanup monitor shutdown"); + _ = &mut rx => { + rocket::info!("Session cleanup monitor shutdown"); + } } } }); @@ -213,7 +212,7 @@ where "SELECT {ID_COLUMN}, {DATA_COLUMN} FROM \"{}\" \ WHERE {} = $1 AND {EXPIRES_COLUMN} > CURRENT_TIMESTAMP", &self.table_name, - T::NAME + T::IDENTIFIER ); let rows = sqlx::query(&sql).bind(id).fetch_all(&self.pool).await?; let parsed_rows = rows @@ -234,7 +233,7 @@ where "SELECT {ID_COLUMN} FROM \"{}\" \ WHERE {} = $1 AND {EXPIRES_COLUMN} > CURRENT_TIMESTAMP", &self.table_name, - T::NAME + T::IDENTIFIER ); let rows = sqlx::query(&sql).bind(id).fetch_all(&self.pool).await?; let parsed_rows = rows @@ -253,7 +252,7 @@ where let mut sql = format!( "DELETE FROM \"{}\" WHERE {} = $1", &self.table_name, - T::NAME + T::IDENTIFIER ); if excluded_session_id.is_some() { sql.push_str(&format!(" AND {ID_COLUMN} != $2")); diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 9a14e54..d2413f3 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -1,15 +1,17 @@ +use fred::prelude::{ClientLike, KeysInterface, ReconnectPolicy}; use sqlx::{Connection, PgPool}; pub const POSTGRES_URL: &str = "postgres://postgres:postgres@localhost"; +fn random_string(n: usize) -> String { + (0..n) + .map(|_| (b'a' + (rand::random::() % 26)) as char) + .collect() +} + /// Setup a test Postgres database pub async fn setup_postgres(base_url: &str) -> (PgPool, String) { - let db_name = format!( - "test_{}", - (0..6) - .map(|_| (b'a' + (rand::random::() % 26)) as char) - .collect::() - ); + let db_name = format!("test_{}", random_string(6)); let mut cxn = sqlx::PgConnection::connect(base_url).await.unwrap(); sqlx::query(&format!("CREATE DATABASE {}", db_name)) .execute(&mut cxn) @@ -43,3 +45,26 @@ pub async fn teardown_postgres(pool: sqlx::Pool, db_name: String .await .expect("Should drop test database"); } + +pub async fn setup_redis_fred() -> (fred::prelude::Pool, String) { + let pool = fred::prelude::Builder::default_centralized() + .set_policy(ReconnectPolicy::new_linear(3, 5, 1)) + .with_performance_config(|c| c.default_command_timeout = std::time::Duration::from_secs(5)) + .build_pool(3) + .expect("Should build Redis pool"); + pool.init().await.expect("Should initialize Redis pool"); + let prefix = format!("test_{}:sess:", random_string(6)); + + (pool, prefix) +} + +pub async fn teardown_redis_fred(pool: fred::prelude::Pool, prefix: String) { + let (_cursor, keys): (String, Vec) = pool + .scan_page("0", format!("{prefix}*"), Some(50), None) + .await + .expect("Should scan keys"); + if !keys.is_empty() { + let _: () = pool.del(keys).await.expect("Should delete keys"); + } + pool.quit().await.expect("Should quit Redis pool"); +} diff --git a/tests/session_indexed.rs b/tests/session_indexed.rs index 9ec05c7..ef1b131 100644 --- a/tests/session_indexed.rs +++ b/tests/session_indexed.rs @@ -15,7 +15,7 @@ struct UserSession { } impl SessionIdentifier for UserSession { - const NAME: &str = "user_id"; + const IDENTIFIER: &str = "user_id"; type Id = String; fn identifier(&self) -> Option<&Self::Id> { @@ -31,7 +31,7 @@ struct AdminSession { } impl SessionIdentifier for AdminSession { - const NAME: &str = "admin_id"; + const IDENTIFIER: &str = "admin_id"; type Id = String; fn identifier(&self) -> Option<&Self::Id> { diff --git a/tests/storages.rs b/tests/storages_basic.rs similarity index 85% rename from tests/storages.rs rename to tests/storages_basic.rs index 5b24cf6..4a14028 100644 --- a/tests/storages.rs +++ b/tests/storages_basic.rs @@ -5,7 +5,6 @@ extern crate rocket; use std::{future::Future, pin::Pin}; -use fred::prelude::{ClientLike, ReconnectPolicy}; use rocket::{ futures::FutureExt, http::Status, local::asynchronous::Client, tokio::time::sleep, Build, Rocket, @@ -14,7 +13,7 @@ use rocket_flex_session::{ error::SessionError, storage::{ cookie::CookieStorage, - redis::{RedisFredStorage, RedisType}, + redis::{RedisFredStorage, RedisFredStorageIndexed, RedisType}, sqlx::SqlxPostgresStorage, }, RocketFlexSession, Session, SessionIdentifier, @@ -22,7 +21,9 @@ use rocket_flex_session::{ use serde::{Deserialize, Serialize}; use test_case::test_case; -use crate::common::{setup_postgres, teardown_postgres, POSTGRES_URL}; +use crate::common::{ + setup_postgres, setup_redis_fred, teardown_postgres, teardown_redis_fred, POSTGRES_URL, +}; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] struct SessionData { @@ -59,7 +60,7 @@ impl From for fred::types::Value { } } impl SessionIdentifier for SessionData { - const NAME: &str = "user_id"; + const IDENTIFIER: &str = "user_id"; type Id = String; fn identifier(&self) -> Option<&Self::Id> { Some(&self.user_id) @@ -107,22 +108,22 @@ async fn create_rocket( None, ), "redis" => { - let pool = fred::prelude::Builder::default_centralized() - .set_policy(ReconnectPolicy::new_linear(3, 5, 1)) - .with_performance_config(|c| { - c.default_command_timeout = std::time::Duration::from_secs(5) - }) - .build_pool(3) - .expect("Should build Redis pool"); - pool.init().await.expect("Should initialize Redis pool"); - let storage = RedisFredStorage::new(pool.clone(), RedisType::String, "sess:"); + let (pool, prefix) = setup_redis_fred().await; + let storage = RedisFredStorage::new(pool.clone(), RedisType::String, &prefix); + let fairing = RocketFlexSession::::builder() + .storage(storage) + .build(); + let cleanup_task = teardown_redis_fred(pool, prefix).boxed(); + (fairing, Some(cleanup_task)) + } + "redis_indexed" => { + let (pool, prefix) = setup_redis_fred().await; + let base_storage = RedisFredStorage::new(pool.clone(), RedisType::String, &prefix); + let storage = RedisFredStorageIndexed::new(base_storage); let fairing = RocketFlexSession::::builder() .storage(storage) .build(); - let cleanup_task = async move { - pool.quit().await.ok(); - } - .boxed(); + let cleanup_task = teardown_redis_fred(pool, prefix).boxed(); (fairing, Some(cleanup_task)) } "sqlx" => { @@ -146,7 +147,8 @@ async fn create_rocket( } #[test_case("cookie"; "Cookie")] -#[test_case("redis"; "Fred Redis")] +#[test_case("redis"; "Redis Fred")] +#[test_case("redis_indexed"; "Redis Fred Indexed")] #[test_case("sqlx"; "Sqlx Postgres")] #[rocket::async_test] async fn test_storages(storage_case: &str) { diff --git a/tests/storages_indexed.rs b/tests/storages_indexed.rs index bec17b4..f838da8 100644 --- a/tests/storages_indexed.rs +++ b/tests/storages_indexed.rs @@ -1,15 +1,22 @@ mod common; -use std::{future::Future, pin::Pin}; +use std::{collections::HashMap, future::Future, pin::Pin}; use rocket::{futures::FutureExt, local::asynchronous::Client}; use rocket_flex_session::{ - storage::{memory::MemoryStorageIndexed, sqlx::SqlxPostgresStorage, SessionStorageIndexed}, + storage::{ + memory::MemoryStorageIndexed, + redis::{RedisFredStorage, RedisFredStorageIndexed, RedisType}, + sqlx::SqlxPostgresStorage, + SessionStorageIndexed, + }, SessionIdentifier, }; use test_case::test_case; -use crate::common::{setup_postgres, teardown_postgres, POSTGRES_URL}; +use crate::common::{ + setup_postgres, setup_redis_fred, teardown_postgres, teardown_redis_fred, POSTGRES_URL, +}; #[derive(Clone, Debug, PartialEq)] struct TestSession { @@ -23,6 +30,8 @@ impl SessionIdentifier for TestSession { Some(&self.user_id) } } + +// Impls for Sqlx impl ToString for TestSession { fn to_string(&self) -> String { format!("{}:{}", self.user_id, self.data) @@ -43,6 +52,27 @@ impl TryFrom for TestSession { } } +// Impls for fred.rs Redis +const USER_ID_KEY: fred::prelude::Key = fred::types::Key::from_static_str("user_id"); +const DATA_KEY: fred::prelude::Key = fred::types::Key::from_static_str("data"); +impl fred::types::FromValue for TestSession { + fn from_value(value: fred::prelude::Value) -> Result { + let mut map = value.into_map()?; + Ok(Self { + user_id: map.remove(&USER_ID_KEY).unwrap().convert()?, + data: map.remove(&DATA_KEY).unwrap().convert()?, + }) + } +} +impl From for fred::types::Value { + fn from(value: TestSession) -> Self { + let hash: HashMap = + HashMap::from([(USER_ID_KEY, value.user_id), (DATA_KEY, value.data)]); + let fred_map = fred::types::Map::try_from(hash).unwrap(); + fred::types::Value::Map(fred_map) + } +} + async fn create_storage( storage_case: &str, ) -> ( @@ -54,6 +84,13 @@ async fn create_storage( let storage = MemoryStorageIndexed::::default(); (Box::new(storage), None) } + "redis" => { + let (pool, prefix) = setup_redis_fred().await; + let base_storage = RedisFredStorage::new(pool.clone(), RedisType::Hash, &prefix); + let storage = RedisFredStorageIndexed::new(base_storage); + let cleanup_task = teardown_redis_fred(pool, prefix).boxed(); + (Box::new(storage), Some(cleanup_task)) + } "sqlx" => { let (pool, db_name) = setup_postgres(POSTGRES_URL).await; let storage = SqlxPostgresStorage::new(pool.clone(), "sessions", None); @@ -66,6 +103,7 @@ async fn create_storage( #[test_case("memory"; "Memory")] #[test_case("sqlx"; "Sqlx Postgres")] +#[test_case("redis"; "Redis Fred")] #[rocket::async_test] async fn basic_operations(storage_case: &str) { let (storage, cleanup_task) = create_storage(storage_case).await; @@ -128,6 +166,7 @@ async fn basic_operations(storage_case: &str) { #[test_case("memory"; "Memory")] #[test_case("sqlx"; "Sqlx Postgres")] +#[test_case("redis"; "Redis Fred")] #[rocket::async_test] async fn invalidate_by_identifier(storage_case: &str) { let (storage, cleanup_task) = create_storage(storage_case).await; @@ -196,6 +235,7 @@ async fn invalidate_by_identifier(storage_case: &str) { #[test_case("memory"; "Memory")] #[test_case("sqlx"; "Sqlx Postgres")] +#[test_case("redis"; "Redis Fred")] #[rocket::async_test] async fn invalidate_all_but_one_by_identifier(storage_case: &str) { let (storage, cleanup_task) = create_storage(storage_case).await; @@ -254,6 +294,7 @@ async fn invalidate_all_but_one_by_identifier(storage_case: &str) { #[test_case("memory"; "Memory")] #[test_case("sqlx"; "Sqlx Postgres")] +#[test_case("redis"; "Redis Fred")] #[rocket::async_test] async fn delete_single_session(storage_case: &str) { let client = Client::tracked(rocket::build()).await.unwrap(); @@ -304,6 +345,7 @@ async fn delete_single_session(storage_case: &str) { #[test_case("memory"; "Memory")] #[test_case("sqlx"; "Sqlx Postgres")] +#[test_case("redis"; "Redis Fred")] #[rocket::async_test] async fn nonexistent_identifier(storage_case: &str) { let (storage, cleanup_task) = create_storage(storage_case).await; From ef4f1531bd4e8a4bc01b4cd97477fc432d36f768 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sun, 7 Sep 2025 01:30:36 -0400 Subject: [PATCH 16/28] organize fred.rs module --- src/storage/redis.rs | 257 +-------------------------- src/storage/redis/base.rs | 79 ++++++++ src/storage/redis/storage.rs | 40 +++++ src/storage/redis/storage_indexed.rs | 149 ++++++++++++++++ 4 files changed, 272 insertions(+), 253 deletions(-) create mode 100644 src/storage/redis/base.rs create mode 100644 src/storage/redis/storage.rs create mode 100644 src/storage/redis/storage_indexed.rs diff --git a/src/storage/redis.rs b/src/storage/redis.rs index a541013..1dfb8d8 100644 --- a/src/storage/redis.rs +++ b/src/storage/redis.rs @@ -1,18 +1,8 @@ //! Session storage with Redis (and Redis-compatible databases) -use fred::{ - prelude::{FromValue, HashesInterface, KeysInterface, Pool, SetsInterface, Value}, - types::Expiration, -}; -use rocket::{async_trait, http::CookieJar}; - -use crate::{ - error::{SessionError, SessionResult}, - storage::SessionStorageIndexed, - SessionIdentifier, -}; - -use super::interface::SessionStorage; +mod base; +mod storage; +mod storage_indexed; /// The Redis type to use for the session data #[derive(Debug)] @@ -77,111 +67,11 @@ impl From for Value { ``` */ pub struct RedisFredStorage { - pool: Pool, + pool: fred::prelude::Pool, prefix: String, redis_type: RedisType, } -impl RedisFredStorage { - pub fn new(pool: Pool, redis_type: RedisType, key_prefix: &str) -> Self { - Self { - pool, - prefix: key_prefix.to_owned(), - redis_type, - } - } - - fn key(&self, id: &str) -> String { - format!("{}{id}", self.prefix) - } - - fn session_index_key(&self, identifier_name: &str, identifier: &impl ToString) -> String { - format!( - "{}{identifier_name}:{}", - self.prefix, - identifier.to_string() - ) - } - - async fn fetch_session(&self, id: &str, ttl: Option) -> SessionResult<(Value, u32)> { - let key = self.key(id); - let pipeline = self.pool.next().pipeline(); - let _: () = match self.redis_type { - RedisType::String => pipeline.get(&key).await?, - RedisType::Hash => pipeline.hgetall(&key).await?, - }; - let _: () = pipeline.ttl(&key).await?; - - let (value, orig_ttl): (Option, i64) = match ttl { - None => pipeline.all().await?, - Some(new_ttl) => { - let _: () = pipeline.expire(&key, new_ttl.into(), None).await?; - let (value, orig_ttl, _expire_result): (Option, i64, Option) = - pipeline.all().await?; - (value, orig_ttl) - } - }; - - let found_value = value.ok_or(SessionError::NotFound)?; - Ok((found_value, ttl.unwrap_or(orig_ttl.try_into().unwrap_or(0)))) - } - - async fn save_session(&self, id: &str, value: Value, ttl: u32) -> SessionResult<()> { - let key = self.key(id); - let _: () = match self.redis_type { - RedisType::String => { - self.pool - .set(&key, value, Some(Expiration::EX(ttl.into())), None, false) - .await? - } - RedisType::Hash => { - let Value::Map(map) = value else { - return Err(SessionError::Serialization(Box::new(std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!("Converted Redis value wasn't a Map: {:?}", value), - )))); - }; - let pipeline = self.pool.next().pipeline(); - let _: () = pipeline.hset(&key, map).await?; - let _: () = pipeline.expire(&key, ttl.into(), None).await?; - pipeline.all().await? - } - }; - Ok(()) - } -} - -#[async_trait] -impl SessionStorage for RedisFredStorage -where - T: FromValue + TryInto + Clone + Send + Sync + 'static, - >::Error: std::error::Error + Send + Sync + 'static, -{ - async fn load( - &self, - id: &str, - ttl: Option, - _cookie_jar: &CookieJar, - ) -> SessionResult<(T, u32)> { - let (value, ttl) = self.fetch_session(id, ttl).await?; - let data = T::from_value(value)?; - Ok((data, ttl)) - } - - async fn save(&self, id: &str, data: T, ttl: u32) -> SessionResult<()> { - let value: Value = data - .try_into() - .map_err(|e| SessionError::Serialization(Box::new(e)))?; - self.save_session(id, value, ttl).await?; - Ok(()) - } - - async fn delete(&self, id: &str, _cookie_jar: &CookieJar) -> SessionResult<()> { - let _: u8 = self.pool.del(self.key(id)).await?; - Ok(()) - } -} - /// Redis session storage using the [fred.rs](https://docs.rs/fred) crate. This is a wrapper around /// [`RedisFredStorage`] that adds support for indexing sessions by an identifier (e.g. `user_id`). /// @@ -192,142 +82,3 @@ where pub struct RedisFredStorageIndexed { base_storage: RedisFredStorage, } - -impl RedisFredStorageIndexed { - pub fn new(base_storage: RedisFredStorage) -> Self { - Self { base_storage } - } -} - -#[async_trait] -impl SessionStorage for RedisFredStorageIndexed -where - T: SessionIdentifier + FromValue + TryInto + Clone + Send + Sync + 'static, - >::Error: std::error::Error + Send + Sync + 'static, - ::Id: ToString, -{ - async fn load( - &self, - id: &str, - ttl: Option, - _cookie_jar: &CookieJar, - ) -> SessionResult<(T, u32)> { - let (value, ttl) = self.base_storage.fetch_session(id, ttl).await?; - let data = T::from_value(value)?; - Ok((data, ttl)) - } - - async fn save(&self, id: &str, data: T, ttl: u32) -> SessionResult<()> { - if let Some(identifier) = data.identifier() { - let session_idx_key = self - .base_storage - .session_index_key(T::IDENTIFIER, identifier); - let pipeline = self.base_storage.pool.next().pipeline(); - let _: () = pipeline.sadd(&session_idx_key, id).await?; - let _: () = pipeline.expire(&session_idx_key, ttl.into(), None).await?; - let _: () = pipeline.all().await?; - } - - let value: Value = data - .try_into() - .map_err(|e| SessionError::Serialization(Box::new(e)))?; - self.base_storage.save_session(id, value, ttl).await - } - - async fn delete(&self, id: &str, _cookie_jar: &CookieJar) -> SessionResult<()> { - let (value, _) = self.base_storage.fetch_session(id, None).await?; - let data = T::from_value(value)?; - - let pipeline = self.base_storage.pool.next().pipeline(); - let _: () = pipeline.del(self.base_storage.key(id)).await?; - if let Some(identifier) = data.identifier() { - let session_idx_key = self - .base_storage - .session_index_key(T::IDENTIFIER, identifier); - let _: () = pipeline.srem(&session_idx_key, id).await?; - } - Ok(pipeline.all().await?) - } -} - -#[async_trait] -impl SessionStorageIndexed for RedisFredStorageIndexed -where - T: SessionIdentifier + FromValue + TryInto + Clone + Send + Sync + 'static, - >::Error: std::error::Error + Send + Sync + 'static, - ::Id: ToString, -{ - async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult> { - let session_index_key = self.base_storage.session_index_key(T::IDENTIFIER, id); - let session_ids: Vec = self.base_storage.pool.smembers(&session_index_key).await?; - - let session_value_pipeline = self.base_storage.pool.next().pipeline(); - for session_id in &session_ids { - let session_key = self.base_storage.key(&session_id); - let _: () = match self.base_storage.redis_type { - RedisType::String => session_value_pipeline.get(&session_key).await?, - RedisType::Hash => session_value_pipeline.hgetall(&session_key).await?, - }; - } - let session_values: Vec> = session_value_pipeline.all().await?; - - let sessions = session_values - .into_iter() - .enumerate() - .filter_map(|(idx, value)| { - value.and_then(|value| { - let session_id = session_ids.get(idx)?.clone(); - let data = T::from_value(value).ok()?; - Some((session_id, data)) - }) - }) - .collect(); - Ok(sessions) - } - - async fn get_session_ids_by_identifier(&self, id: &T::Id) -> SessionResult> { - let session_index_key = self.base_storage.session_index_key(T::IDENTIFIER, id); - let session_ids: Vec = self.base_storage.pool.smembers(&session_index_key).await?; - - let session_exist_pipeline = self.base_storage.pool.next().pipeline(); - for session_id in &session_ids { - let session_key = self.base_storage.key(&session_id); - let _: () = session_exist_pipeline.exists(&session_key).await?; - } - let session_exist_results: Vec = session_exist_pipeline.all().await?; - - let existing_sessions = session_ids - .into_iter() - .enumerate() - .filter_map(|(idx, id)| session_exist_results.get(idx)?.then_some(id)) - .collect(); - Ok(existing_sessions) - } - - async fn invalidate_sessions_by_identifier( - &self, - id: &T::Id, - excluded_session_id: Option<&str>, - ) -> SessionResult { - let session_index_key = self.base_storage.session_index_key(T::IDENTIFIER, id); - let mut session_ids: Vec = - self.base_storage.pool.smembers(&session_index_key).await?; - if let Some(excluded_id) = excluded_session_id { - session_ids.retain(|id| id != excluded_id); - } - if session_ids.is_empty() { - return Ok(0); - } - - let session_keys: Vec<_> = session_ids - .iter() - .map(|id| self.base_storage.key(id)) - .collect(); - let delete_pipeline = self.base_storage.pool.next().pipeline(); - let _: () = delete_pipeline.del(session_keys).await?; - let _: () = delete_pipeline.srem(session_index_key, session_ids).await?; - let (del_num, _srem_num): (u64, u64) = delete_pipeline.all().await?; - - Ok(del_num) - } -} diff --git a/src/storage/redis/base.rs b/src/storage/redis/base.rs new file mode 100644 index 0000000..587d723 --- /dev/null +++ b/src/storage/redis/base.rs @@ -0,0 +1,79 @@ +use fred::{ + prelude::{HashesInterface, KeysInterface, Pool, Value}, + types::Expiration, +}; + +use crate::error::{SessionError, SessionResult}; + +use super::{RedisFredStorage, RedisType}; + +impl RedisFredStorage { + pub fn new(pool: Pool, redis_type: RedisType, key_prefix: &str) -> Self { + Self { + pool, + prefix: key_prefix.to_owned(), + redis_type, + } + } + + pub(super) fn key(&self, id: &str) -> String { + format!("{}{id}", self.prefix) + } + + pub(super) fn session_index_key( + &self, + identifier_name: &str, + identifier: &impl ToString, + ) -> String { + format!( + "{}{identifier_name}:{}", + self.prefix, + identifier.to_string() + ) + } + + pub(super) async fn fetch_session( + &self, + id: &str, + ttl: Option, + ) -> SessionResult<(Value, u32)> { + let key = self.key(id); + let pipeline = self.pool.next().pipeline(); + let _: () = match self.redis_type { + RedisType::String => pipeline.get(&key).await?, + RedisType::Hash => pipeline.hgetall(&key).await?, + }; + let _: () = pipeline.ttl(&key).await?; + + let (value, orig_ttl): (Option, i64) = match ttl { + None => pipeline.all().await?, + Some(new_ttl) => { + let _: () = pipeline.expire(&key, new_ttl.into(), None).await?; + let (value, orig_ttl, _expire_result): (Option, i64, Option) = + pipeline.all().await?; + (value, orig_ttl) + } + }; + + let found_value = value.ok_or(SessionError::NotFound)?; + Ok((found_value, ttl.unwrap_or(orig_ttl.try_into().unwrap_or(0)))) + } + + pub(super) async fn save_session(&self, id: &str, value: Value, ttl: u32) -> SessionResult<()> { + let key = self.key(id); + let _: () = match self.redis_type { + RedisType::String => { + self.pool + .set(&key, value, Some(Expiration::EX(ttl.into())), None, false) + .await? + } + RedisType::Hash => { + let pipeline = self.pool.next().pipeline(); + let _: () = pipeline.hset(&key, value.into_map()?).await?; + let _: () = pipeline.expire(&key, ttl.into(), None).await?; + pipeline.all().await? + } + }; + Ok(()) + } +} diff --git a/src/storage/redis/storage.rs b/src/storage/redis/storage.rs new file mode 100644 index 0000000..b262caf --- /dev/null +++ b/src/storage/redis/storage.rs @@ -0,0 +1,40 @@ +use fred::prelude::{FromValue, KeysInterface, Value}; +use rocket::http::CookieJar; + +use crate::{ + error::{SessionError, SessionResult}, + storage::SessionStorage, +}; + +use super::RedisFredStorage; + +#[rocket::async_trait] +impl SessionStorage for RedisFredStorage +where + T: FromValue + TryInto + Clone + Send + Sync + 'static, + >::Error: std::error::Error + Send + Sync + 'static, +{ + async fn load( + &self, + id: &str, + ttl: Option, + _cookie_jar: &CookieJar, + ) -> SessionResult<(T, u32)> { + let (value, ttl) = self.fetch_session(id, ttl).await?; + let data = T::from_value(value)?; + Ok((data, ttl)) + } + + async fn save(&self, id: &str, data: T, ttl: u32) -> SessionResult<()> { + let value: Value = data + .try_into() + .map_err(|e| SessionError::Serialization(Box::new(e)))?; + self.save_session(id, value, ttl).await?; + Ok(()) + } + + async fn delete(&self, id: &str, _cookie_jar: &CookieJar) -> SessionResult<()> { + let _: u8 = self.pool.del(self.key(id)).await?; + Ok(()) + } +} diff --git a/src/storage/redis/storage_indexed.rs b/src/storage/redis/storage_indexed.rs new file mode 100644 index 0000000..308cd7a --- /dev/null +++ b/src/storage/redis/storage_indexed.rs @@ -0,0 +1,149 @@ +use fred::prelude::{FromValue, HashesInterface, KeysInterface, SetsInterface, Value}; +use rocket::http::CookieJar; + +use crate::{ + error::{SessionError, SessionResult}, + storage::{SessionStorage, SessionStorageIndexed}, + SessionIdentifier, +}; + +use super::{RedisFredStorage, RedisFredStorageIndexed}; + +impl RedisFredStorageIndexed { + pub fn new(base_storage: RedisFredStorage) -> Self { + Self { base_storage } + } +} + +#[rocket::async_trait] +impl SessionStorage for RedisFredStorageIndexed +where + T: SessionIdentifier + FromValue + TryInto + Clone + Send + Sync + 'static, + >::Error: std::error::Error + Send + Sync + 'static, + ::Id: ToString, +{ + async fn load( + &self, + id: &str, + ttl: Option, + _cookie_jar: &CookieJar, + ) -> SessionResult<(T, u32)> { + let (value, ttl) = self.base_storage.fetch_session(id, ttl).await?; + let data = T::from_value(value)?; + Ok((data, ttl)) + } + + async fn save(&self, id: &str, data: T, ttl: u32) -> SessionResult<()> { + if let Some(identifier) = data.identifier() { + let session_idx_key = self + .base_storage + .session_index_key(T::IDENTIFIER, identifier); + let pipeline = self.base_storage.pool.next().pipeline(); + let _: () = pipeline.sadd(&session_idx_key, id).await?; + let _: () = pipeline.expire(&session_idx_key, ttl.into(), None).await?; + let _: () = pipeline.all().await?; + } + + let value: Value = data + .try_into() + .map_err(|e| SessionError::Serialization(Box::new(e)))?; + self.base_storage.save_session(id, value, ttl).await + } + + async fn delete(&self, id: &str, _cookie_jar: &CookieJar) -> SessionResult<()> { + let (value, _) = self.base_storage.fetch_session(id, None).await?; + let data = T::from_value(value)?; + + let pipeline = self.base_storage.pool.next().pipeline(); + let _: () = pipeline.del(self.base_storage.key(id)).await?; + if let Some(identifier) = data.identifier() { + let session_idx_key = self + .base_storage + .session_index_key(T::IDENTIFIER, identifier); + let _: () = pipeline.srem(&session_idx_key, id).await?; + } + Ok(pipeline.all().await?) + } +} + +#[rocket::async_trait] +impl SessionStorageIndexed for RedisFredStorageIndexed +where + T: SessionIdentifier + FromValue + TryInto + Clone + Send + Sync + 'static, + >::Error: std::error::Error + Send + Sync + 'static, + ::Id: ToString, +{ + async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult> { + let session_index_key = self.base_storage.session_index_key(T::IDENTIFIER, id); + let session_ids: Vec = self.base_storage.pool.smembers(&session_index_key).await?; + + let session_value_pipeline = self.base_storage.pool.next().pipeline(); + for session_id in &session_ids { + let session_key = self.base_storage.key(&session_id); + let _: () = match self.base_storage.redis_type { + super::RedisType::String => session_value_pipeline.get(&session_key).await?, + super::RedisType::Hash => session_value_pipeline.hgetall(&session_key).await?, + }; + } + let session_values: Vec> = session_value_pipeline.all().await?; + + let sessions = session_values + .into_iter() + .enumerate() + .filter_map(|(idx, value)| { + value.and_then(|value| { + let session_id = session_ids.get(idx)?.clone(); + let data = T::from_value(value).ok()?; + Some((session_id, data)) + }) + }) + .collect(); + Ok(sessions) + } + + async fn get_session_ids_by_identifier(&self, id: &T::Id) -> SessionResult> { + let session_index_key = self.base_storage.session_index_key(T::IDENTIFIER, id); + let session_ids: Vec = self.base_storage.pool.smembers(&session_index_key).await?; + + let session_exist_pipeline = self.base_storage.pool.next().pipeline(); + for session_id in &session_ids { + let session_key = self.base_storage.key(&session_id); + let _: () = session_exist_pipeline.exists(&session_key).await?; + } + let session_exist_results: Vec = session_exist_pipeline.all().await?; + + let existing_sessions = session_ids + .into_iter() + .enumerate() + .filter_map(|(idx, id)| session_exist_results.get(idx)?.then_some(id)) + .collect(); + Ok(existing_sessions) + } + + async fn invalidate_sessions_by_identifier( + &self, + id: &T::Id, + excluded_session_id: Option<&str>, + ) -> SessionResult { + let session_index_key = self.base_storage.session_index_key(T::IDENTIFIER, id); + let mut session_ids: Vec = + self.base_storage.pool.smembers(&session_index_key).await?; + if let Some(excluded_id) = excluded_session_id { + session_ids.retain(|id| id != excluded_id); + } + if session_ids.is_empty() { + return Ok(0); + } + + let session_keys: Vec<_> = session_ids + .iter() + .map(|id| self.base_storage.key(id)) + .collect(); + let delete_pipeline = self.base_storage.pool.next().pipeline(); + let _: () = delete_pipeline.del(session_keys).await?; + let _: () = delete_pipeline.srem(session_index_key, session_ids).await?; + let (del_num, _srem_num): (u64, u64) = delete_pipeline.all().await?; + + Ok(del_num) + } +} From f2f880a0a2175f8ed51937d62d96749b1eab245d Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sun, 7 Sep 2025 01:31:23 -0400 Subject: [PATCH 17/28] rename for clarity --- src/storage/redis/base.rs | 6 +++--- src/storage/redis/storage.rs | 2 +- src/storage/redis/storage_indexed.rs | 8 ++++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/storage/redis/base.rs b/src/storage/redis/base.rs index 587d723..be564e1 100644 --- a/src/storage/redis/base.rs +++ b/src/storage/redis/base.rs @@ -16,7 +16,7 @@ impl RedisFredStorage { } } - pub(super) fn key(&self, id: &str) -> String { + pub(super) fn session_key(&self, id: &str) -> String { format!("{}{id}", self.prefix) } @@ -37,7 +37,7 @@ impl RedisFredStorage { id: &str, ttl: Option, ) -> SessionResult<(Value, u32)> { - let key = self.key(id); + let key = self.session_key(id); let pipeline = self.pool.next().pipeline(); let _: () = match self.redis_type { RedisType::String => pipeline.get(&key).await?, @@ -60,7 +60,7 @@ impl RedisFredStorage { } pub(super) async fn save_session(&self, id: &str, value: Value, ttl: u32) -> SessionResult<()> { - let key = self.key(id); + let key = self.session_key(id); let _: () = match self.redis_type { RedisType::String => { self.pool diff --git a/src/storage/redis/storage.rs b/src/storage/redis/storage.rs index b262caf..240aad4 100644 --- a/src/storage/redis/storage.rs +++ b/src/storage/redis/storage.rs @@ -34,7 +34,7 @@ where } async fn delete(&self, id: &str, _cookie_jar: &CookieJar) -> SessionResult<()> { - let _: u8 = self.pool.del(self.key(id)).await?; + let _: u8 = self.pool.del(self.session_key(id)).await?; Ok(()) } } diff --git a/src/storage/redis/storage_indexed.rs b/src/storage/redis/storage_indexed.rs index 308cd7a..f5453fe 100644 --- a/src/storage/redis/storage_indexed.rs +++ b/src/storage/redis/storage_indexed.rs @@ -55,7 +55,7 @@ where let data = T::from_value(value)?; let pipeline = self.base_storage.pool.next().pipeline(); - let _: () = pipeline.del(self.base_storage.key(id)).await?; + let _: () = pipeline.del(self.base_storage.session_key(id)).await?; if let Some(identifier) = data.identifier() { let session_idx_key = self .base_storage @@ -79,7 +79,7 @@ where let session_value_pipeline = self.base_storage.pool.next().pipeline(); for session_id in &session_ids { - let session_key = self.base_storage.key(&session_id); + let session_key = self.base_storage.session_key(&session_id); let _: () = match self.base_storage.redis_type { super::RedisType::String => session_value_pipeline.get(&session_key).await?, super::RedisType::Hash => session_value_pipeline.hgetall(&session_key).await?, @@ -107,7 +107,7 @@ where let session_exist_pipeline = self.base_storage.pool.next().pipeline(); for session_id in &session_ids { - let session_key = self.base_storage.key(&session_id); + let session_key = self.base_storage.session_key(&session_id); let _: () = session_exist_pipeline.exists(&session_key).await?; } let session_exist_results: Vec = session_exist_pipeline.all().await?; @@ -137,7 +137,7 @@ where let session_keys: Vec<_> = session_ids .iter() - .map(|id| self.base_storage.key(id)) + .map(|id| self.base_storage.session_key(id)) .collect(); let delete_pipeline = self.base_storage.pool.next().pipeline(); let _: () = delete_pipeline.del(session_keys).await?; From be019cb37fadc148fd8c77ad8c9cca7e1a2e5abf Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sun, 7 Sep 2025 02:05:48 -0400 Subject: [PATCH 18/28] add stable (and customizable) expiration for fred.rs session index --- src/fairing.rs | 6 +++--- src/storage/redis.rs | 1 + src/storage/redis/base.rs | 5 +++++ src/storage/redis/storage_indexed.rs | 23 ++++++++++++++++++----- tests/storages_basic.rs | 2 +- tests/storages_indexed.rs | 2 +- 6 files changed, 29 insertions(+), 10 deletions(-) diff --git a/src/fairing.rs b/src/fairing.rs index 7da5851..181980b 100644 --- a/src/fairing.rs +++ b/src/fairing.rs @@ -162,7 +162,7 @@ where if let Some(deleted_id) = deleted { let delete_result = self.storage.delete(&deleted_id, req.cookies()).await; if let Err(e) = delete_result { - rocket::error!("Error while deleting session '{}': {}", deleted_id, e); + rocket::warn!("Error while deleting session '{deleted_id}': {e}"); } } @@ -170,7 +170,7 @@ where if let Some((id, pending_data, ttl)) = updated { let save_result = self.storage.save(&id, pending_data, ttl).await; if let Err(e) = save_result { - rocket::error!("Error while saving session '{}': {}", &id, e); + rocket::error!("Error while saving session '{id}': {e}"); } } } @@ -178,7 +178,7 @@ where async fn on_shutdown(&self, _rocket: &Rocket) { rocket::debug!("Shutting down session resources..."); if let Err(e) = self.storage.shutdown().await { - rocket::warn!("Error during session storage shutdown: {}", e); + rocket::warn!("Error during session storage shutdown: {e}"); } } } diff --git a/src/storage/redis.rs b/src/storage/redis.rs index 1dfb8d8..4cc3a41 100644 --- a/src/storage/redis.rs +++ b/src/storage/redis.rs @@ -81,4 +81,5 @@ pub struct RedisFredStorage { /// `:`. e.g.: `sess:user_id:1` pub struct RedisFredStorageIndexed { base_storage: RedisFredStorage, + index_ttl: u32, } diff --git a/src/storage/redis/base.rs b/src/storage/redis/base.rs index be564e1..35a92f3 100644 --- a/src/storage/redis/base.rs +++ b/src/storage/redis/base.rs @@ -8,6 +8,11 @@ use crate::error::{SessionError, SessionResult}; use super::{RedisFredStorage, RedisType}; impl RedisFredStorage { + /// Create the storage instance. + /// # Parameters + /// * `pool` - The initialized fred.rs connection pool. + /// * `redis_type` - The Redis data type to use for storing sessions. + /// * `key_prefix` - The prefix to use for session keys. (e.g. "sess:") pub fn new(pool: Pool, redis_type: RedisType, key_prefix: &str) -> Self { Self { pool, diff --git a/src/storage/redis/storage_indexed.rs b/src/storage/redis/storage_indexed.rs index f5453fe..45f22b2 100644 --- a/src/storage/redis/storage_indexed.rs +++ b/src/storage/redis/storage_indexed.rs @@ -9,9 +9,20 @@ use crate::{ use super::{RedisFredStorage, RedisFredStorageIndexed}; +const DEFAULT_INDEX_TTL: u32 = 60 * 60 * 24 * 7 * 2; // 2 weeks + impl RedisFredStorageIndexed { - pub fn new(base_storage: RedisFredStorage) -> Self { - Self { base_storage } + /// Create the indexed storage. + /// + /// # Parameters: + /// - `base_storage`: The base storage to use for session data. + /// - `index_ttl`: The TTL for the session index - should match + /// your longest expected session duration (default: 2 weeks). + pub fn new(base_storage: RedisFredStorage, index_ttl: Option) -> Self { + Self { + base_storage, + index_ttl: index_ttl.unwrap_or(DEFAULT_INDEX_TTL), + } } } @@ -35,12 +46,14 @@ where async fn save(&self, id: &str, data: T, ttl: u32) -> SessionResult<()> { if let Some(identifier) = data.identifier() { - let session_idx_key = self + let session_index_key = self .base_storage .session_index_key(T::IDENTIFIER, identifier); let pipeline = self.base_storage.pool.next().pipeline(); - let _: () = pipeline.sadd(&session_idx_key, id).await?; - let _: () = pipeline.expire(&session_idx_key, ttl.into(), None).await?; + let _: () = pipeline.sadd(&session_index_key, id).await?; + let _: () = pipeline + .expire(&session_index_key, self.index_ttl.into(), None) + .await?; let _: () = pipeline.all().await?; } diff --git a/tests/storages_basic.rs b/tests/storages_basic.rs index 4a14028..f3f6b0b 100644 --- a/tests/storages_basic.rs +++ b/tests/storages_basic.rs @@ -119,7 +119,7 @@ async fn create_rocket( "redis_indexed" => { let (pool, prefix) = setup_redis_fred().await; let base_storage = RedisFredStorage::new(pool.clone(), RedisType::String, &prefix); - let storage = RedisFredStorageIndexed::new(base_storage); + let storage = RedisFredStorageIndexed::new(base_storage, None); let fairing = RocketFlexSession::::builder() .storage(storage) .build(); diff --git a/tests/storages_indexed.rs b/tests/storages_indexed.rs index f838da8..81cf4e7 100644 --- a/tests/storages_indexed.rs +++ b/tests/storages_indexed.rs @@ -87,7 +87,7 @@ async fn create_storage( "redis" => { let (pool, prefix) = setup_redis_fred().await; let base_storage = RedisFredStorage::new(pool.clone(), RedisType::Hash, &prefix); - let storage = RedisFredStorageIndexed::new(base_storage); + let storage = RedisFredStorageIndexed::new(base_storage, None); let cleanup_task = teardown_redis_fred(pool, prefix).boxed(); (Box::new(storage), Some(cleanup_task)) } From 13aed8bb9c3b42e266c2118d430a9309aad1e5c8 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sun, 7 Sep 2025 03:07:31 -0400 Subject: [PATCH 19/28] auto clean up stale sessions in fred.rs session index --- src/storage/redis/base.rs | 12 ----- src/storage/redis/storage_indexed.rs | 72 ++++++++++++++++------------ 2 files changed, 42 insertions(+), 42 deletions(-) diff --git a/src/storage/redis/base.rs b/src/storage/redis/base.rs index 35a92f3..3cbf214 100644 --- a/src/storage/redis/base.rs +++ b/src/storage/redis/base.rs @@ -25,18 +25,6 @@ impl RedisFredStorage { format!("{}{id}", self.prefix) } - pub(super) fn session_index_key( - &self, - identifier_name: &str, - identifier: &impl ToString, - ) -> String { - format!( - "{}{identifier_name}:{}", - self.prefix, - identifier.to_string() - ) - } - pub(super) async fn fetch_session( &self, id: &str, diff --git a/src/storage/redis/storage_indexed.rs b/src/storage/redis/storage_indexed.rs index 45f22b2..fc324af 100644 --- a/src/storage/redis/storage_indexed.rs +++ b/src/storage/redis/storage_indexed.rs @@ -24,6 +24,14 @@ impl RedisFredStorageIndexed { index_ttl: index_ttl.unwrap_or(DEFAULT_INDEX_TTL), } } + + fn session_index_key(&self, identifier_name: &str, identifier: &impl ToString) -> String { + format!( + "{}{identifier_name}:{}", + self.base_storage.prefix, + identifier.to_string() + ) + } } #[rocket::async_trait] @@ -46,13 +54,11 @@ where async fn save(&self, id: &str, data: T, ttl: u32) -> SessionResult<()> { if let Some(identifier) = data.identifier() { - let session_index_key = self - .base_storage - .session_index_key(T::IDENTIFIER, identifier); + let index_key = self.session_index_key(T::IDENTIFIER, identifier); let pipeline = self.base_storage.pool.next().pipeline(); - let _: () = pipeline.sadd(&session_index_key, id).await?; + let _: () = pipeline.sadd(&index_key, id).await?; let _: () = pipeline - .expire(&session_index_key, self.index_ttl.into(), None) + .expire(&index_key, self.index_ttl.into(), None) .await?; let _: () = pipeline.all().await?; } @@ -70,9 +76,7 @@ where let pipeline = self.base_storage.pool.next().pipeline(); let _: () = pipeline.del(self.base_storage.session_key(id)).await?; if let Some(identifier) = data.identifier() { - let session_idx_key = self - .base_storage - .session_index_key(T::IDENTIFIER, identifier); + let session_idx_key = self.session_index_key(T::IDENTIFIER, identifier); let _: () = pipeline.srem(&session_idx_key, id).await?; } Ok(pipeline.all().await?) @@ -87,8 +91,8 @@ where ::Id: ToString, { async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult> { - let session_index_key = self.base_storage.session_index_key(T::IDENTIFIER, id); - let session_ids: Vec = self.base_storage.pool.smembers(&session_index_key).await?; + let index_key = self.session_index_key(T::IDENTIFIER, id); + let session_ids: Vec = self.base_storage.pool.smembers(&index_key).await?; let session_value_pipeline = self.base_storage.pool.next().pipeline(); for session_id in &session_ids { @@ -100,23 +104,26 @@ where } let session_values: Vec> = session_value_pipeline.all().await?; - let sessions = session_values + let (existing_sessions, stale_sessions): (Vec<_>, Vec<_>) = session_ids + .into_iter() + .zip(session_values.into_iter()) + .map(|(id, value)| (id, value.and_then(|v| T::from_value(v).ok()))) + .partition(|(_, data)| data.is_some()); + if !stale_sessions.is_empty() { + let stale_ids: Vec<_> = stale_sessions.into_iter().map(|(id, _)| id).collect(); + let _: () = self.base_storage.pool.srem(&index_key, stale_ids).await?; + } + + let sessions = existing_sessions .into_iter() - .enumerate() - .filter_map(|(idx, value)| { - value.and_then(|value| { - let session_id = session_ids.get(idx)?.clone(); - let data = T::from_value(value).ok()?; - Some((session_id, data)) - }) - }) + .map(|(session_id, data)| (session_id.to_owned(), data.expect("already checked"))) .collect(); Ok(sessions) } async fn get_session_ids_by_identifier(&self, id: &T::Id) -> SessionResult> { - let session_index_key = self.base_storage.session_index_key(T::IDENTIFIER, id); - let session_ids: Vec = self.base_storage.pool.smembers(&session_index_key).await?; + let index_key = self.session_index_key(T::IDENTIFIER, id); + let session_ids: Vec = self.base_storage.pool.smembers(&index_key).await?; let session_exist_pipeline = self.base_storage.pool.next().pipeline(); for session_id in &session_ids { @@ -125,12 +132,17 @@ where } let session_exist_results: Vec = session_exist_pipeline.all().await?; - let existing_sessions = session_ids + let (existing_sessions, stale_sessions): (Vec<_>, Vec<_>) = session_ids .into_iter() - .enumerate() - .filter_map(|(idx, id)| session_exist_results.get(idx)?.then_some(id)) - .collect(); - Ok(existing_sessions) + .zip(session_exist_results.into_iter()) + .partition(|(_, exists)| *exists); + if !stale_sessions.is_empty() { + let stale_ids: Vec<_> = stale_sessions.into_iter().map(|(id, _)| id).collect(); + let _: () = self.base_storage.pool.srem(&index_key, stale_ids).await?; + } + + let sessions = existing_sessions.into_iter().map(|(id, _)| id).collect(); + Ok(sessions) } async fn invalidate_sessions_by_identifier( @@ -138,9 +150,9 @@ where id: &T::Id, excluded_session_id: Option<&str>, ) -> SessionResult { - let session_index_key = self.base_storage.session_index_key(T::IDENTIFIER, id); - let mut session_ids: Vec = - self.base_storage.pool.smembers(&session_index_key).await?; + let index_key = self.session_index_key(T::IDENTIFIER, id); + let mut session_ids: Vec = self.base_storage.pool.smembers(&index_key).await?; + if let Some(excluded_id) = excluded_session_id { session_ids.retain(|id| id != excluded_id); } @@ -154,7 +166,7 @@ where .collect(); let delete_pipeline = self.base_storage.pool.next().pipeline(); let _: () = delete_pipeline.del(session_keys).await?; - let _: () = delete_pipeline.srem(session_index_key, session_ids).await?; + let _: () = delete_pipeline.srem(index_key, session_ids).await?; let (del_num, _srem_num): (u64, u64) = delete_pipeline.all().await?; Ok(del_num) From e5211690d690e6aa645ea5b9a5c2efe3ca256544 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sun, 7 Sep 2025 03:10:15 -0400 Subject: [PATCH 20/28] Update storage_indexed.rs --- src/storage/redis/storage_indexed.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/storage/redis/storage_indexed.rs b/src/storage/redis/storage_indexed.rs index fc324af..000898f 100644 --- a/src/storage/redis/storage_indexed.rs +++ b/src/storage/redis/storage_indexed.rs @@ -116,7 +116,7 @@ where let sessions = existing_sessions .into_iter() - .map(|(session_id, data)| (session_id.to_owned(), data.expect("already checked"))) + .map(|(id, data)| (id, data.expect("already checked by partition"))) .collect(); Ok(sessions) } From 5d35a230029e611ac700e2a250aba9b796efa405 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sun, 7 Sep 2025 03:22:11 -0400 Subject: [PATCH 21/28] Update storage_indexed.rs --- src/storage/redis/storage_indexed.rs | 32 ++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/src/storage/redis/storage_indexed.rs b/src/storage/redis/storage_indexed.rs index 000898f..4d8dcc9 100644 --- a/src/storage/redis/storage_indexed.rs +++ b/src/storage/redis/storage_indexed.rs @@ -32,6 +32,24 @@ impl RedisFredStorageIndexed { identifier.to_string() ) } + + async fn fetch_session_index( + &self, + identifier_name: &str, + identifier: &impl ToString, + ) -> SessionResult<(Vec, String)> { + let index_key = self.session_index_key(identifier_name, identifier); + let session_ids = self.base_storage.pool.smembers(&index_key).await?; + Ok((session_ids, index_key)) + } + + async fn cleanup_session_index( + &self, + index_key: &str, + stale_ids: Vec, + ) -> SessionResult<()> { + Ok(self.base_storage.pool.srem(index_key, stale_ids).await?) + } } #[rocket::async_trait] @@ -91,8 +109,7 @@ where ::Id: ToString, { async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult> { - let index_key = self.session_index_key(T::IDENTIFIER, id); - let session_ids: Vec = self.base_storage.pool.smembers(&index_key).await?; + let (session_ids, index_key) = self.fetch_session_index(T::IDENTIFIER, id).await?; let session_value_pipeline = self.base_storage.pool.next().pipeline(); for session_id in &session_ids { @@ -111,7 +128,7 @@ where .partition(|(_, data)| data.is_some()); if !stale_sessions.is_empty() { let stale_ids: Vec<_> = stale_sessions.into_iter().map(|(id, _)| id).collect(); - let _: () = self.base_storage.pool.srem(&index_key, stale_ids).await?; + self.cleanup_session_index(&index_key, stale_ids).await?; } let sessions = existing_sessions @@ -122,8 +139,7 @@ where } async fn get_session_ids_by_identifier(&self, id: &T::Id) -> SessionResult> { - let index_key = self.session_index_key(T::IDENTIFIER, id); - let session_ids: Vec = self.base_storage.pool.smembers(&index_key).await?; + let (session_ids, index_key) = self.fetch_session_index(T::IDENTIFIER, id).await?; let session_exist_pipeline = self.base_storage.pool.next().pipeline(); for session_id in &session_ids { @@ -138,7 +154,7 @@ where .partition(|(_, exists)| *exists); if !stale_sessions.is_empty() { let stale_ids: Vec<_> = stale_sessions.into_iter().map(|(id, _)| id).collect(); - let _: () = self.base_storage.pool.srem(&index_key, stale_ids).await?; + self.cleanup_session_index(&index_key, stale_ids).await?; } let sessions = existing_sessions.into_iter().map(|(id, _)| id).collect(); @@ -150,9 +166,7 @@ where id: &T::Id, excluded_session_id: Option<&str>, ) -> SessionResult { - let index_key = self.session_index_key(T::IDENTIFIER, id); - let mut session_ids: Vec = self.base_storage.pool.smembers(&index_key).await?; - + let (mut session_ids, index_key) = self.fetch_session_index(T::IDENTIFIER, id).await?; if let Some(excluded_id) = excluded_session_id { session_ids.retain(|id| id != excluded_id); } From 720f6f4ce000aa730166b6a55753c1ff2a9ab087 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sun, 7 Sep 2025 03:25:37 -0400 Subject: [PATCH 22/28] fix docs --- src/session_index.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/session_index.rs b/src/session_index.rs index 9242d56..cae3348 100644 --- a/src/session_index.rs +++ b/src/session_index.rs @@ -18,7 +18,7 @@ use crate::{error::SessionError, storage::SessionStorageIndexed, Session}; /// } /// /// impl SessionIdentifier for MySession { -/// const NAME: &str = "user_id"; +/// const IDENTIFIER: &str = "user_id"; /// type Id = String; /// /// fn identifier(&self) -> Option<&Self::Id> { From 0aeb83fea962329a79e6eae28f107d71006646ba Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sun, 7 Sep 2025 04:38:45 -0400 Subject: [PATCH 23/28] also return ttl when retrieving all user sessions --- src/session.rs | 2 +- src/session_index.rs | 23 ++++++++------- src/storage/interface.rs | 8 ++--- src/storage/memory.rs | 7 +++-- src/storage/redis/storage_indexed.rs | 23 +++++++++++---- src/storage/sqlx.rs | 44 +++++++++++++++++++--------- tests/storages_indexed.rs | 14 +++++---- 7 files changed, 76 insertions(+), 45 deletions(-) diff --git a/src/session.rs b/src/session.rs index 312632c..e1e2f36 100644 --- a/src/session.rs +++ b/src/session.rs @@ -155,7 +155,7 @@ where OffsetDateTime::now_utc().saturating_add(Duration::seconds(self.ttl().into())) } - /// Delete the session. + /// Delete the current session. pub fn delete(&mut self) { // Delete inner session data let mut inner = self.get_inner_lock(); diff --git a/src/session_index.rs b/src/session_index.rs index cae3348..f8c2d6c 100644 --- a/src/session_index.rs +++ b/src/session_index.rs @@ -44,9 +44,10 @@ impl<'a, T> Session<'a, T> where T: SessionIdentifier + Send + Sync + Clone, { - /// Get all session IDs and data for the same identifier as the current session. - /// Returns `None` if there's no session or the session isn't indexed. - pub async fn get_all_sessions(&self) -> Result>, SessionError> { + /// Get all active sessions for the same user/identifier as the current session. + /// Returns the session ID, data, and TTL (in seconds) for each session. + /// Returns `None` if there's no current session or the session isn't indexed. + pub async fn get_all_sessions(&self) -> Result>, SessionError> { let Some(identifier) = self.get_identifier() else { return Ok(None); }; @@ -56,8 +57,8 @@ where Ok(Some(sessions)) } - /// Get all session IDs for the same identifier as the current session. - /// Returns `None` if there's no session or the session isn't indexed. + /// Get all active session IDs for the same user/identifier as the current session. + /// Returns `None` if there's no current session or the session isn't indexed. pub async fn get_all_session_ids(&self) -> Result>, SessionError> { let Some(identifier) = self.get_identifier() else { return Ok(None); @@ -68,8 +69,8 @@ where Ok(Some(session_ids)) } - /// Invalidate all sessions with the same identifier as the current session, optionally keeping the current session active. - /// Returns the number of sessions invalidated, or `None` if there's no session or the session isn't indexed. + /// Invalidate all sessions with the same user/identifier as the current session, optionally keeping the current session active. + /// Returns the number of sessions invalidated, or `None` if there's no current session or the session isn't indexed. pub async fn invalidate_all_sessions( &self, keep_current: bool, @@ -88,16 +89,16 @@ where Ok(Some(num_sessions)) } - /// Get all session IDs and data for a specific identifier. + /// Get all session IDs, data, and TTL (in seconds) for a specific user/identifier. pub async fn get_sessions_by_identifier( &self, identifier: &T::Id, - ) -> Result, SessionError> { + ) -> Result, SessionError> { let storage = self.get_indexed_storage()?; storage.get_sessions_by_identifier(identifier).await } - /// Get all session IDs for a specific identifier. + /// Get all session IDs for a specific user/identifier. pub async fn get_session_ids_by_identifier( &self, identifier: &T::Id, @@ -106,7 +107,7 @@ where storage.get_session_ids_by_identifier(identifier).await } - /// Invalidate all sessions for a specific identifier, returning the number of sessions invalidated. + /// Invalidate all sessions for a specific user/identifier, returning the number of sessions invalidated. pub async fn invalidate_sessions_by_identifier( &self, identifier: &T::Id, diff --git a/src/storage/interface.rs b/src/storage/interface.rs index f85305d..51788e5 100644 --- a/src/storage/interface.rs +++ b/src/storage/interface.rs @@ -68,12 +68,12 @@ pub trait SessionStorageIndexed: SessionStorage where T: SessionIdentifier + Send + Sync, { - /// Retrieve all tracked session IDs and data for the given identifier. - async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult>; - - /// Get all tracked session IDs associated with the given identifier. + /// Retrieve all tracked session IDs associated with the given identifier. async fn get_session_ids_by_identifier(&self, id: &T::Id) -> SessionResult>; + /// Retrieve all tracked session IDs, data, and TTL for the given identifier. + async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult>; + /// Invalidate all tracked sessions associated with the given identifier, optionally excluding one session ID. /// Returns the number of sessions invalidated. async fn invalidate_sessions_by_identifier( diff --git a/src/storage/memory.rs b/src/storage/memory.rs index 74e18d2..7d4be50 100644 --- a/src/storage/memory.rs +++ b/src/storage/memory.rs @@ -237,16 +237,17 @@ where T: SessionIdentifier + Clone + Send + Sync, T::Id: ToString, { - async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult> { + async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult> { let session_ids = { let index = self.identifier_index.lock().unwrap(); index.get(&id.to_string()).cloned().unwrap_or_default() }; - let mut sessions: Vec<(String, T)> = Vec::new(); + let mut sessions: Vec<(String, T, u32)> = Vec::new(); for session_id in session_ids { if let Some(data) = self.base_storage.cache.get(&session_id).await { - sessions.push((session_id, data.value().to_owned())); + let secs = data.expiration().remaining().unwrap().as_secs(); + sessions.push((session_id, data.value().to_owned(), secs as u32)); } } diff --git a/src/storage/redis/storage_indexed.rs b/src/storage/redis/storage_indexed.rs index 4d8dcc9..494f405 100644 --- a/src/storage/redis/storage_indexed.rs +++ b/src/storage/redis/storage_indexed.rs @@ -108,7 +108,7 @@ where >::Error: std::error::Error + Send + Sync + 'static, ::Id: ToString, { - async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult> { + async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult> { let (session_ids, index_key) = self.fetch_session_index(T::IDENTIFIER, id).await?; let session_value_pipeline = self.base_storage.pool.next().pipeline(); @@ -118,14 +118,22 @@ where super::RedisType::String => session_value_pipeline.get(&session_key).await?, super::RedisType::Hash => session_value_pipeline.hgetall(&session_key).await?, }; + let _: () = session_value_pipeline.ttl(&session_key).await?; } - let session_values: Vec> = session_value_pipeline.all().await?; + let mut raw_values_and_ttls: Vec> = session_value_pipeline.all().await?; let (existing_sessions, stale_sessions): (Vec<_>, Vec<_>) = session_ids .into_iter() - .zip(session_values.into_iter()) - .map(|(id, value)| (id, value.and_then(|v| T::from_value(v).ok()))) - .partition(|(_, data)| data.is_some()); + .zip(raw_values_and_ttls.chunks_exact_mut(2)) + .map(|(id, raw)| { + let data_and_ttl = raw[0].take().and_then(|val| { + let data = T::from_value(val).ok()?; + let ttl = raw[1].as_ref().and_then(Value::as_i64)?; + Some((data, ttl)) + }); + (id, data_and_ttl) + }) + .partition(|(_, data_and_ttl)| data_and_ttl.is_some()); if !stale_sessions.is_empty() { let stale_ids: Vec<_> = stale_sessions.into_iter().map(|(id, _)| id).collect(); self.cleanup_session_index(&index_key, stale_ids).await?; @@ -133,7 +141,10 @@ where let sessions = existing_sessions .into_iter() - .map(|(id, data)| (id, data.expect("already checked by partition"))) + .map(|(id, data_and_ttl)| { + let (data, ttl) = data_and_ttl.expect("already checked by partition"); + (id, data, ttl.try_into().unwrap_or(0)) + }) .collect(); Ok(sessions) } diff --git a/src/storage/sqlx.rs b/src/storage/sqlx.rs index 2e7956b..9b1e88e 100644 --- a/src/storage/sqlx.rs +++ b/src/storage/sqlx.rs @@ -9,7 +9,7 @@ use rocket::{ time::interval, }, }; -use sqlx::{PgPool, Row}; +use sqlx::{postgres::PgRow, PgPool, Row}; use time::{Duration, OffsetDateTime}; use crate::{ @@ -61,6 +61,23 @@ impl SqlxPostgresStorage { shutdown_tx: Mutex::default(), } } + + fn id_from_row(&self, row: &PgRow) -> sqlx::Result { + row.try_get(ID_COLUMN) + } + + fn raw_data_from_row(&self, row: &PgRow) -> sqlx::Result { + row.try_get(DATA_COLUMN) + } + + fn ttl_from_row(&self, row: &PgRow) -> sqlx::Result { + let expires: OffsetDateTime = row.try_get(EXPIRES_COLUMN)?; + let ttl = (expires - OffsetDateTime::now_utc()) + .whole_seconds() + .try_into() + .unwrap_or(0); + Ok(ttl) + } } const ID_COLUMN: &str = "id"; @@ -108,18 +125,16 @@ where } }; - let (raw_str, expires) = match row { + let (raw_str, ttl) = match row { Some(row) => { - let data: String = row.try_get(DATA_COLUMN)?; - let expires: OffsetDateTime = row.try_get(EXPIRES_COLUMN)?; - (data, expires) + let data = self.raw_data_from_row(&row)?; + let ttl = self.ttl_from_row(&row)?; + (data, ttl) } None => return Err(SessionError::NotFound), }; let data = T::try_from(raw_str).map_err(|e| SessionError::Serialization(Box::new(e)))?; - let ttl = (expires - OffsetDateTime::now_utc()).whole_seconds(); - - Ok((data, ttl.try_into().unwrap_or(0))) + Ok((data, ttl)) } async fn save(&self, id: &str, data: T, ttl: u32) -> SessionResult<()> { @@ -207,9 +222,9 @@ where for<'q> sqlx::Encode<'q, sqlx::Postgres> + sqlx::Type, >::Error: std::error::Error + Send + Sync + 'static, { - async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult> { + async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult> { let sql = format!( - "SELECT {ID_COLUMN}, {DATA_COLUMN} FROM \"{}\" \ + "SELECT {ID_COLUMN}, {DATA_COLUMN}, {EXPIRES_COLUMN} FROM \"{}\" \ WHERE {} = $1 AND {EXPIRES_COLUMN} > CURRENT_TIMESTAMP", &self.table_name, T::IDENTIFIER @@ -218,10 +233,11 @@ where let parsed_rows = rows .into_iter() .filter_map(|row| { - let id: String = row.try_get(0).ok()?; - let raw_data: String = row.try_get(1).ok()?; + let id = self.id_from_row(&row).ok()?; + let raw_data = self.raw_data_from_row(&row).ok()?; let data = T::try_from(raw_data).ok()?; - Some((id, data)) + let ttl = self.ttl_from_row(&row).ok()?; + Some((id, data, ttl)) }) .collect(); @@ -238,7 +254,7 @@ where let rows = sqlx::query(&sql).bind(id).fetch_all(&self.pool).await?; let parsed_rows = rows .into_iter() - .filter_map(|row| row.try_get(0).ok()) + .filter_map(|row| self.id_from_row(&row).ok()) .collect(); Ok(parsed_rows) diff --git a/tests/storages_indexed.rs b/tests/storages_indexed.rs index 81cf4e7..fa13ba3 100644 --- a/tests/storages_indexed.rs +++ b/tests/storages_indexed.rs @@ -135,10 +135,10 @@ async fn basic_operations(storage_case: &str) { assert_eq!(user1_sessions.len(), 2); assert!(user1_sessions .iter() - .any(|(id, data)| id == "sid1" && data == &session1)); + .any(|(id, data, ttl)| id == "sid1" && data == &session1 && *ttl <= 3600)); assert!(user1_sessions .iter() - .any(|(id, data)| id == "sid2" && data == &session2)); + .any(|(id, data, ttl)| id == "sid2" && data == &session2 && *ttl <= 3600)); let user2_sessions = storage .get_sessions_by_identifier(&"user2".to_string()) @@ -147,7 +147,7 @@ async fn basic_operations(storage_case: &str) { assert_eq!(user2_sessions.len(), 1); assert!(user2_sessions .iter() - .any(|(id, data)| id == "sid3" && data == &session3)); + .any(|(id, data, ttl)| id == "sid3" && data == &session3 && *ttl <= 3600)); // Test get_session_ids_by_identifier let user1_session_ids = storage @@ -225,7 +225,8 @@ async fn invalidate_by_identifier(storage_case: &str) { .await .unwrap(); assert_eq!(user2_sessions.len(), 1); - assert_eq!(user2_sessions[0], ("sid3".to_string(), session3)); + assert_eq!(user2_sessions[0].0, "sid3"); + assert_eq!(user2_sessions[0].1, session3); storage.shutdown().await.unwrap(); if let Some(task) = cleanup_task { @@ -284,7 +285,8 @@ async fn invalidate_all_but_one_by_identifier(storage_case: &str) { .await .unwrap(); assert_eq!(user1_sessions.len(), 1); - assert_eq!(user1_sessions[0], ("sid3".to_string(), session3)); + assert_eq!(user1_sessions[0].0, "sid3"); + assert_eq!(user1_sessions[0].1, session3); storage.shutdown().await.unwrap(); if let Some(task) = cleanup_task { @@ -335,7 +337,7 @@ async fn delete_single_session(storage_case: &str) { assert_eq!(remaining_sessions.len(), 1); assert!(remaining_sessions .iter() - .any(|(id, data)| id == "sid2" && data == &session2)); + .any(|(id, data, ttl)| id == "sid2" && data == &session2 && *ttl <= 3600)); storage.shutdown().await.unwrap(); if let Some(task) = cleanup_task { From 52b5fa4af29c4f75140348c078c424501759e151 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sun, 7 Sep 2025 05:11:14 -0400 Subject: [PATCH 24/28] doc tweaks --- src/session.rs | 31 +++++++++++++++++++++++++++---- src/session_index.rs | 5 +++-- src/storage/redis.rs | 4 ++-- 3 files changed, 32 insertions(+), 8 deletions(-) diff --git a/src/session.rs b/src/session.rs index e1e2f36..6c3b92f 100644 --- a/src/session.rs +++ b/src/session.rs @@ -103,7 +103,18 @@ where } /// Get a reference to the current session data via a closure. - /// Data will be `None` if there's no active session. + /// The closure's argument will be `None` if there's no active session. + /// + /// # Example + /// ```rust,ignore + /// session.tap(|data| { + /// if let Some(data) = data { + /// println!("Session data: {:?}", data); + /// } else { + /// println!("No active session"); + /// } + /// }); + /// ``` pub fn tap(&self, f: F) -> R where F: FnOnce(Option<&T>) -> R, @@ -112,7 +123,18 @@ where } /// Get a mutable reference to the current session data via a closure. - /// The function's argument will be `None` if there's no active session. + /// The closure's argument will be `None` if there's no active session. + /// + /// # Example + /// ```rust,ignore + /// session.tap_mut(|data| { + /// if let Some(data) = data { + /// data.foo = new_value; + /// } else { + /// println!("No active session"); + /// } + /// }); + /// ``` pub fn tap_mut(&mut self, f: UpdateFn) -> R where UpdateFn: FnOnce(&mut Option) -> R, @@ -129,7 +151,7 @@ where response } - /// Set/update the session data. Will create a new active session if there isn't one. + /// Set/replace the session data. Will create a new active session if there isn't one. pub fn set(&mut self, new_data: T) { self.get_inner_lock() .set_data(new_data, self.get_default_ttl()); @@ -137,7 +159,8 @@ where } /// Set the TTL of the session in seconds. This can be used to extend the length - /// of the session if needed. This has no effect if there is no active session. + /// of the session if needed. This has no effect if there is no active session, or + /// if you have enabled "rolling" sessions in the [`options`](RocketFlexSessionOptions::rolling). pub fn set_ttl(&mut self, new_ttl: u32) { self.get_inner_lock().set_ttl(new_ttl); self.update_cookies(); diff --git a/src/session_index.rs b/src/session_index.rs index f8c2d6c..f5eabb8 100644 --- a/src/session_index.rs +++ b/src/session_index.rs @@ -33,8 +33,9 @@ pub trait SessionIdentifier { /// The type of the identifier type Id: Send + Sync + Clone; - /// Extract the identifier from the session data. - /// Returns `None` if the session doesn't have an identifier and/or + /// Extract the identifier from the session data. This identifier + /// should be immutable for the lifetime of the session. + /// Can return `None` if a session doesn't have an identifier and/or /// shouldn't be indexed. fn identifier(&self) -> Option<&Self::Id>; } diff --git a/src/storage/redis.rs b/src/storage/redis.rs index 4cc3a41..3b81e23 100644 --- a/src/storage/redis.rs +++ b/src/storage/redis.rs @@ -75,9 +75,9 @@ pub struct RedisFredStorage { /// Redis session storage using the [fred.rs](https://docs.rs/fred) crate. This is a wrapper around /// [`RedisFredStorage`] that adds support for indexing sessions by an identifier (e.g. `user_id`). /// -/// In addition to the requirements for [`RedisFredStorage`], your session data type must +/// In addition to the requirements for `RedisFredStorage`, your session data type must /// implement [`SessionIdentifier`], and its [Id](`SessionIdentifier::Id`) type -/// must implement [`ToString`]. Sessions are tracked in Redis sets, with a key format of +/// must implement `ToString`. Sessions are tracked in Redis sets, with a key format of /// `:`. e.g.: `sess:user_id:1` pub struct RedisFredStorageIndexed { base_storage: RedisFredStorage, From 8dcc11b6329fa7ebab198e85bf5a14c5484b7156 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sun, 7 Sep 2025 05:27:39 -0400 Subject: [PATCH 25/28] add debug logs for saving and deleting session --- src/fairing.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/fairing.rs b/src/fairing.rs index 181980b..7a31978 100644 --- a/src/fairing.rs +++ b/src/fairing.rs @@ -160,17 +160,21 @@ where // Handle deleted session if let Some(deleted_id) = deleted { - let delete_result = self.storage.delete(&deleted_id, req.cookies()).await; - if let Err(e) = delete_result { + rocket::debug!("Found deleted session. Deleting session '{deleted_id}'..."); + if let Err(e) = self.storage.delete(&deleted_id, req.cookies()).await { rocket::warn!("Error while deleting session '{deleted_id}': {e}"); + } else { + rocket::debug!("Deleted session '{deleted_id}' successfully"); } } // Handle updated session if let Some((id, pending_data, ttl)) = updated { - let save_result = self.storage.save(&id, pending_data, ttl).await; - if let Err(e) = save_result { + rocket::debug!("Found updated session. Saving session '{id}'..."); + if let Err(e) = self.storage.save(&id, pending_data, ttl).await { rocket::error!("Error while saving session '{id}': {e}"); + } else { + rocket::debug!("Saved session '{id}' successfully"); } } } From e4db35fdf22baac7c4bdb07963d9678bdc371994 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sun, 7 Sep 2025 05:38:49 -0400 Subject: [PATCH 26/28] show session retrieval errors as info level logs --- src/guard.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/guard.rs b/src/guard.rs index 54ba437..32cd2b8 100644 --- a/src/guard.rs +++ b/src/guard.rs @@ -77,7 +77,7 @@ async fn fetch_session_data<'r, T: Send + Sync + Clone>( let session_cookie = cookie_jar.get_private(cookie_name); if let Some(cookie) = session_cookie { let id = cookie.value(); - rocket::debug!("Got session id '{}' from cookie. Retrieving session...", id); + rocket::debug!("Got session id '{id}' from cookie. Retrieving session..."); match storage.load(id, rolling_ttl, cookie_jar).await { Ok((data, ttl)) => { rocket::debug!("Session found. Creating existing session..."); @@ -85,7 +85,7 @@ async fn fetch_session_data<'r, T: Send + Sync + Clone>( (Mutex::new(session_inner), None) } Err(e) => { - rocket::debug!("Error from session storage, creating empty session: {}", e); + rocket::info!("Error from session storage, creating empty session: {e}"); (Mutex::default(), Some(e)) } } From dee436dbeff8b97a7e3521489b82e175c28bd657 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sun, 7 Sep 2025 13:16:37 -0400 Subject: [PATCH 27/28] re-organize modules --- src/lib.rs | 3 +- src/storage/redis.rs | 81 +------- src/storage/redis/base.rs | 68 ++++++- src/storage/redis/storage.rs | 2 +- src/storage/redis/storage_indexed.rs | 62 +++--- src/storage/sqlx.rs | 284 +------------------------- src/storage/sqlx/postgres.rs | 287 +++++++++++++++++++++++++++ tests/storages_basic.rs | 17 +- tests/storages_indexed.rs | 8 +- 9 files changed, 408 insertions(+), 404 deletions(-) create mode 100644 src/storage/sqlx/postgres.rs diff --git a/src/lib.rs b/src/lib.rs index 5784e66..615d9d1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -185,9 +185,10 @@ This crate supports multiple storage backends with different capabilities: | Storage | Feature Flag | Indexing support | HashMap support | Use Cases | |---------|-------------|------------------|----------|----------| | [`storage::memory::MemoryStorage`] | Built-in | ❌ | ✅ | Development, testing | -| [`storage::memory::MemoryStorageIndexed`] | Built-in | ✅ | ✅ | Development with indexing features | +| [`storage::memory::MemoryStorageIndexed`] | Built-in | ✅ | ❌ | Development with indexing features | | [`storage::cookie::CookieStorage`] | `cookie` | ❌ | ✅ | Client-side storage, stateless servers | | [`storage::redis::RedisFredStorage`] | `redis_fred` | ❌ | ✅ | Production, distributed systems | +| [`storage::redis::RedisFredStorageIndexed`] | `redis_fred` | ✅ | ❌ | Production, distributed systems | | [`storage::sqlx::SqlxPostgresStorage`] | `sqlx_postgres` | ✅ | ❌ | Production, existing database | diff --git a/src/storage/redis.rs b/src/storage/redis.rs index 3b81e23..c4d2048 100644 --- a/src/storage/redis.rs +++ b/src/storage/redis.rs @@ -4,82 +4,5 @@ mod base; mod storage; mod storage_indexed; -/// The Redis type to use for the session data -#[derive(Debug)] -pub enum RedisType { - String, - Hash, -} - -/** -Redis session storage using the [fred.rs](https://docs.rs/fred) crate. - -You can store the data as a Redis string or hash. Your session data type must implement [`FromValue`](https://docs.rs/fred/latest/fred/types/trait.FromValue.html) -from the fred.rs crate, as well as the inverse `From` or `TryFrom` for [`Value`](https://docs.rs/fred/latest/fred/types/enum.Value.html) in order -to dictate how the data will be converted to/from the Redis data type. -- For Redis string types, convert to/from `Value::String` -- For Redis hash types, convert to/from `Value::Map` - -💡 Common hashmap types like `HashMap` are automatically supported - make sure to use `RedisType::Hash` -when constructing the storage to ensure they are properly converted and stored as Redis hashes. - -```rust -use fred::prelude::{Builder, ClientLike, Config, FromValue, Value}; -use rocket_flex_session::{error::SessionError, storage::{redis::{RedisFredStorage, RedisType}}}; - -async fn setup_storage() -> RedisFredStorage { - // Setup and initialize a fred.rs Redis pool. - let redis_pool = Builder::default_centralized() - .set_config(Config::from_url("redis://localhost").expect("Valid Redis URL")) - .build_pool(4) - .expect("Should build Redis pool"); - redis_pool.init().await.expect("Should initialize Redis pool"); - - // Construct the storage - let storage = RedisFredStorage::new( - redis_pool, - RedisType::String, // or RedisType::Hash - "sess:" // Prefix for Redis keys - ); - - storage -} - -// If using a custom struct for your session data, implement the following... -struct MySessionData { - user_id: String, -} -// Implement `FromValue` to convert from the Redis value to your session data type -impl FromValue for MySessionData { - fn from_value(value: Value) -> Result { - let data: String = value.convert()?; // fred.rs provides several conversion methods on the Value type - Ok(MySessionData { - user_id: data, - }) - } -} -// Implement the inverse conversion -impl From for Value { - fn from(data: MySessionData) -> Self { - Value::String(data.user_id.into()) - } -} -``` -*/ -pub struct RedisFredStorage { - pool: fred::prelude::Pool, - prefix: String, - redis_type: RedisType, -} - -/// Redis session storage using the [fred.rs](https://docs.rs/fred) crate. This is a wrapper around -/// [`RedisFredStorage`] that adds support for indexing sessions by an identifier (e.g. `user_id`). -/// -/// In addition to the requirements for `RedisFredStorage`, your session data type must -/// implement [`SessionIdentifier`], and its [Id](`SessionIdentifier::Id`) type -/// must implement `ToString`. Sessions are tracked in Redis sets, with a key format of -/// `:`. e.g.: `sess:user_id:1` -pub struct RedisFredStorageIndexed { - base_storage: RedisFredStorage, - index_ttl: u32, -} +pub use base::{RedisFredStorage, RedisType}; +pub use storage_indexed::RedisFredStorageIndexed; diff --git a/src/storage/redis/base.rs b/src/storage/redis/base.rs index 3cbf214..d5778f5 100644 --- a/src/storage/redis/base.rs +++ b/src/storage/redis/base.rs @@ -5,7 +5,73 @@ use fred::{ use crate::error::{SessionError, SessionResult}; -use super::{RedisFredStorage, RedisType}; +/// The Redis type to use for the session data +#[derive(Debug)] +pub enum RedisType { + String, + Hash, +} + +/** +Redis session storage using the [fred.rs](https://docs.rs/fred) crate. + +You can store the data as a Redis string or hash. Your session data type must implement [`FromValue`](https://docs.rs/fred/latest/fred/types/trait.FromValue.html) +from the fred.rs crate, as well as the inverse `From` or `TryFrom` for [`Value`](https://docs.rs/fred/latest/fred/types/enum.Value.html) in order +to dictate how the data will be converted to/from the Redis data type. +- For Redis string types, convert to/from `Value::String` +- For Redis hash types, convert to/from `Value::Map` + +💡 Common hashmap types like `HashMap` are automatically supported - make sure to use `RedisType::Hash` +when constructing the storage to ensure they are properly converted and stored as Redis hashes. + +```rust +use fred::prelude::{Builder, ClientLike, Config, FromValue, Value}; +use rocket_flex_session::{error::SessionError, storage::{redis::{RedisFredStorage, RedisType}}}; + +async fn setup_storage() -> RedisFredStorage { + // Setup and initialize a fred.rs Redis pool. + let redis_pool = Builder::default_centralized() + .set_config(Config::from_url("redis://localhost").expect("Valid Redis URL")) + .build_pool(4) + .expect("Should build Redis pool"); + redis_pool.init().await.expect("Should initialize Redis pool"); + + // Construct the storage + let storage = RedisFredStorage::new( + redis_pool, + RedisType::String, // or RedisType::Hash + "sess:" // Prefix for Redis keys + ); + + storage +} + +// If using a custom struct for your session data, implement the following... +struct MySessionData { + user_id: String, +} +// Implement `FromValue` to convert from the Redis value to your session data type +impl FromValue for MySessionData { + fn from_value(value: Value) -> Result { + let data: String = value.convert()?; // fred.rs provides several conversion methods on the Value type + Ok(MySessionData { + user_id: data, + }) + } +} +// Implement the inverse conversion +impl From for Value { + fn from(data: MySessionData) -> Self { + Value::String(data.user_id.into()) + } +} +``` +*/ +pub struct RedisFredStorage { + pub(super) pool: fred::prelude::Pool, + pub(super) prefix: String, + pub(super) redis_type: RedisType, +} impl RedisFredStorage { /// Create the storage instance. diff --git a/src/storage/redis/storage.rs b/src/storage/redis/storage.rs index 240aad4..a07ffee 100644 --- a/src/storage/redis/storage.rs +++ b/src/storage/redis/storage.rs @@ -34,7 +34,7 @@ where } async fn delete(&self, id: &str, _cookie_jar: &CookieJar) -> SessionResult<()> { - let _: u8 = self.pool.del(self.session_key(id)).await?; + let _: () = self.pool.del(self.session_key(id)).await?; Ok(()) } } diff --git a/src/storage/redis/storage_indexed.rs b/src/storage/redis/storage_indexed.rs index 494f405..363799a 100644 --- a/src/storage/redis/storage_indexed.rs +++ b/src/storage/redis/storage_indexed.rs @@ -7,15 +7,27 @@ use crate::{ SessionIdentifier, }; -use super::{RedisFredStorage, RedisFredStorageIndexed}; +use super::RedisFredStorage; const DEFAULT_INDEX_TTL: u32 = 60 * 60 * 24 * 7 * 2; // 2 weeks +/// Redis session storage using the [fred.rs](https://docs.rs/fred) crate. This is a wrapper around +/// [`RedisFredStorage`] that adds support for indexing sessions by an identifier (e.g. `user_id`). +/// +/// In addition to the requirements for `RedisFredStorage`, your session data type must +/// implement [`SessionIdentifier`], and its [Id](`SessionIdentifier::Id`) type +/// must implement `ToString`. Sessions are tracked in Redis sets, with a key format of +/// `:`. e.g.: `sess:user_id:1` +pub struct RedisFredStorageIndexed { + base_storage: RedisFredStorage, + index_ttl: u32, +} + impl RedisFredStorageIndexed { /// Create the indexed storage. /// /// # Parameters: - /// - `base_storage`: The base storage to use for session data. + /// - `base_storage`: The [`RedisFredStorage`] instance to use. /// - `index_ttl`: The TTL for the session index - should match /// your longest expected session duration (default: 2 weeks). pub fn new(base_storage: RedisFredStorage, index_ttl: Option) -> Self { @@ -108,6 +120,29 @@ where >::Error: std::error::Error + Send + Sync + 'static, ::Id: ToString, { + async fn get_session_ids_by_identifier(&self, id: &T::Id) -> SessionResult> { + let (session_ids, index_key) = self.fetch_session_index(T::IDENTIFIER, id).await?; + + let session_exist_pipeline = self.base_storage.pool.next().pipeline(); + for session_id in &session_ids { + let session_key = self.base_storage.session_key(&session_id); + let _: () = session_exist_pipeline.exists(&session_key).await?; + } + let session_exist_results: Vec = session_exist_pipeline.all().await?; + + let (existing_sessions, stale_sessions): (Vec<_>, Vec<_>) = session_ids + .into_iter() + .zip(session_exist_results.into_iter()) + .partition(|(_, exists)| *exists); + if !stale_sessions.is_empty() { + let stale_ids: Vec<_> = stale_sessions.into_iter().map(|(id, _)| id).collect(); + self.cleanup_session_index(&index_key, stale_ids).await?; + } + + let sessions = existing_sessions.into_iter().map(|(id, _)| id).collect(); + Ok(sessions) + } + async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult> { let (session_ids, index_key) = self.fetch_session_index(T::IDENTIFIER, id).await?; @@ -149,29 +184,6 @@ where Ok(sessions) } - async fn get_session_ids_by_identifier(&self, id: &T::Id) -> SessionResult> { - let (session_ids, index_key) = self.fetch_session_index(T::IDENTIFIER, id).await?; - - let session_exist_pipeline = self.base_storage.pool.next().pipeline(); - for session_id in &session_ids { - let session_key = self.base_storage.session_key(&session_id); - let _: () = session_exist_pipeline.exists(&session_key).await?; - } - let session_exist_results: Vec = session_exist_pipeline.all().await?; - - let (existing_sessions, stale_sessions): (Vec<_>, Vec<_>) = session_ids - .into_iter() - .zip(session_exist_results.into_iter()) - .partition(|(_, exists)| *exists); - if !stale_sessions.is_empty() { - let stale_ids: Vec<_> = stale_sessions.into_iter().map(|(id, _)| id).collect(); - self.cleanup_session_index(&index_key, stale_ids).await?; - } - - let sessions = existing_sessions.into_iter().map(|(id, _)| id).collect(); - Ok(sessions) - } - async fn invalidate_sessions_by_identifier( &self, id: &T::Id, diff --git a/src/storage/sqlx.rs b/src/storage/sqlx.rs index 9b1e88e..68ac06b 100644 --- a/src/storage/sqlx.rs +++ b/src/storage/sqlx.rs @@ -1,285 +1,5 @@ //! Session storage via sqlx -use rocket::{ - async_trait, - http::CookieJar, - tokio::{ - self, - sync::{oneshot, Mutex}, - time::interval, - }, -}; -use sqlx::{postgres::PgRow, PgPool, Row}; -use time::{Duration, OffsetDateTime}; +mod postgres; -use crate::{ - error::{SessionError, SessionResult}, - storage::SessionStorageIndexed, - SessionIdentifier, -}; - -use super::interface::SessionStorage; - -/** -Session store using PostgreSQL via [sqlx](https://docs.rs/crate/sqlx) that stores session data as a string, and supports session indexing. - -You'll need to implement `ToString` (or Display) and `TryFrom` for your session data type. You'll also need to implement [`SessionIdentifier`], -and its [`Id`](crate::SessionIdentifier::Id) must be a [type supported by sqlx](https://docs.rs/sqlx/latest/sqlx/postgres/types/index.html). -Expects a table to already exist with the following columns: - -| Name | Type | -|------|---------| -| id | `text` PRIMARY KEY | -| data | `text` NOT NULL (or `jsonb` if using JSON) | -| `` | `` (the name and type should match your [`SessionIdentifier`] impl) | -| expires | `timestamptz` NOT NULL | -*/ -pub struct SqlxPostgresStorage { - pool: PgPool, - table_name: String, - cleanup_interval: Option, - shutdown_tx: Mutex>>, -} - -impl SqlxPostgresStorage { - /// Creates a new [`SqlxPostgresStorage`]. - /// - /// Parameters: - /// - `pool`: An initialized Postgres connection pool. - /// - `table_name`: The name of the table to use for storing sessions. - /// - `cleanup_interval`: Interval to check for and clean up expired sessions. If `None`, - /// expired sessions won't be cleaned up automatically. - pub fn new( - pool: PgPool, - table_name: &str, - cleanup_interval: Option, - ) -> SqlxPostgresStorage { - Self { - pool, - table_name: table_name.to_owned(), - cleanup_interval, - shutdown_tx: Mutex::default(), - } - } - - fn id_from_row(&self, row: &PgRow) -> sqlx::Result { - row.try_get(ID_COLUMN) - } - - fn raw_data_from_row(&self, row: &PgRow) -> sqlx::Result { - row.try_get(DATA_COLUMN) - } - - fn ttl_from_row(&self, row: &PgRow) -> sqlx::Result { - let expires: OffsetDateTime = row.try_get(EXPIRES_COLUMN)?; - let ttl = (expires - OffsetDateTime::now_utc()) - .whole_seconds() - .try_into() - .unwrap_or(0); - Ok(ttl) - } -} - -const ID_COLUMN: &str = "id"; -const DATA_COLUMN: &str = "data"; -const EXPIRES_COLUMN: &str = "expires"; - -#[async_trait] -impl SessionStorage for SqlxPostgresStorage -where - T: SessionIdentifier + TryFrom + ToString + Clone + Send + Sync + 'static, - ::Id: - for<'q> sqlx::Encode<'q, sqlx::Postgres> + sqlx::Type, - >::Error: std::error::Error + Send + Sync + 'static, -{ - async fn load( - &self, - id: &str, - ttl: Option, - _cookie_jar: &CookieJar, - ) -> SessionResult<(T, u32)> { - let row = match ttl { - Some(new_ttl) => { - let sql = format!( - "UPDATE \"{}\" SET {EXPIRES_COLUMN} = $1 \ - WHERE {ID_COLUMN} = $2 AND {EXPIRES_COLUMN} > CURRENT_TIMESTAMP \ - RETURNING {DATA_COLUMN}, {EXPIRES_COLUMN}", - &self.table_name, - ); - sqlx::query(&sql) - .bind(OffsetDateTime::now_utc() + Duration::seconds(new_ttl.into())) - .bind(id) - .fetch_optional(&self.pool) - .await? - } - None => { - let sql = format!( - "SELECT {DATA_COLUMN}, {EXPIRES_COLUMN} FROM \"{}\" \ - WHERE {ID_COLUMN} = $1 AND {EXPIRES_COLUMN} > CURRENT_TIMESTAMP", - &self.table_name, - ); - sqlx::query(&sql) - .bind(id) - .fetch_optional(&self.pool) - .await? - } - }; - - let (raw_str, ttl) = match row { - Some(row) => { - let data = self.raw_data_from_row(&row)?; - let ttl = self.ttl_from_row(&row)?; - (data, ttl) - } - None => return Err(SessionError::NotFound), - }; - let data = T::try_from(raw_str).map_err(|e| SessionError::Serialization(Box::new(e)))?; - Ok((data, ttl)) - } - - async fn save(&self, id: &str, data: T, ttl: u32) -> SessionResult<()> { - let sql = format!( - "INSERT INTO \"{}\" ({ID_COLUMN}, {}, {DATA_COLUMN}, {EXPIRES_COLUMN}) \ - VALUES ($1, $2, $3, $4) \ - ON CONFLICT ({ID_COLUMN}) DO UPDATE SET \ - {DATA_COLUMN} = EXCLUDED.{DATA_COLUMN}, \ - {EXPIRES_COLUMN} = EXCLUDED.{EXPIRES_COLUMN}", - self.table_name, - T::IDENTIFIER - ); - sqlx::query(&sql) - .bind(id) - .bind(data.identifier()) - .bind(data.to_string()) - .bind(OffsetDateTime::now_utc() + Duration::seconds(ttl.into())) - .execute(&self.pool) - .await?; - - Ok(()) - } - - async fn delete(&self, id: &str, _cookie_jar: &CookieJar) -> SessionResult<()> { - let sql = format!("DELETE FROM {} WHERE {ID_COLUMN} = $1", &self.table_name); - sqlx::query(&sql).bind(id).execute(&self.pool).await?; - - Ok(()) - } - - async fn setup(&self) -> SessionResult<()> { - let Some(cleanup_interval) = self.cleanup_interval else { - return Ok(()); - }; - let (tx, mut rx) = oneshot::channel(); - let pool = self.pool.clone(); - let table_name = self.table_name.clone(); - tokio::spawn(async move { - rocket::info!("Starting session cleanup monitor"); - let mut interval = interval(cleanup_interval); - loop { - tokio::select! { - _ = interval.tick() => { - rocket::debug!("Cleaning up expired sessions"); - if let Err(e) = cleanup_expired_sessions(&table_name, &pool).await { - rocket::error!("Error deleting expired sessions: {e}"); - } - } - _ = &mut rx => { - rocket::info!("Session cleanup monitor shutdown"); - } - } - } - }); - self.shutdown_tx.lock().await.replace(tx); - - Ok(()) - } - - async fn shutdown(&self) -> SessionResult<()> { - if let Some(tx) = self.shutdown_tx.lock().await.take() { - tx.send(()).map_err(|_| { - SessionError::SetupTeardown("Failed to send shutdown signal".to_string()) - })?; - } - Ok(()) - } -} - -async fn cleanup_expired_sessions(table_name: &str, pool: &PgPool) -> Result { - rocket::debug!("Cleaning up expired sessions"); - let sql = format!("DELETE FROM \"{table_name}\" WHERE {EXPIRES_COLUMN} < $1"); - let rows = sqlx::query(&sql) - .bind(OffsetDateTime::now_utc()) - .execute(pool) - .await?; - Ok(rows.rows_affected()) -} - -#[async_trait] -impl SessionStorageIndexed for SqlxPostgresStorage -where - T: SessionIdentifier + TryFrom + ToString + Clone + Send + Sync + 'static, - ::Id: - for<'q> sqlx::Encode<'q, sqlx::Postgres> + sqlx::Type, - >::Error: std::error::Error + Send + Sync + 'static, -{ - async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult> { - let sql = format!( - "SELECT {ID_COLUMN}, {DATA_COLUMN}, {EXPIRES_COLUMN} FROM \"{}\" \ - WHERE {} = $1 AND {EXPIRES_COLUMN} > CURRENT_TIMESTAMP", - &self.table_name, - T::IDENTIFIER - ); - let rows = sqlx::query(&sql).bind(id).fetch_all(&self.pool).await?; - let parsed_rows = rows - .into_iter() - .filter_map(|row| { - let id = self.id_from_row(&row).ok()?; - let raw_data = self.raw_data_from_row(&row).ok()?; - let data = T::try_from(raw_data).ok()?; - let ttl = self.ttl_from_row(&row).ok()?; - Some((id, data, ttl)) - }) - .collect(); - - Ok(parsed_rows) - } - - async fn get_session_ids_by_identifier(&self, id: &T::Id) -> SessionResult> { - let sql = format!( - "SELECT {ID_COLUMN} FROM \"{}\" \ - WHERE {} = $1 AND {EXPIRES_COLUMN} > CURRENT_TIMESTAMP", - &self.table_name, - T::IDENTIFIER - ); - let rows = sqlx::query(&sql).bind(id).fetch_all(&self.pool).await?; - let parsed_rows = rows - .into_iter() - .filter_map(|row| self.id_from_row(&row).ok()) - .collect(); - - Ok(parsed_rows) - } - - async fn invalidate_sessions_by_identifier( - &self, - id: &T::Id, - excluded_session_id: Option<&str>, - ) -> SessionResult { - let mut sql = format!( - "DELETE FROM \"{}\" WHERE {} = $1", - &self.table_name, - T::IDENTIFIER - ); - if excluded_session_id.is_some() { - sql.push_str(&format!(" AND {ID_COLUMN} != $2")); - } - - let mut query = sqlx::query(&sql).bind(id); - if let Some(excluded_id) = excluded_session_id { - query = query.bind(excluded_id); - } - let rows = query.execute(&self.pool).await?; - - Ok(rows.rows_affected()) - } -} +pub use postgres::SqlxPostgresStorage; diff --git a/src/storage/sqlx/postgres.rs b/src/storage/sqlx/postgres.rs new file mode 100644 index 0000000..702ec1e --- /dev/null +++ b/src/storage/sqlx/postgres.rs @@ -0,0 +1,287 @@ +use rocket::{ + async_trait, + http::CookieJar, + tokio::{ + self, + sync::{oneshot, Mutex}, + time::interval, + }, +}; +use sqlx::{postgres::PgRow, PgPool, Row}; +use time::{Duration, OffsetDateTime}; + +use crate::{ + error::{SessionError, SessionResult}, + storage::{SessionStorage, SessionStorageIndexed}, + SessionIdentifier, +}; + +const ID_COLUMN: &str = "id"; +const DATA_COLUMN: &str = "data"; +const EXPIRES_COLUMN: &str = "expires"; + +/** +Session store using PostgreSQL via [sqlx](https://docs.rs/crate/sqlx) that stores session data as a string, and supports session indexing. + +You'll need to implement `TryInto` and `TryFrom` for your session data type. You'll also need to implement [`SessionIdentifier`], +and its [`Id`](crate::SessionIdentifier::Id) must be a [type supported by sqlx](https://docs.rs/sqlx/latest/sqlx/postgres/types/index.html). +Expects a table to already exist with the following columns: + +| Name | Type | +|------|---------| +| id | `text` PRIMARY KEY | +| data | `text` NOT NULL (or `jsonb` if using JSON) | +| `` | `` (the name and type should match your [`SessionIdentifier`] impl) | +| expires | `timestamptz` NOT NULL | +*/ +pub struct SqlxPostgresStorage { + pool: PgPool, + table_name: String, + cleanup_interval: Option, + shutdown_tx: Mutex>>, +} + +impl SqlxPostgresStorage { + /// Creates a new [`SqlxPostgresStorage`]. + /// + /// Parameters: + /// - `pool`: An initialized Postgres connection pool. + /// - `table_name`: The name of the table to use for storing sessions. + /// - `cleanup_interval`: Interval to check for and clean up expired sessions. If `None`, + /// expired sessions won't be cleaned up automatically. + pub fn new( + pool: PgPool, + table_name: &str, + cleanup_interval: Option, + ) -> SqlxPostgresStorage { + Self { + pool, + table_name: table_name.to_owned(), + cleanup_interval, + shutdown_tx: Mutex::default(), + } + } + + fn id_from_row(&self, row: &PgRow) -> sqlx::Result { + row.try_get(ID_COLUMN) + } + + fn raw_data_from_row(&self, row: &PgRow) -> sqlx::Result { + row.try_get(DATA_COLUMN) + } + + fn ttl_from_row(&self, row: &PgRow) -> sqlx::Result { + let expires: OffsetDateTime = row.try_get(EXPIRES_COLUMN)?; + let ttl = (expires - OffsetDateTime::now_utc()) + .whole_seconds() + .try_into() + .unwrap_or(0); + Ok(ttl) + } +} + +#[async_trait] +impl SessionStorage for SqlxPostgresStorage +where + T: SessionIdentifier + TryFrom + TryInto + Clone + Send + Sync + 'static, + ::Id: + for<'q> sqlx::Encode<'q, sqlx::Postgres> + sqlx::Type, + >::Error: std::error::Error + Send + Sync + 'static, + >::Error: std::error::Error + Send + Sync + 'static, +{ + async fn load( + &self, + id: &str, + ttl: Option, + _cookie_jar: &CookieJar, + ) -> SessionResult<(T, u32)> { + let row = match ttl { + Some(new_ttl) => { + let sql = format!( + "UPDATE \"{}\" SET {EXPIRES_COLUMN} = $1 \ + WHERE {ID_COLUMN} = $2 AND {EXPIRES_COLUMN} > CURRENT_TIMESTAMP \ + RETURNING {DATA_COLUMN}, {EXPIRES_COLUMN}", + &self.table_name, + ); + sqlx::query(&sql) + .bind(OffsetDateTime::now_utc() + Duration::seconds(new_ttl.into())) + .bind(id) + .fetch_optional(&self.pool) + .await? + } + None => { + let sql = format!( + "SELECT {DATA_COLUMN}, {EXPIRES_COLUMN} FROM \"{}\" \ + WHERE {ID_COLUMN} = $1 AND {EXPIRES_COLUMN} > CURRENT_TIMESTAMP", + &self.table_name, + ); + sqlx::query(&sql) + .bind(id) + .fetch_optional(&self.pool) + .await? + } + }; + + let (raw_str, ttl) = match row { + Some(row) => { + let data = self.raw_data_from_row(&row)?; + let ttl = self.ttl_from_row(&row)?; + (data, ttl) + } + None => return Err(SessionError::NotFound), + }; + let data = T::try_from(raw_str).map_err(|e| SessionError::Serialization(Box::new(e)))?; + Ok((data, ttl)) + } + + async fn save(&self, id: &str, data: T, ttl: u32) -> SessionResult<()> { + let sql = format!( + "INSERT INTO \"{}\" ({ID_COLUMN}, {}, {DATA_COLUMN}, {EXPIRES_COLUMN}) \ + VALUES ($1, $2, $3, $4) \ + ON CONFLICT ({ID_COLUMN}) DO UPDATE SET \ + {DATA_COLUMN} = EXCLUDED.{DATA_COLUMN}, \ + {EXPIRES_COLUMN} = EXCLUDED.{EXPIRES_COLUMN}", + self.table_name, + T::IDENTIFIER + ); + let identifier = data.identifier().cloned(); + let data_str: String = data + .try_into() + .map_err(|e| SessionError::Serialization(Box::new(e)))?; + sqlx::query(&sql) + .bind(id) + .bind(identifier) + .bind(data_str) + .bind(OffsetDateTime::now_utc() + Duration::seconds(ttl.into())) + .execute(&self.pool) + .await?; + + Ok(()) + } + + async fn delete(&self, id: &str, _cookie_jar: &CookieJar) -> SessionResult<()> { + let sql = format!("DELETE FROM {} WHERE {ID_COLUMN} = $1", &self.table_name); + sqlx::query(&sql).bind(id).execute(&self.pool).await?; + + Ok(()) + } + + async fn setup(&self) -> SessionResult<()> { + let Some(cleanup_interval) = self.cleanup_interval else { + return Ok(()); + }; + let (tx, mut rx) = oneshot::channel(); + let pool = self.pool.clone(); + let table_name = self.table_name.clone(); + tokio::spawn(async move { + rocket::info!("Starting session cleanup monitor"); + let mut interval = interval(cleanup_interval); + loop { + tokio::select! { + _ = interval.tick() => { + rocket::debug!("Cleaning up expired sessions"); + if let Err(e) = cleanup_expired_sessions(&table_name, &pool).await { + rocket::error!("Error deleting expired sessions: {e}"); + } + } + _ = &mut rx => { + rocket::info!("Session cleanup monitor shutdown"); + } + } + } + }); + self.shutdown_tx.lock().await.replace(tx); + + Ok(()) + } + + async fn shutdown(&self) -> SessionResult<()> { + if let Some(tx) = self.shutdown_tx.lock().await.take() { + tx.send(()).map_err(|_| { + SessionError::SetupTeardown("Failed to send shutdown signal".to_string()) + })?; + } + Ok(()) + } +} + +async fn cleanup_expired_sessions(table_name: &str, pool: &PgPool) -> Result { + rocket::debug!("Cleaning up expired sessions"); + let sql = format!("DELETE FROM \"{table_name}\" WHERE {EXPIRES_COLUMN} < $1"); + let rows = sqlx::query(&sql) + .bind(OffsetDateTime::now_utc()) + .execute(pool) + .await?; + Ok(rows.rows_affected()) +} + +#[async_trait] +impl SessionStorageIndexed for SqlxPostgresStorage +where + T: SessionIdentifier + TryFrom + TryInto + Clone + Send + Sync + 'static, + ::Id: + for<'q> sqlx::Encode<'q, sqlx::Postgres> + sqlx::Type, + >::Error: std::error::Error + Send + Sync + 'static, + >::Error: std::error::Error + Send + Sync + 'static, +{ + async fn get_session_ids_by_identifier(&self, id: &T::Id) -> SessionResult> { + let sql = format!( + "SELECT {ID_COLUMN} FROM \"{}\" \ + WHERE {} = $1 AND {EXPIRES_COLUMN} > CURRENT_TIMESTAMP", + &self.table_name, + T::IDENTIFIER + ); + let rows = sqlx::query(&sql).bind(id).fetch_all(&self.pool).await?; + let parsed_rows = rows + .into_iter() + .filter_map(|row| self.id_from_row(&row).ok()) + .collect(); + + Ok(parsed_rows) + } + + async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult> { + let sql = format!( + "SELECT {ID_COLUMN}, {DATA_COLUMN}, {EXPIRES_COLUMN} FROM \"{}\" \ + WHERE {} = $1 AND {EXPIRES_COLUMN} > CURRENT_TIMESTAMP", + &self.table_name, + T::IDENTIFIER + ); + let rows = sqlx::query(&sql).bind(id).fetch_all(&self.pool).await?; + let parsed_rows = rows + .into_iter() + .filter_map(|row| { + let id = self.id_from_row(&row).ok()?; + let raw_data = self.raw_data_from_row(&row).ok()?; + let data = T::try_from(raw_data).ok()?; + let ttl = self.ttl_from_row(&row).ok()?; + Some((id, data, ttl)) + }) + .collect(); + + Ok(parsed_rows) + } + + async fn invalidate_sessions_by_identifier( + &self, + id: &T::Id, + excluded_session_id: Option<&str>, + ) -> SessionResult { + let mut sql = format!( + "DELETE FROM \"{}\" WHERE {} = $1", + &self.table_name, + T::IDENTIFIER + ); + if excluded_session_id.is_some() { + sql.push_str(&format!(" AND {ID_COLUMN} != $2")); + } + + let mut query = sqlx::query(&sql).bind(id); + if let Some(excluded_id) = excluded_session_id { + query = query.bind(excluded_id); + } + let rows = query.execute(&self.pool).await?; + + Ok(rows.rows_affected()) + } +} diff --git a/tests/storages_basic.rs b/tests/storages_basic.rs index f3f6b0b..1f21a52 100644 --- a/tests/storages_basic.rs +++ b/tests/storages_basic.rs @@ -35,6 +35,11 @@ impl TryFrom for SessionData { Ok(Self { user_id: value }) } } +impl From for String { + fn from(value: SessionData) -> Self { + value.user_id + } +} impl fred::types::FromValue for SessionData { fn from_value(value: fred::prelude::Value) -> Result { Ok(Self { @@ -42,18 +47,6 @@ impl fred::types::FromValue for SessionData { }) } } -impl std::fmt::Display for SessionData { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.user_id) - } -} -impl TryFrom for SessionData { - type Error = SessionError; - fn try_from(value: fred::types::Value) -> Result { - let user_id = value.as_string().ok_or(SessionError::NotFound)?; - Ok(Self { user_id }) - } -} impl From for fred::types::Value { fn from(value: SessionData) -> Self { Self::String(value.user_id.into()) diff --git a/tests/storages_indexed.rs b/tests/storages_indexed.rs index fa13ba3..3181827 100644 --- a/tests/storages_indexed.rs +++ b/tests/storages_indexed.rs @@ -32,9 +32,11 @@ impl SessionIdentifier for TestSession { } // Impls for Sqlx -impl ToString for TestSession { - fn to_string(&self) -> String { - format!("{}:{}", self.user_id, self.data) +impl TryFrom for String { + type Error = std::io::Error; + + fn try_from(value: TestSession) -> Result { + Ok(format!("{}:{}", value.user_id, value.data)) } } impl TryFrom for TestSession { From 286be0a3c0a5b0e6dc2b04730039d13cd27a4714 Mon Sep 17 00:00:00 2001 From: fa-sharp Date: Sun, 7 Sep 2025 14:52:49 -0400 Subject: [PATCH 28/28] add interface for hashmap structures --- src/lib.rs | 62 ++++++++++++++++++++++++-------------- src/session.rs | 66 ++-------------------------------------- src/session_hash.rs | 71 ++++++++++++++++++++++++++++++++++++++++++++ src/session_index.rs | 4 +-- tests/basic.rs | 32 +++++++++++++++----- 5 files changed, 138 insertions(+), 97 deletions(-) create mode 100644 src/session_hash.rs diff --git a/src/lib.rs b/src/lib.rs index 615d9d1..641ff57 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -104,21 +104,36 @@ For more info and examples of this powerful pattern, please see Rocket's documen ## HashMap session data -Instead of a custom struct, you can use a [HashMap](std::collections::HashMap) as your Session data type if the -storage provider supports it. This is particularly useful if you expect your session data structure to be inconsistent and/or change frequently. -When using a HashMap, there are [some additional helper functions](file:///Users/farshad/Projects/pg-user-manager/api/target/doc/rocket_flex_session/struct.Session.html#method.get_key) -to read and set keys. +If your session data has a hashmap data structure, you can implement [`SessionHashMap`] which will +add [additional helper methods](Session::get_key) to Session to read and set keys. This is particularly useful if you expect your +session data structure to be inconsistent and/or change frequently. ``` -use rocket_flex_session::Session; +use rocket_flex_session::{Session, SessionHashMap}; use std::collections::HashMap; -type MySessionData = HashMap; +#[derive(Clone, Default)] +struct MySession(HashMap); + +impl SessionHashMap for MySession { + type Value = String; + + fn get(&self, key: &str) -> Option<&Self::Value> { + self.0.get(key) + } + fn insert(&mut self, key: String, value: Self::Value) { + self.0.insert(key, value); + } + fn remove(&mut self, key: &str) { + self.0.remove(key); + } +} #[rocket::post("/login")] -fn login(mut session: Session) { +fn login(mut session: Session) { let user_id: Option = session.get_key("user_id"); session.set_key("name".to_owned(), "Bob".to_owned()); + session.remove_key("foobar"); } ``` @@ -182,14 +197,14 @@ This crate supports multiple storage backends with different capabilities: ## Available Storage Providers -| Storage | Feature Flag | Indexing support | HashMap support | Use Cases | -|---------|-------------|------------------|----------|----------| -| [`storage::memory::MemoryStorage`] | Built-in | ❌ | ✅ | Development, testing | -| [`storage::memory::MemoryStorageIndexed`] | Built-in | ✅ | ❌ | Development with indexing features | -| [`storage::cookie::CookieStorage`] | `cookie` | ❌ | ✅ | Client-side storage, stateless servers | -| [`storage::redis::RedisFredStorage`] | `redis_fred` | ❌ | ✅ | Production, distributed systems | -| [`storage::redis::RedisFredStorageIndexed`] | `redis_fred` | ✅ | ❌ | Production, distributed systems | -| [`storage::sqlx::SqlxPostgresStorage`] | `sqlx_postgres` | ✅ | ❌ | Production, existing database | +| Storage | Feature Flag | Indexing support | Use Cases | +|---------|-------------|------------------|------------| +| [`storage::memory::MemoryStorage`] | Built-in | ❌ | Development, testing | +| [`storage::memory::MemoryStorageIndexed`] | Built-in | ✅ | Development with indexing features | +| [`storage::cookie::CookieStorage`] | `cookie` | ❌ | Client-side storage, stateless servers | +| [`storage::redis::RedisFredStorage`] | `redis_fred` | ❌ | Production, distributed systems | +| [`storage::redis::RedisFredStorageIndexed`] | `redis_fred` | ✅ | Production, distributed systems | +| [`storage::sqlx::SqlxPostgresStorage`] | `sqlx_postgres` | ✅ | Production, existing database | ## Custom Storage @@ -240,8 +255,8 @@ impl SessionStorageIndexed for MyCustomStorage where T: SessionIdentifier + Send + Sync + Clone + 'static, { - async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult> { - // Return all (session_id, session_data) pairs for the identifier + async fn get_sessions_by_identifier(&self, id: &T::Id) -> SessionResult> { + // Return all sessions (session_id, session_data, session_ttl) for the given identifier todo!() } // etc... @@ -263,12 +278,11 @@ where ### Implementation Tips -1. **Thread Safety**: All storage implementations must be `Send + Sync` -2. **Trait bounds**: Add additional trait bounds to the session data type as needed -3. **Error Handling**: Use [`error::SessionError::Backend`] for custom errors -4. **TTL Handling**: Respect the TTL parameters in `load` and `save` for session expiration -5. **Indexing Consistency**: Keep identifier indexes in sync with session data -6. **Cleanup**: Implement proper cleanup in `shutdown()` if needed +1. **Trait bounds**: Add additional trait bounds to the session data type `` as needed +2. **Error Handling**: Use [`error::SessionError::Backend`] for custom errors +3. **TTL Handling**: Respect the TTL parameters in `load` and `save` for session expiration +4. **Indexing Consistency**: Keep identifier indexes in sync with session data +5. **Cleanup**: Implement proper cleanup in `shutdown()` if needed # Feature flags @@ -287,6 +301,7 @@ mod fairing; mod guard; mod options; mod session; +mod session_hash; mod session_index; mod session_inner; @@ -295,4 +310,5 @@ pub mod storage; pub use fairing::{RocketFlexSession, RocketFlexSessionBuilder}; pub use options::RocketFlexSessionOptions; pub use session::Session; +pub use session_hash::SessionHashMap; pub use session_index::SessionIdentifier; diff --git a/src/session.rs b/src/session.rs index 6c3b92f..ce09bd2 100644 --- a/src/session.rs +++ b/src/session.rs @@ -3,9 +3,7 @@ use rocket::{ time::{Duration, OffsetDateTime}, }; use std::{ - collections::HashMap, fmt::Display, - hash::Hash, marker::{Send, Sync}, sync::{Mutex, MutexGuard}, }; @@ -214,11 +212,11 @@ where self.inner.lock().expect("Failed to get session data lock") } - fn get_default_ttl(&self) -> u32 { + pub(super) fn get_default_ttl(&self) -> u32 { self.options.ttl.unwrap_or(self.options.max_age) } - fn update_cookies(&self) { + pub(super) fn update_cookies(&self) { let inner = self.get_inner_lock(); let Some(id) = inner.get_id() else { rocket::warn!("Cookies not updated: no active session"); @@ -242,66 +240,6 @@ where } } -impl Session<'_, HashMap> -where - K: Eq + Hash + Send + Sync + Clone, - V: Send + Sync + Clone, -{ - /// Get the value of a key in the session data via cloning - pub fn get_key(&self, key: &Q) -> Option - where - Q: ?Sized + Eq + Hash, - K: std::borrow::Borrow, - { - self.get_inner_lock() - .get_current_data() - .and_then(|h| h.get(key).cloned()) - } - - /// Get the value of a key in the session data via a closure - pub fn tap_key(&self, key: &Q, f: F) -> R - where - Q: ?Sized + Eq + Hash, - K: std::borrow::Borrow, - F: FnOnce(Option<&V>) -> R, - { - f(self - .get_inner_lock() - .get_current_data() - .and_then(|d| d.get(key))) - } - - /// Set the value of a key in the session data. Will create a new session if there isn't one. - pub fn set_key(&mut self, key: K, value: V) { - self.get_inner_lock().tap_data_mut( - |data| data.get_or_insert_default().insert(key, value), - self.get_default_ttl(), - ); - self.update_cookies(); - } - - /// Set multiple keys and values in the session data. Will create a new session if there isn't one. - pub fn set_keys(&mut self, kv_iter: I) - where - I: IntoIterator, - { - self.get_inner_lock().tap_data_mut( - |data| data.get_or_insert_default().extend(kv_iter), - self.get_default_ttl(), - ); - self.update_cookies(); - } - - /// Remove a key from the session data. - pub fn remove_key(&mut self, key: K) { - self.get_inner_lock().tap_data_mut( - |data| data.get_or_insert_default().remove(&key), - self.get_default_ttl(), - ); - self.update_cookies(); - } -} - /// Create the session cookie fn create_session_cookie(id: &str, options: &RocketFlexSessionOptions) -> Cookie<'static> { let mut cookie = Cookie::build((options.cookie_name.to_owned(), id.to_owned())) diff --git a/src/session_hash.rs b/src/session_hash.rs new file mode 100644 index 0000000..55b7f8f --- /dev/null +++ b/src/session_hash.rs @@ -0,0 +1,71 @@ +use crate::Session; + +/// Optional trait for sessions with a hashmap-like data structure. +pub trait SessionHashMap: Send + Sync + Clone + Default { + /// The type of values stored in the session hashmap. + type Value: Send + Sync + Clone; + + /// Get a reference to the value associated with the given key. + fn get(&self, key: &str) -> Option<&Self::Value>; + + /// Inserts or updates a key-value pair into the map. + fn insert(&mut self, key: String, value: Self::Value); + + /// Removes a key from the map. + fn remove(&mut self, key: &str); + + // /// Returns the number of keys in the map. + // fn len(&self) -> usize; + + // /// Returns an iterator over the key-value pairs in the map. + // fn iter(&self) -> std::slice::Iter<'_, (&str, &Self::Value)>; + + // /// Returns an iterator over the key-value pairs in the map, with mutable references. + // fn iter_mut(&mut self) -> std::slice::IterMut<'_, (&str, &mut Self::Value)>; +} + +/// Implementation block for sessions with hashmap-like data structures +impl Session<'_, T> +where + T: SessionHashMap, +{ + /// Get the value of a key in the session data via cloning + pub fn get_key(&self, key: &str) -> Option { + self.get_inner_lock() + .get_current_data() + .and_then(|h| h.get(key).cloned()) + } + + /// Get the value of a key in the session data via a closure + pub fn tap_key(&self, key: &str, f: F) -> R + where + F: FnOnce(Option<&T::Value>) -> R, + { + f(self + .get_inner_lock() + .get_current_data() + .and_then(|d| d.get(key))) + } + + /// Set the value of a key in the session data. Will create a new session if there isn't one. + pub fn set_key(&mut self, key: String, value: T::Value) { + self.get_inner_lock().tap_data_mut( + |data| data.get_or_insert_default().insert(key, value), + self.get_default_ttl(), + ); + self.update_cookies(); + } + + /// Remove a key from the session data. + pub fn remove_key(&mut self, key: &str) { + self.get_inner_lock().tap_data_mut( + |data| { + if let Some(data) = data { + data.remove(key); + } + }, + self.get_default_ttl(), + ); + self.update_cookies(); + } +} diff --git a/src/session_index.rs b/src/session_index.rs index f5eabb8..6077045 100644 --- a/src/session_index.rs +++ b/src/session_index.rs @@ -26,7 +26,7 @@ use crate::{error::SessionError, storage::SessionStorageIndexed, Session}; /// } /// } /// ``` -pub trait SessionIdentifier { +pub trait SessionIdentifier: Send + Sync + Clone { /// The name of the identifier (default: `"user_id"`), that may be used as a field/key name by the storage backend. const IDENTIFIER: &str = "user_id"; @@ -43,7 +43,7 @@ pub trait SessionIdentifier { /// Session implementation block for indexing operations impl<'a, T> Session<'a, T> where - T: SessionIdentifier + Send + Sync + Clone, + T: SessionIdentifier, { /// Get all active sessions for the same user/identifier as the current session. /// Returns the session ID, data, and TTL (in seconds) for each session. diff --git a/tests/basic.rs b/tests/basic.rs index c6c651a..e7494a9 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -6,7 +6,10 @@ use rocket::{ local::blocking::Client, {routes, Build, Rocket}, }; -use rocket_flex_session::{storage::cookie::CookieStorage, RocketFlexSession, Session}; +use rocket_flex_session::{ + storage::cookie::CookieStorage, RocketFlexSession, Session, SessionHashMap, +}; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; #[derive(Clone, Debug, PartialEq)] @@ -15,6 +18,23 @@ struct User { name: String, } +#[derive(Clone, Default, Serialize, Deserialize)] +struct SessionHash(HashMap); + +impl SessionHashMap for SessionHash { + type Value = String; + + fn get(&self, key: &str) -> Option<&Self::Value> { + self.0.get(key) + } + fn insert(&mut self, key: String, value: Self::Value) { + self.0.insert(key, value); + } + fn remove(&mut self, key: &str) { + self.0.remove(key); + } +} + #[get("/get_session")] fn get_session(session: Session) -> String { match session.get() { @@ -39,7 +59,7 @@ fn delete_session(mut session: Session) -> &'static str { } #[get("/get_hash_session/")] -fn get_hash_session(session: Session>, key: &str) -> String { +fn get_hash_session(session: Session, key: &str) -> String { match session.get_key(key) { Some(value) => value, None => "No value".to_string(), @@ -47,11 +67,7 @@ fn get_hash_session(session: Session>, key: &str) -> Str } #[post("/set_hash_session//")] -fn set_hash_session( - mut session: Session>, - key: &str, - value: &str, -) -> &'static str { +fn set_hash_session(mut session: Session, key: &str, value: &str) -> &'static str { session.set_key(key.to_owned(), value.to_owned()); "Hash session value set" } @@ -60,7 +76,7 @@ fn create_rocket() -> Rocket { rocket::build() .attach(RocketFlexSession::::default()) .attach( - RocketFlexSession::>::builder() + RocketFlexSession::::builder() .with_options(|opt| opt.cookie_name = "hash_session".to_owned()) .storage( CookieStorage::builder()