diff --git a/.sqlx/query-83722331508d9f6347db04c44546ddc6c1c82aad42f16dbda45003f13a1f6e33.json b/.sqlx/query-98739c1b3049739f95c056ef871f0ed200f2b4c10707685ece61e1dbe85c5c37.json similarity index 91% rename from .sqlx/query-83722331508d9f6347db04c44546ddc6c1c82aad42f16dbda45003f13a1f6e33.json rename to .sqlx/query-98739c1b3049739f95c056ef871f0ed200f2b4c10707685ece61e1dbe85c5c37.json index 08ecdf129a..043b020675 100644 --- a/.sqlx/query-83722331508d9f6347db04c44546ddc6c1c82aad42f16dbda45003f13a1f6e33.json +++ b/.sqlx/query-98739c1b3049739f95c056ef871f0ed200f2b4c10707685ece61e1dbe85c5c37.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "SELECT id, location_id, user_id, device_id, created_at, connected_at, disconnected_at, mfa_mode \"mfa_mode: LocationMfaMode\", state \"state: VpnClientSessionState\" FROM vpn_client_session WHERE location_id = $1 AND device_id = $2", + "query": "SELECT id, location_id, user_id, device_id, created_at, connected_at, disconnected_at, mfa_mode \"mfa_mode: LocationMfaMode\", state \"state: VpnClientSessionState\" FROM vpn_client_session WHERE location_id = $1 AND device_id = $2 AND state IN ('new', 'connected')", "describe": { "columns": [ { @@ -89,5 +89,5 @@ false ] }, - "hash": "83722331508d9f6347db04c44546ddc6c1c82aad42f16dbda45003f13a1f6e33" + "hash": "98739c1b3049739f95c056ef871f0ed200f2b4c10707685ece61e1dbe85c5c37" } diff --git a/.sqlx/query-c154aea1df6c3f273a1e7ab9c9a1c5b4da1599de659cd8315e331d6caf203fa4.json b/.sqlx/query-b2196d15ed73268487293a74bb8fd9393571a3f98b1e323f566493e464bbf1e3.json similarity index 78% rename from .sqlx/query-c154aea1df6c3f273a1e7ab9c9a1c5b4da1599de659cd8315e331d6caf203fa4.json rename to .sqlx/query-b2196d15ed73268487293a74bb8fd9393571a3f98b1e323f566493e464bbf1e3.json index e5b2104a6a..eb0d67ae11 100644 --- a/.sqlx/query-c154aea1df6c3f273a1e7ab9c9a1c5b4da1599de659cd8315e331d6caf203fa4.json +++ b/.sqlx/query-b2196d15ed73268487293a74bb8fd9393571a3f98b1e323f566493e464bbf1e3.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "SELECT id, session_id, gateway_id, collected_at, latest_handshake, endpoint, total_upload, total_download, upload_diff, download_diff\n \tFROM vpn_session_stats WHERE session_id = $1 ORDER BY collected_at DESC LIMIT 1", + "query": "SELECT DISTINCT ON (gateway_id) id, session_id, gateway_id, collected_at, latest_handshake, endpoint, total_upload, total_download, upload_diff, download_diff\n \tFROM vpn_session_stats WHERE session_id = $1 ORDER BY gateway_id, collected_at DESC", "describe": { "columns": [ { @@ -72,5 +72,5 @@ false ] }, - "hash": "c154aea1df6c3f273a1e7ab9c9a1c5b4da1599de659cd8315e331d6caf203fa4" + "hash": "b2196d15ed73268487293a74bb8fd9393571a3f98b1e323f566493e464bbf1e3" } diff --git a/crates/defguard_common/src/db/models/vpn_client_session.rs b/crates/defguard_common/src/db/models/vpn_client_session.rs index 3ca957d93b..8050f4ecc7 100644 --- a/crates/defguard_common/src/db/models/vpn_client_session.rs +++ b/crates/defguard_common/src/db/models/vpn_client_session.rs @@ -77,7 +77,7 @@ impl VpnClientSession { "SELECT id, location_id, user_id, device_id, created_at, connected_at, disconnected_at, \ mfa_mode \"mfa_mode: LocationMfaMode\", state \"state: VpnClientSessionState\" \ FROM vpn_client_session \ - WHERE location_id = $1 AND device_id = $2", + WHERE location_id = $1 AND device_id = $2 AND state IN ('new', 'connected')", location_id, device_id ) @@ -85,20 +85,21 @@ impl VpnClientSession { .await } - pub async fn try_get_latest_stats<'e, E: sqlx::PgExecutor<'e>>( + /// Returns latest stats in a given session for each gateway + pub async fn get_latest_stats_for_all_gateways<'e, E: sqlx::PgExecutor<'e>>( &self, executor: E, - ) -> Result>, SqlxError> { + ) -> Result>, SqlxError> { query_as!( VpnSessionStats, - "SELECT id, session_id, gateway_id, collected_at, latest_handshake, endpoint, \ + "SELECT DISTINCT ON (gateway_id) id, session_id, gateway_id, collected_at, latest_handshake, endpoint, \ total_upload, total_download, upload_diff, download_diff FROM vpn_session_stats \ WHERE session_id = $1 \ - ORDER BY collected_at DESC LIMIT 1", + ORDER BY gateway_id, collected_at DESC", self.id ) - .fetch_optional(executor) + .fetch_all(executor) .await } diff --git a/crates/defguard_session_manager/src/session_state.rs b/crates/defguard_session_manager/src/session_state.rs index c11815c5cf..422a830d75 100644 --- a/crates/defguard_session_manager/src/session_state.rs +++ b/crates/defguard_session_manager/src/session_state.rs @@ -20,6 +20,27 @@ use crate::{ events::{SessionManagerEvent, SessionManagerEventContext, SessionManagerEventType}, }; +/// Helper map to store latest stats update for each gateway in a given location +pub(crate) struct LastGatewayUpdate(HashMap); + +impl LastGatewayUpdate { + fn new() -> Self { + Self(HashMap::new()) + } + + /// Store latest stats for a given gateway + /// + /// We assume that at this point the update has already been validated. + fn update(&mut self, session_stats: VpnSessionStats) { + let gateway_id = session_stats.gateway_id; + let latest_stats = LastStatsUpdate::from(session_stats); + + debug!("Replacing latest stats update for gateway {gateway_id} with {latest_stats:?}"); + let _maybe_previous = self.0.insert(gateway_id, latest_stats); + } +} + +#[derive(Debug)] struct LastStatsUpdate { collected_at: NaiveDateTime, latest_handshake: NaiveDateTime, @@ -66,37 +87,42 @@ impl From> for LastStatsUpdate { /// State of a specific VPN client session pub(crate) struct SessionState { session_id: Id, - last_stats_update: Option, + last_stats_update: LastGatewayUpdate, } impl SessionState { fn new(session_id: Id) -> Self { Self { session_id, - last_stats_update: None, + last_stats_update: LastGatewayUpdate::new(), } } + fn try_get_last_stats_update(&self, gateway_id: Id) -> Option<&LastStatsUpdate> { + self.last_stats_update.0.get(&gateway_id) + } + /// Updates session stats based on received peer update pub(crate) async fn update_stats( &mut self, transaction: &mut PgConnection, peer_stats_update: PeerStatsUpdate, ) -> Result<(), SessionManagerError> { - // get previous stats if available and calculate transfer change - let (upload_diff, download_diff) = match &self.last_stats_update { - Some(last_stats_update) => { - // validate current update against latest value - last_stats_update.validate_update(&peer_stats_update)?; - - // calculate transfer change - ( - peer_stats_update.upload as i64 - last_stats_update.total_upload, - peer_stats_update.download as i64 - last_stats_update.total_download, - ) - } - None => (0, 0), - }; + // get previous stats for a given gateway if available and calculate transfer change + let (upload_diff, download_diff) = + match self.try_get_last_stats_update(peer_stats_update.gateway_id) { + Some(last_stats_update) => { + // validate current update against latest value + last_stats_update.validate_update(&peer_stats_update)?; + + // calculate transfer change + ( + peer_stats_update.upload as i64 - last_stats_update.total_upload, + peer_stats_update.download as i64 - last_stats_update.total_download, + ) + } + None => (0, 0), + }; let vpn_session_stats = VpnSessionStats::new( self.session_id, @@ -114,7 +140,7 @@ impl SessionState { let stats = vpn_session_stats.save(transaction).await?; // update latest stats - self.last_stats_update = Some(LastStatsUpdate::from(stats)); + self.last_stats_update.update(stats); Ok(()) } @@ -124,7 +150,7 @@ impl From<&VpnClientSession> for SessionState { fn from(value: &VpnClientSession) -> Self { Self { session_id: value.id, - last_stats_update: None, + last_stats_update: LastGatewayUpdate::new(), } } } @@ -143,7 +169,6 @@ impl SessionMap { } } -// TODO(mwojcik): handle multiple gateways per location /// Helper struct to hold session maps for all locations and object cache to avoid repeated DB queries /// /// Since we want to support HA core deployments this structure @@ -195,13 +220,17 @@ impl ActiveSessionsMap { Some(db_session) => { let mut session_state = SessionState::from(&db_session); - // try to fetch latest available stats for a given session - if let Some(latest_stats) = db_session.try_get_latest_stats(transaction).await? { - session_state.last_stats_update = Some(LastStatsUpdate::from(latest_stats)); - }; + // fetch latest available stats for each gateway for a given session + let latest_gateway_stats = db_session + .get_latest_stats_for_all_gateways(transaction) + .await?; + for stats in latest_gateway_stats { + session_state.last_stats_update.update(stats); + } // put session state in map let maybe_existing_session = session_map.insert(device_id, session_state); + // if a session exists already there was an error in earlier logic assert!(maybe_existing_session.is_none());