Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 7 additions & 6 deletions crates/defguard_common/src/db/models/vpn_client_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,28 +77,29 @@ impl VpnClientSession<Id> {
"SELECT id, location_id, user_id, device_id, created_at, connected_at, disconnected_at, \
mfa_mode \"mfa_mode: LocationMfaMode\", state \"state: VpnClientSessionState\" \
FROM vpn_client_session \
WHERE location_id = $1 AND device_id = $2",
WHERE location_id = $1 AND device_id = $2 AND state IN ('new', 'connected')",
location_id,
device_id
)
.fetch_optional(executor)
.await
}

pub async fn try_get_latest_stats<'e, E: sqlx::PgExecutor<'e>>(
/// Returns latest stats in a given session for each gateway
pub async fn get_latest_stats_for_all_gateways<'e, E: sqlx::PgExecutor<'e>>(
&self,
executor: E,
) -> Result<Option<VpnSessionStats<Id>>, SqlxError> {
) -> Result<Vec<VpnSessionStats<Id>>, SqlxError> {
query_as!(
VpnSessionStats,
"SELECT id, session_id, gateway_id, collected_at, latest_handshake, endpoint, \
"SELECT DISTINCT ON (gateway_id) id, session_id, gateway_id, collected_at, latest_handshake, endpoint, \
total_upload, total_download, upload_diff, download_diff
FROM vpn_session_stats \
WHERE session_id = $1 \
ORDER BY collected_at DESC LIMIT 1",
ORDER BY gateway_id, collected_at DESC",
self.id
)
.fetch_optional(executor)
.fetch_all(executor)
.await
}

Expand Down
75 changes: 52 additions & 23 deletions crates/defguard_session_manager/src/session_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,27 @@ use crate::{
events::{SessionManagerEvent, SessionManagerEventContext, SessionManagerEventType},
};

/// Helper map to store latest stats update for each gateway in a given location
pub(crate) struct LastGatewayUpdate(HashMap<Id, LastStatsUpdate>);

impl LastGatewayUpdate {
fn new() -> Self {
Self(HashMap::new())
}

/// Store latest stats for a given gateway
///
/// We assume that at this point the update has already been validated.
fn update(&mut self, session_stats: VpnSessionStats<Id>) {
let gateway_id = session_stats.gateway_id;
let latest_stats = LastStatsUpdate::from(session_stats);

debug!("Replacing latest stats update for gateway {gateway_id} with {latest_stats:?}");
let _maybe_previous = self.0.insert(gateway_id, latest_stats);
}
}

#[derive(Debug)]
struct LastStatsUpdate {
collected_at: NaiveDateTime,
latest_handshake: NaiveDateTime,
Expand Down Expand Up @@ -66,37 +87,42 @@ impl From<VpnSessionStats<Id>> for LastStatsUpdate {
/// State of a specific VPN client session
pub(crate) struct SessionState {
session_id: Id,
last_stats_update: Option<LastStatsUpdate>,
last_stats_update: LastGatewayUpdate,
}

impl SessionState {
fn new(session_id: Id) -> Self {
Self {
session_id,
last_stats_update: None,
last_stats_update: LastGatewayUpdate::new(),
}
}

fn try_get_last_stats_update(&self, gateway_id: Id) -> Option<&LastStatsUpdate> {
self.last_stats_update.0.get(&gateway_id)
}

/// Updates session stats based on received peer update
pub(crate) async fn update_stats(
&mut self,
transaction: &mut PgConnection,
peer_stats_update: PeerStatsUpdate,
) -> Result<(), SessionManagerError> {
// get previous stats if available and calculate transfer change
let (upload_diff, download_diff) = match &self.last_stats_update {
Some(last_stats_update) => {
// validate current update against latest value
last_stats_update.validate_update(&peer_stats_update)?;

// calculate transfer change
(
peer_stats_update.upload as i64 - last_stats_update.total_upload,
peer_stats_update.download as i64 - last_stats_update.total_download,
)
}
None => (0, 0),
};
// get previous stats for a given gateway if available and calculate transfer change
let (upload_diff, download_diff) =
match self.try_get_last_stats_update(peer_stats_update.gateway_id) {
Some(last_stats_update) => {
// validate current update against latest value
last_stats_update.validate_update(&peer_stats_update)?;

// calculate transfer change
(
peer_stats_update.upload as i64 - last_stats_update.total_upload,
peer_stats_update.download as i64 - last_stats_update.total_download,
)
}
None => (0, 0),
};

let vpn_session_stats = VpnSessionStats::new(
self.session_id,
Expand All @@ -114,7 +140,7 @@ impl SessionState {
let stats = vpn_session_stats.save(transaction).await?;

// update latest stats
self.last_stats_update = Some(LastStatsUpdate::from(stats));
self.last_stats_update.update(stats);

Ok(())
}
Expand All @@ -124,7 +150,7 @@ impl From<&VpnClientSession<Id>> for SessionState {
fn from(value: &VpnClientSession<Id>) -> Self {
Self {
session_id: value.id,
last_stats_update: None,
last_stats_update: LastGatewayUpdate::new(),
}
}
}
Expand All @@ -143,7 +169,6 @@ impl SessionMap {
}
}

// TODO(mwojcik): handle multiple gateways per location
/// Helper struct to hold session maps for all locations and object cache to avoid repeated DB queries
///
/// Since we want to support HA core deployments this structure
Expand Down Expand Up @@ -195,13 +220,17 @@ impl ActiveSessionsMap {
Some(db_session) => {
let mut session_state = SessionState::from(&db_session);

// try to fetch latest available stats for a given session
if let Some(latest_stats) = db_session.try_get_latest_stats(transaction).await? {
session_state.last_stats_update = Some(LastStatsUpdate::from(latest_stats));
};
// fetch latest available stats for each gateway for a given session
let latest_gateway_stats = db_session
.get_latest_stats_for_all_gateways(transaction)
.await?;
for stats in latest_gateway_stats {
session_state.last_stats_update.update(stats);
}

// put session state in map
let maybe_existing_session = session_map.insert(device_id, session_state);

// if a session exists already there was an error in earlier logic
assert!(maybe_existing_session.is_none());

Expand Down
Loading