From eeaf3219383997f9077a61ba78800fe174af7a30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Thu, 12 Mar 2026 16:30:27 +0100 Subject: [PATCH 1/4] extend session manager test suite --- .../tests/common/mod.rs | 231 +++++++- .../tests/session_manager/disconnects.rs | 107 ++++ .../tests/session_manager/event_flow.rs | 156 +++-- .../tests/session_manager/mfa.rs | 431 ++++++++++++++ .../tests/session_manager/mod.rs | 2 + .../tests/session_manager/sessions.rs | 538 +++++++++++++++++- .../tests/session_manager/stats.rs | 252 ++++++-- 7 files changed, 1602 insertions(+), 115 deletions(-) create mode 100644 crates/defguard_session_manager/tests/session_manager/disconnects.rs create mode 100644 crates/defguard_session_manager/tests/session_manager/mfa.rs diff --git a/crates/defguard_session_manager/tests/common/mod.rs b/crates/defguard_session_manager/tests/common/mod.rs index 005530c367..ad6cb639e9 100644 --- a/crates/defguard_session_manager/tests/common/mod.rs +++ b/crates/defguard_session_manager/tests/common/mod.rs @@ -1,5 +1,9 @@ -use std::net::{IpAddr, Ipv4Addr}; +use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + time::Duration, +}; +use chrono::{NaiveDateTime, Timelike}; use defguard_common::{ db::{ Id, @@ -7,27 +11,63 @@ use defguard_common::{ Device, DeviceType, User, WireguardNetwork, device::WireguardNetworkDevice, gateway::Gateway, + vpn_client_session::{VpnClientMfaMethod, VpnClientSession}, + vpn_session_stats::VpnSessionStats, wireguard::{LocationMfaMode, ServiceLocationMode}, }, }, messages::peer_stats_update::PeerStatsUpdate, }; -use defguard_session_manager::{SessionManager, events::SessionManagerEvent}; +use defguard_session_manager::{ + IterationOutcome, SESSION_UPDATE_INTERVAL, SessionManager, events::SessionManagerEvent, + run_session_manager_iteration, +}; use ipnetwork::IpNetwork; -use tokio::sync::{broadcast, mpsc}; +use sqlx::{PgExecutor, query, query_as, query_scalar}; +use tokio::{ + sync::{ + broadcast, + mpsc::{self}, + }, + time::interval, +}; pub(crate) struct SessionManagerHarness { pub(crate) manager: SessionManager, stats_tx: mpsc::UnboundedSender, pub(crate) stats_rx: mpsc::UnboundedReceiver, pub(crate) event_rx: mpsc::UnboundedReceiver, + pub(crate) gateway_rx: broadcast::Receiver, +} + +pub(crate) fn assert_no_session_manager_events(harness: &mut SessionManagerHarness) { + match harness.event_rx.try_recv() { + Err(mpsc::error::TryRecvError::Empty) => {} + Err(mpsc::error::TryRecvError::Disconnected) => { + panic!("session manager event channel disconnected unexpectedly") + } + Ok(event) => panic!("unexpected session manager event: {event:?}"), + } +} + +pub(crate) fn assert_no_gateway_events(harness: &mut SessionManagerHarness) { + match harness.gateway_rx.try_recv() { + Err(broadcast::error::TryRecvError::Empty) => {} + Err(broadcast::error::TryRecvError::Closed) => { + panic!("gateway event channel closed unexpectedly") + } + Err(broadcast::error::TryRecvError::Lagged(skipped)) => { + panic!("gateway event channel lagged and skipped {skipped} events") + } + Ok(event) => panic!("unexpected gateway event: {event:?}"), + } } impl SessionManagerHarness { pub(crate) fn new(pool: sqlx::PgPool) -> Self { let (stats_tx, stats_rx) = mpsc::unbounded_channel(); let (event_tx, event_rx) = mpsc::unbounded_channel(); - let (gateway_tx, _gateway_rx) = broadcast::channel(16); + let (gateway_tx, gateway_rx) = broadcast::channel(16); let manager = SessionManager::new(pool, event_tx, gateway_tx); Self { @@ -35,6 +75,7 @@ impl SessionManagerHarness { stats_tx, stats_rx, event_rx, + gateway_rx, } } @@ -43,9 +84,38 @@ impl SessionManagerHarness { .send(update) .expect("failed to send peer stats update"); } + + pub(crate) async fn run_iteration(&mut self) -> IterationOutcome { + let mut session_update_timer = interval(Duration::from_secs(SESSION_UPDATE_INTERVAL)); + run_session_manager_iteration( + &mut self.manager, + &mut self.stats_rx, + &mut session_update_timer, + ) + .await + .expect("session manager iteration failed") + } + + pub(crate) async fn run_idle_iteration(&mut self) -> IterationOutcome { + let mut session_update_timer = interval(Duration::from_millis(1)); + run_session_manager_iteration( + &mut self.manager, + &mut self.stats_rx, + &mut session_update_timer, + ) + .await + .expect("session manager iteration failed") + } } pub(crate) async fn create_network(pool: &sqlx::PgPool) -> WireguardNetwork { + create_network_with_mfa_mode(pool, LocationMfaMode::Disabled).await +} + +pub(crate) async fn create_network_with_mfa_mode( + pool: &sqlx::PgPool, + location_mfa_mode: LocationMfaMode, +) -> WireguardNetwork { WireguardNetwork::new( "TestNet".to_string(), vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 0)), 24).unwrap()], @@ -59,7 +129,7 @@ pub(crate) async fn create_network(pool: &sqlx::PgPool) -> WireguardNetwork 300, false, false, - LocationMfaMode::Disabled, + location_mfa_mode, ServiceLocationMode::Disabled, ) .save(pool) @@ -82,9 +152,17 @@ pub(crate) async fn create_user(pool: &sqlx::PgPool) -> User { } pub(crate) async fn create_device(pool: &sqlx::PgPool, user_id: Id) -> Device { + create_device_with_pubkey(pool, user_id, "device-pubkey-test").await +} + +pub(crate) async fn create_device_with_pubkey( + pool: &sqlx::PgPool, + user_id: Id, + wireguard_pubkey: &str, +) -> Device { Device::new( "session-test-device".to_string(), - "device-pubkey-test".to_string(), + wireguard_pubkey.to_string(), user_id, DeviceType::User, None, @@ -111,10 +189,19 @@ pub(crate) async fn create_gateway( pool: &sqlx::PgPool, network_id: Id, modified_by: String, +) -> Gateway { + create_gateway_named(pool, network_id, modified_by, "gateway-1").await +} + +pub(crate) async fn create_gateway_named( + pool: &sqlx::PgPool, + network_id: Id, + modified_by: String, + name: &str, ) -> Gateway { Gateway::new( network_id, - "gateway-1".to_string(), + name.to_string(), "127.0.0.1".to_string(), 51820, modified_by, @@ -123,3 +210,133 @@ pub(crate) async fn create_gateway( .await .expect("failed to create gateway") } + +pub(crate) async fn authorize_device_in_network( + pool: &sqlx::PgPool, + network_id: Id, + device_id: Id, + preshared_key: &str, +) { + let mut network_device = WireguardNetworkDevice::find(pool, device_id, network_id) + .await + .expect("failed to load device network info") + .expect("expected device network info"); + network_device.is_authorized = true; + network_device.authorized_at = Some(chrono::Utc::now().naive_utc()); + network_device.preshared_key = Some(preshared_key.to_string()); + network_device + .update(pool) + .await + .expect("failed to authorize device in network"); +} + +#[allow(clippy::too_many_arguments)] +pub(crate) fn build_stats_update( + location_id: Id, + gateway_id: Id, + device_pubkey: impl Into, + collected_at: NaiveDateTime, + endpoint: SocketAddr, + upload: u64, + download: u64, + latest_handshake: NaiveDateTime, +) -> PeerStatsUpdate { + PeerStatsUpdate { + location_id, + gateway_id, + device_pubkey: device_pubkey.into(), + collected_at: truncate_timestamp(collected_at), + endpoint, + upload, + download, + latest_handshake: truncate_timestamp(latest_handshake), + } +} + +pub(crate) fn truncate_timestamp(timestamp: NaiveDateTime) -> NaiveDateTime { + timestamp + .with_nanosecond((timestamp.nanosecond() / 1_000) * 1_000) + .expect("failed to truncate timestamp precision") +} + +pub(crate) async fn create_session( + pool: &sqlx::PgPool, + location_id: Id, + user_id: Id, + device_id: Id, + connected_at: Option, + mfa_method: Option, +) -> VpnClientSession { + VpnClientSession::new(location_id, user_id, device_id, connected_at, mfa_method) + .save(pool) + .await + .expect("failed to create vpn client session") +} + +#[allow(clippy::too_many_arguments)] +pub(crate) async fn create_session_stats( + pool: &sqlx::PgPool, + session_id: Id, + gateway_id: Id, + collected_at: NaiveDateTime, + latest_handshake: NaiveDateTime, + endpoint: SocketAddr, + total_upload: i64, + total_download: i64, + upload_diff: i64, + download_diff: i64, +) -> VpnSessionStats { + VpnSessionStats::new( + session_id, + gateway_id, + collected_at, + latest_handshake, + endpoint.to_string(), + total_upload, + total_download, + upload_diff, + download_diff, + ) + .save(pool) + .await + .expect("failed to create vpn session stats") +} + +pub(crate) async fn set_session_created_at<'e, E: PgExecutor<'e>>( + executor: E, + session_id: Id, + created_at: NaiveDateTime, +) { + query("UPDATE vpn_client_session SET created_at = $1 WHERE id = $2") + .bind(created_at) + .bind(session_id) + .execute(executor) + .await + .expect("failed to update session created_at"); +} + +pub(crate) async fn count_session_stats<'e, E: PgExecutor<'e>>(executor: E, session_id: Id) -> i64 { + query_scalar("SELECT COUNT(*) FROM vpn_session_stats WHERE session_id = $1") + .bind(session_id) + .fetch_one(executor) + .await + .expect("failed to count vpn session stats") +} + +pub(crate) async fn count_stats_for_device_location<'e, E: PgExecutor<'e>>( + executor: E, + device_id: Id, + location_id: Id, +) -> i64 { + query_scalar( + "SELECT COUNT(*) \ + FROM vpn_session_stats stats \ + JOIN vpn_client_session session ON stats.session_id = session.id \ + WHERE session.device_id = $1 AND session.location_id = $2", + ) + .bind(device_id) + .bind(location_id) + .fetch_one(executor) + .await + .expect("failed to count device session stats") +} diff --git a/crates/defguard_session_manager/tests/session_manager/disconnects.rs b/crates/defguard_session_manager/tests/session_manager/disconnects.rs new file mode 100644 index 0000000000..501f4de413 --- /dev/null +++ b/crates/defguard_session_manager/tests/session_manager/disconnects.rs @@ -0,0 +1,107 @@ +use std::net::SocketAddr; + +use chrono::{TimeDelta, Utc}; +use defguard_common::db::{ + models::vpn_client_session::{VpnClientSession, VpnClientSessionState}, + setup_pool, +}; +use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; + +use crate::common::{ + SessionManagerHarness, attach_device_to_network, create_device, create_gateway, create_network, + create_session, create_session_stats, create_user, +}; + +#[sqlx::test] +async fn test_inactive_connected_sessions_are_disconnected_after_threshold( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let network = create_network(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_network(&pool, network.id, device.id).await; + let gateway = create_gateway(&pool, network.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let stale_handshake = Utc::now().naive_utc() - TimeDelta::seconds(301); + let session = create_session( + &pool, + network.id, + user.id, + device.id, + Some(stale_handshake), + None, + ) + .await; + create_session_stats( + &pool, + session.id, + gateway.id, + stale_handshake, + stale_handshake, + "203.0.113.10:51820".parse::().unwrap(), + 100, + 200, + 0, + 0, + ) + .await; + + let _ = harness.run_idle_iteration().await; + + let disconnected_session = VpnClientSession::find_by_id(&pool, session.id) + .await + .expect("failed to query session") + .expect("expected session"); + assert_eq!( + disconnected_session.state, + VpnClientSessionState::Disconnected + ); + assert!(disconnected_session.disconnected_at.is_some()); +} + +#[sqlx::test] +async fn test_recent_connected_sessions_remain_active(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + let network = create_network(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_network(&pool, network.id, device.id).await; + let gateway = create_gateway(&pool, network.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let recent_handshake = Utc::now().naive_utc() - TimeDelta::seconds(30); + let session = create_session( + &pool, + network.id, + user.id, + device.id, + Some(recent_handshake), + None, + ) + .await; + create_session_stats( + &pool, + session.id, + gateway.id, + recent_handshake, + recent_handshake, + "203.0.113.10:51820".parse::().unwrap(), + 100, + 200, + 0, + 0, + ) + .await; + + let _ = harness.run_idle_iteration().await; + + let refreshed_session = VpnClientSession::find_by_id(&pool, session.id) + .await + .expect("failed to query session") + .expect("expected session"); + assert_eq!(refreshed_session.state, VpnClientSessionState::Connected); + assert!(refreshed_session.disconnected_at.is_none()); +} diff --git a/crates/defguard_session_manager/tests/session_manager/event_flow.rs b/crates/defguard_session_manager/tests/session_manager/event_flow.rs index 8ffe26d73c..90b8104f0e 100644 --- a/crates/defguard_session_manager/tests/session_manager/event_flow.rs +++ b/crates/defguard_session_manager/tests/session_manager/event_flow.rs @@ -1,20 +1,23 @@ use std::net::SocketAddr; use chrono::{TimeDelta, Utc}; -use defguard_common::{db::setup_pool, messages::peer_stats_update::PeerStatsUpdate}; -use defguard_session_manager::{ - SESSION_UPDATE_INTERVAL, events::SessionManagerEventType, run_session_manager_iteration, -}; +use defguard_common::db::{models::vpn_client_session::VpnClientSession, setup_pool}; +use defguard_session_manager::events::SessionManagerEventType; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; -use tokio::time::{Duration, interval}; +use tokio::time::{Duration, timeout}; use crate::common::{ - SessionManagerHarness, attach_device_to_network, create_device, create_gateway, create_network, - create_user, + SessionManagerHarness, attach_device_to_network, build_stats_update, create_device, + create_gateway, create_network, create_session, create_session_stats, create_user, }; +const RECEIVE_TIMEOUT: Duration = Duration::from_secs(1); + #[sqlx::test] -async fn test_session_manager_emits_connected_event(_: PgPoolOptions, options: PgConnectOptions) { +async fn test_session_manager_emits_connected_event_for_first_stats( + _: PgPoolOptions, + options: PgConnectOptions, +) { let pool = setup_pool(options).await; let network = create_network(&pool).await; let user = create_user(&pool).await; @@ -25,33 +28,23 @@ async fn test_session_manager_emits_connected_event(_: PgPoolOptions, options: P let mut harness = SessionManagerHarness::new(pool); let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); - let base_time = Utc::now().naive_utc(); - let update = PeerStatsUpdate { - location_id: network.id, - gateway_id: gateway.id, - device_pubkey: device.wireguard_pubkey.clone(), - collected_at: base_time, + let handshake = Utc::now().naive_utc() - TimeDelta::seconds(5); + harness.send_stats(build_stats_update( + network.id, + gateway.id, + &device.wireguard_pubkey, + handshake, endpoint, - upload: 100, - download: 200, - latest_handshake: base_time - TimeDelta::seconds(5), - }; - - harness.send_stats(update); - - let mut session_update_timer = interval(Duration::from_secs(SESSION_UPDATE_INTERVAL)); - let _ = run_session_manager_iteration( - &mut harness.manager, - &mut harness.stats_rx, - &mut session_update_timer, - ) - .await - .expect("session manager iteration failed"); + 100, + 200, + handshake, + )); + + let _ = harness.run_iteration().await; - let event = harness - .event_rx - .recv() + let event = timeout(RECEIVE_TIMEOUT, harness.event_rx.recv()) .await + .expect("timed out waiting for ClientConnected event") .expect("session manager event channel closed"); assert!(matches!( @@ -63,3 +56,102 @@ async fn test_session_manager_emits_connected_event(_: PgPoolOptions, options: P assert_eq!(event.context.device.id, device.id); assert_eq!(event.context.public_ip, endpoint.ip()); } + +#[sqlx::test] +async fn test_reusing_existing_connected_session_does_not_emit_duplicate_connected_event( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let network = create_network(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_network(&pool, network.id, device.id).await; + let gateway = create_gateway(&pool, network.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let connected_at = Utc::now().naive_utc() - TimeDelta::seconds(5); + let _session = create_session( + &pool, + network.id, + user.id, + device.id, + Some(connected_at), + None, + ) + .await; + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + harness.send_stats(build_stats_update( + network.id, + gateway.id, + &device.wireguard_pubkey, + connected_at, + endpoint, + 100, + 200, + connected_at, + )); + + let _ = harness.run_iteration().await; + + assert!(harness.event_rx.try_recv().is_err()); +} + +#[sqlx::test] +async fn test_session_manager_emits_disconnect_event_for_inactive_standard_session( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let network = create_network(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_network(&pool, network.id, device.id).await; + let gateway = create_gateway(&pool, network.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let stale_handshake = Utc::now().naive_utc() - TimeDelta::seconds(301); + let session = create_session( + &pool, + network.id, + user.id, + device.id, + Some(stale_handshake), + None, + ) + .await; + create_session_stats( + &pool, + session.id, + gateway.id, + stale_handshake, + stale_handshake, + "203.0.113.10:51820".parse().unwrap(), + 100, + 200, + 0, + 0, + ) + .await; + + let _ = harness.run_idle_iteration().await; + + let event = timeout(RECEIVE_TIMEOUT, harness.event_rx.recv()) + .await + .expect("timed out waiting for ClientDisconnected event") + .expect("session manager event channel closed"); + assert!(matches!( + event.event, + SessionManagerEventType::ClientDisconnected + )); + assert_eq!(event.context.location.id, network.id); + assert_eq!(event.context.user.id, user.id); + assert_eq!(event.context.device.id, device.id); + + let disconnected_session = VpnClientSession::find_by_id(&pool, session.id) + .await + .expect("failed to query session") + .expect("expected session"); + assert!(disconnected_session.disconnected_at.is_some()); +} diff --git a/crates/defguard_session_manager/tests/session_manager/mfa.rs b/crates/defguard_session_manager/tests/session_manager/mfa.rs new file mode 100644 index 0000000000..519591f076 --- /dev/null +++ b/crates/defguard_session_manager/tests/session_manager/mfa.rs @@ -0,0 +1,431 @@ +use std::net::SocketAddr; + +use chrono::{TimeDelta, Utc}; +use defguard_common::db::{ + models::{ + device::WireguardNetworkDevice, + vpn_client_session::{VpnClientMfaMethod, VpnClientSession, VpnClientSessionState}, + vpn_session_stats::VpnSessionStats, + wireguard::LocationMfaMode, + }, + setup_pool, +}; +use defguard_core::grpc::GatewayEvent; +use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; +use tokio::time::{Duration, timeout}; + +use crate::common::{ + SessionManagerHarness, assert_no_gateway_events, assert_no_session_manager_events, + attach_device_to_network, authorize_device_in_network, build_stats_update, count_session_stats, + count_stats_for_device_location, create_device, create_gateway, create_network_with_mfa_mode, + create_session, create_session_stats, create_user, set_session_created_at, truncate_timestamp, +}; + +const RECEIVE_TIMEOUT: Duration = Duration::from_secs(1); + +#[sqlx::test] +async fn test_mfa_location_stats_do_not_create_missing_session( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let network = create_network_with_mfa_mode(&pool, LocationMfaMode::Internal).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_network(&pool, network.id, device.id).await; + let gateway = create_gateway(&pool, network.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let timestamp = Utc::now().naive_utc(); + harness.send_stats(build_stats_update( + network.id, + gateway.id, + &device.wireguard_pubkey, + timestamp, + endpoint, + 100, + 200, + timestamp, + )); + + let _ = harness.run_iteration().await; + + assert!( + VpnClientSession::try_get_active_session(&pool, network.id, device.id) + .await + .expect("failed to query active session") + .is_none() + ); + assert_eq!( + count_stats_for_device_location(&pool, device.id, network.id).await, + 0 + ); +} + +#[sqlx::test] +async fn test_mfa_new_session_upgrades_to_connected_on_stats( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let network = create_network_with_mfa_mode(&pool, LocationMfaMode::Internal).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_network(&pool, network.id, device.id).await; + let gateway = create_gateway(&pool, network.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let session = create_session( + &pool, + network.id, + user.id, + device.id, + None, + Some(VpnClientMfaMethod::Totp), + ) + .await; + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let handshake = truncate_timestamp(Utc::now().naive_utc()); + harness.send_stats(build_stats_update( + network.id, + gateway.id, + &device.wireguard_pubkey, + handshake, + endpoint, + 100, + 200, + handshake, + )); + + let _ = harness.run_iteration().await; + + let refreshed_session = VpnClientSession::find_by_id(&pool, session.id) + .await + .expect("failed to query session") + .expect("expected session"); + assert_eq!(refreshed_session.state, VpnClientSessionState::Connected); + assert_eq!(refreshed_session.connected_at, Some(handshake)); + assert_eq!( + count_stats_for_device_location(&pool, device.id, network.id).await, + 1 + ); + + let second_collected_at = handshake + TimeDelta::seconds(30); + let second_handshake = handshake + TimeDelta::seconds(25); + harness.send_stats(build_stats_update( + network.id, + gateway.id, + &device.wireguard_pubkey, + second_collected_at, + endpoint, + 160, + 280, + second_handshake, + )); + + let _ = harness.run_iteration().await; + + let updated_session = VpnClientSession::find_by_id(&pool, session.id) + .await + .expect("failed to query session") + .expect("expected session"); + assert_eq!(updated_session.state, VpnClientSessionState::Connected); + assert_eq!(updated_session.connected_at, Some(handshake)); + + let active_sessions = + VpnClientSession::get_all_active_device_sessions_in_location(&pool, network.id, device.id) + .await + .expect("failed to query active sessions"); + assert_eq!(active_sessions.len(), 1); + assert_eq!(active_sessions[0].id, session.id); + + assert_eq!(count_session_stats(&pool, session.id).await, 2); + assert_eq!( + count_stats_for_device_location(&pool, device.id, network.id).await, + 2 + ); + + let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, network.id) + .await + .expect("failed to query latest stats") + .expect("expected latest stats"); + assert_eq!(latest_stats.session_id, session.id); + assert_eq!(latest_stats.total_upload, 160); + assert_eq!(latest_stats.total_download, 280); + assert_eq!(latest_stats.upload_diff, 60); + assert_eq!(latest_stats.download_diff, 80); +} + +#[sqlx::test] +async fn test_duplicate_first_stats_on_mfa_new_session_are_idempotent( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let network = create_network_with_mfa_mode(&pool, LocationMfaMode::Internal).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_network(&pool, network.id, device.id).await; + let gateway = create_gateway(&pool, network.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let session = create_session( + &pool, + network.id, + user.id, + device.id, + None, + Some(VpnClientMfaMethod::Totp), + ) + .await; + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let handshake = truncate_timestamp(Utc::now().naive_utc()); + let duplicate_update = || { + build_stats_update( + network.id, + gateway.id, + &device.wireguard_pubkey, + handshake, + endpoint, + 100, + 200, + handshake, + ) + }; + + harness.send_stats(duplicate_update()); + harness.send_stats(duplicate_update()); + + let _ = harness.run_iteration().await; + + let refreshed_session = VpnClientSession::find_by_id(&pool, session.id) + .await + .expect("failed to query session") + .expect("expected session"); + assert_eq!(refreshed_session.state, VpnClientSessionState::Connected); + assert_eq!(refreshed_session.connected_at, Some(handshake)); + + let active_sessions = + VpnClientSession::get_all_active_device_sessions_in_location(&pool, network.id, device.id) + .await + .expect("failed to query active sessions"); + assert_eq!(active_sessions.len(), 1); + assert_eq!(active_sessions[0].id, session.id); + + assert_eq!(count_session_stats(&pool, session.id).await, 2); + + let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, network.id) + .await + .expect("failed to query latest stats") + .expect("expected latest stats"); + assert_eq!(latest_stats.session_id, session.id); + assert_eq!(latest_stats.upload_diff, 0); + assert_eq!(latest_stats.download_diff, 0); + + assert_no_session_manager_events(&mut harness); + assert_no_gateway_events(&mut harness); +} + +#[sqlx::test] +async fn test_repeated_later_stats_on_mfa_session_remain_idempotent( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let network = create_network_with_mfa_mode(&pool, LocationMfaMode::Internal).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_network(&pool, network.id, device.id).await; + let gateway = create_gateway(&pool, network.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let session = create_session( + &pool, + network.id, + user.id, + device.id, + None, + Some(VpnClientMfaMethod::Totp), + ) + .await; + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let first_handshake = truncate_timestamp(Utc::now().naive_utc() - TimeDelta::seconds(30)); + harness.send_stats(build_stats_update( + network.id, + gateway.id, + &device.wireguard_pubkey, + first_handshake, + endpoint, + 100, + 200, + first_handshake, + )); + + let _ = harness.run_iteration().await; + + let connected_session = VpnClientSession::find_by_id(&pool, session.id) + .await + .expect("failed to query session") + .expect("expected session"); + assert_eq!(connected_session.state, VpnClientSessionState::Connected); + assert_eq!(connected_session.connected_at, Some(first_handshake)); + + assert_no_session_manager_events(&mut harness); + assert_no_gateway_events(&mut harness); + + let later_collected_at = first_handshake + TimeDelta::seconds(30); + let later_handshake = first_handshake + TimeDelta::seconds(20); + harness.send_stats(build_stats_update( + network.id, + gateway.id, + &device.wireguard_pubkey, + later_collected_at, + endpoint, + 100, + 200, + later_handshake, + )); + + let _ = harness.run_iteration().await; + + let refreshed_session = VpnClientSession::find_by_id(&pool, session.id) + .await + .expect("failed to query session") + .expect("expected session"); + assert_eq!(refreshed_session.state, VpnClientSessionState::Connected); + assert_eq!(refreshed_session.connected_at, Some(first_handshake)); + + let active_sessions = + VpnClientSession::get_all_active_device_sessions_in_location(&pool, network.id, device.id) + .await + .expect("failed to query active sessions"); + assert_eq!(active_sessions.len(), 1); + assert_eq!(active_sessions[0].id, session.id); + + assert_eq!(count_session_stats(&pool, session.id).await, 2); + + let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, network.id) + .await + .expect("failed to query latest stats") + .expect("expected latest stats"); + assert_eq!(latest_stats.session_id, session.id); + assert_eq!(latest_stats.upload_diff, 0); + assert_eq!(latest_stats.download_diff, 0); + + assert_no_session_manager_events(&mut harness); + assert_no_gateway_events(&mut harness); +} + +#[sqlx::test] +async fn test_inactive_mfa_connected_sessions_disconnect_and_clear_authorization( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let network = create_network_with_mfa_mode(&pool, LocationMfaMode::Internal).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_network(&pool, network.id, device.id).await; + authorize_device_in_network(&pool, network.id, device.id, "psk-before-disconnect").await; + let gateway = create_gateway(&pool, network.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let stale_handshake = Utc::now().naive_utc() - TimeDelta::seconds(301); + let session = create_session( + &pool, + network.id, + user.id, + device.id, + Some(stale_handshake), + Some(VpnClientMfaMethod::Totp), + ) + .await; + create_session_stats( + &pool, + session.id, + gateway.id, + stale_handshake, + stale_handshake, + "203.0.113.10:51820".parse().unwrap(), + 100, + 200, + 0, + 0, + ) + .await; + + let _ = harness.run_idle_iteration().await; + + let disconnected_session = VpnClientSession::find_by_id(&pool, session.id) + .await + .expect("failed to query session") + .expect("expected session"); + assert_eq!( + disconnected_session.state, + VpnClientSessionState::Disconnected + ); + + let network_device = WireguardNetworkDevice::find(&pool, device.id, network.id) + .await + .expect("failed to query network device") + .expect("expected network device"); + assert!(!network_device.is_authorized); + assert_eq!(network_device.preshared_key, None); + + let gateway_event = timeout(RECEIVE_TIMEOUT, harness.gateway_rx.recv()) + .await + .expect("timed out waiting for MFA disconnect gateway event") + .expect("gateway event channel closed"); + match gateway_event { + GatewayEvent::MfaSessionDisconnected(location_id, disconnected_device) => { + assert_eq!(location_id, network.id); + assert_eq!(disconnected_device.id, device.id); + } + other => panic!("unexpected gateway event: {other:?}"), + } +} + +#[sqlx::test] +async fn test_never_connected_mfa_new_sessions_disconnect_after_threshold( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let network = create_network_with_mfa_mode(&pool, LocationMfaMode::Internal).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_network(&pool, network.id, device.id).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let session = create_session( + &pool, + network.id, + user.id, + device.id, + None, + Some(VpnClientMfaMethod::Totp), + ) + .await; + set_session_created_at( + &pool, + session.id, + Utc::now().naive_utc() - TimeDelta::seconds(301), + ) + .await; + + let _ = harness.run_idle_iteration().await; + + let disconnected_session = VpnClientSession::find_by_id(&pool, session.id) + .await + .expect("failed to query session") + .expect("expected session"); + assert_eq!( + disconnected_session.state, + VpnClientSessionState::Disconnected + ); + assert!(disconnected_session.disconnected_at.is_some()); +} diff --git a/crates/defguard_session_manager/tests/session_manager/mod.rs b/crates/defguard_session_manager/tests/session_manager/mod.rs index a16bea70cf..1cf8461d46 100644 --- a/crates/defguard_session_manager/tests/session_manager/mod.rs +++ b/crates/defguard_session_manager/tests/session_manager/mod.rs @@ -1,3 +1,5 @@ +mod disconnects; mod event_flow; +mod mfa; mod sessions; mod stats; diff --git a/crates/defguard_session_manager/tests/session_manager/sessions.rs b/crates/defguard_session_manager/tests/session_manager/sessions.rs index 8d05bc51fe..094aeddb2c 100644 --- a/crates/defguard_session_manager/tests/session_manager/sessions.rs +++ b/crates/defguard_session_manager/tests/session_manager/sessions.rs @@ -1,57 +1,539 @@ +use std::net::SocketAddr; + use chrono::{TimeDelta, Utc}; -use defguard_common::{ - db::{ - models::vpn_client_session::{VpnClientSession, VpnClientSessionState}, - setup_pool, +use defguard_common::db::{ + models::{ + vpn_client_session::{VpnClientSession, VpnClientSessionState}, + vpn_session_stats::VpnSessionStats, }, - messages::peer_stats_update::PeerStatsUpdate, + setup_pool, }; -use defguard_session_manager::{SESSION_UPDATE_INTERVAL, run_session_manager_iteration}; +use defguard_session_manager::events::SessionManagerEventType; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; -use tokio::time::{Duration, interval}; +use tokio::time::{Duration, timeout}; use crate::common::{ - SessionManagerHarness, attach_device_to_network, create_device, create_gateway, create_network, - create_user, + SessionManagerHarness, assert_no_gateway_events, assert_no_session_manager_events, + attach_device_to_network, build_stats_update, count_session_stats, + count_stats_for_device_location, create_device, create_device_with_pubkey, create_gateway, + create_network, create_session, create_session_stats, create_user, truncate_timestamp, }; +const RECEIVE_TIMEOUT: Duration = Duration::from_secs(1); + +#[sqlx::test] +async fn test_session_manager_creates_connected_session_from_first_stats( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let network = create_network(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_network(&pool, network.id, device.id).await; + let gateway = create_gateway(&pool, network.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let handshake = truncate_timestamp(Utc::now().naive_utc() - TimeDelta::seconds(5)); + harness.send_stats(build_stats_update( + network.id, + gateway.id, + &device.wireguard_pubkey, + handshake, + endpoint, + 100, + 200, + handshake, + )); + + let _ = harness.run_iteration().await; + + let session = VpnClientSession::try_get_active_session(&pool, network.id, device.id) + .await + .expect("failed to query active session") + .expect("expected active session"); + assert_eq!(session.state, VpnClientSessionState::Connected); + assert_eq!(session.connected_at, Some(handshake)); +} + #[sqlx::test] -async fn test_session_manager_creates_active_session(_: PgPoolOptions, options: PgConnectOptions) { +async fn test_stale_first_stats_update_does_not_create_session_or_stats( + _: PgPoolOptions, + options: PgConnectOptions, +) { let pool = setup_pool(options).await; let network = create_network(&pool).await; let user = create_user(&pool).await; let device = create_device(&pool, user.id).await; attach_device_to_network(&pool, network.id, device.id).await; let gateway = create_gateway(&pool, network.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let collected_at = truncate_timestamp(Utc::now().naive_utc()); + let stale_handshake = collected_at - TimeDelta::seconds(301); + harness.send_stats(build_stats_update( + network.id, + gateway.id, + &device.wireguard_pubkey, + collected_at, + endpoint, + 100, + 200, + stale_handshake, + )); + + let _ = harness.run_iteration().await; + + assert!( + VpnClientSession::try_get_active_session(&pool, network.id, device.id) + .await + .expect("failed to query active session") + .is_none() + ); + assert_eq!( + count_stats_for_device_location(&pool, device.id, network.id).await, + 0 + ); + assert_no_session_manager_events(&mut harness); + assert_no_gateway_events(&mut harness); +} +#[sqlx::test] +async fn test_duplicate_stats_in_same_batch_reuse_existing_session( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let network = create_network(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_network(&pool, network.id, device.id).await; + let gateway = create_gateway(&pool, network.id, user.fullname()).await; let mut harness = SessionManagerHarness::new(pool.clone()); + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); let base_time = Utc::now().naive_utc(); - let update = PeerStatsUpdate { - location_id: network.id, - gateway_id: gateway.id, - device_pubkey: device.wireguard_pubkey.clone(), - collected_at: base_time, - endpoint: "203.0.113.10:51820".parse().unwrap(), - upload: 100, - download: 200, - latest_handshake: base_time - TimeDelta::seconds(5), + let duplicate_update = || { + build_stats_update( + network.id, + gateway.id, + &device.wireguard_pubkey, + base_time, + endpoint, + 100, + 200, + base_time, + ) }; + harness.send_stats(duplicate_update()); + harness.send_stats(duplicate_update()); + + let _ = harness.run_iteration().await; + + let connected_event = timeout(RECEIVE_TIMEOUT, harness.event_rx.recv()) + .await + .expect("timed out waiting for ClientConnected event in duplicate same-batch stats test") + .expect("session manager event channel closed while waiting for duplicate same-batch stats event"); + assert!(matches!( + connected_event.event, + SessionManagerEventType::ClientConnected + )); + assert_eq!(connected_event.context.location.id, network.id); + assert_eq!(connected_event.context.user.id, user.id); + assert_eq!(connected_event.context.device.id, device.id); + assert_eq!(connected_event.context.public_ip, endpoint.ip()); + assert_no_session_manager_events(&mut harness); + assert_no_gateway_events(&mut harness); + + let active_session = VpnClientSession::try_get_active_session(&pool, network.id, device.id) + .await + .expect("failed to query active session") + .expect("expected active session"); + + let sessions = + VpnClientSession::get_all_active_device_sessions_in_location(&pool, network.id, device.id) + .await + .expect("failed to query active sessions"); + assert_eq!(sessions.len(), 1); + assert_eq!(sessions[0].id, active_session.id); + + let session = sessions.first().expect("expected active session"); + assert_eq!(count_session_stats(&pool, session.id).await, 2); + + let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, network.id) + .await + .expect("failed to query latest stats") + .expect("expected latest stats"); + assert_eq!(latest_stats.session_id, session.id); + assert_eq!(latest_stats.upload_diff, 0); + assert_eq!(latest_stats.download_diff, 0); +} + +#[sqlx::test] +async fn test_duplicate_stats_across_iterations_reuse_existing_session( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let network = create_network(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_network(&pool, network.id, device.id).await; + let gateway = create_gateway(&pool, network.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let base_time = Utc::now().naive_utc(); + let update = build_stats_update( + network.id, + gateway.id, + &device.wireguard_pubkey, + base_time, + endpoint, + 100, + 200, + base_time, + ); harness.send_stats(update); + let _ = harness.run_iteration().await; - let mut session_update_timer = interval(Duration::from_secs(SESSION_UPDATE_INTERVAL)); - let _ = run_session_manager_iteration( - &mut harness.manager, - &mut harness.stats_rx, - &mut session_update_timer, - ) - .await - .expect("session manager iteration failed"); + let first_session = VpnClientSession::try_get_active_session(&pool, network.id, device.id) + .await + .expect("failed to query active session") + .expect("expected active session"); + + let connected_event = timeout(RECEIVE_TIMEOUT, harness.event_rx.recv()) + .await + .expect( + "timed out waiting for ClientConnected event in duplicate cross-iteration stats test", + ) + .expect( + "session manager event channel closed while waiting for duplicate cross-iteration stats event", + ); + assert!(matches!( + connected_event.event, + SessionManagerEventType::ClientConnected + )); + assert_eq!(connected_event.context.location.id, network.id); + assert_eq!(connected_event.context.user.id, user.id); + assert_eq!(connected_event.context.device.id, device.id); + assert_eq!(connected_event.context.public_ip, endpoint.ip()); + assert_no_session_manager_events(&mut harness); + assert_no_gateway_events(&mut harness); + + harness.send_stats(build_stats_update( + network.id, + gateway.id, + &device.wireguard_pubkey, + base_time, + endpoint, + 100, + 200, + base_time, + )); + let _ = harness.run_iteration().await; + + assert_no_session_manager_events(&mut harness); + assert_no_gateway_events(&mut harness); + + let sessions = + VpnClientSession::get_all_active_device_sessions_in_location(&pool, network.id, device.id) + .await + .expect("failed to query active sessions"); + assert_eq!(sessions.len(), 1); + assert_eq!(sessions[0].id, first_session.id); + assert_eq!(count_session_stats(&pool, first_session.id).await, 2); + + let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, network.id) + .await + .expect("failed to query latest stats") + .expect("expected latest stats"); + assert_eq!(latest_stats.session_id, first_session.id); + assert_eq!(latest_stats.upload_diff, 0); + assert_eq!(latest_stats.download_diff, 0); +} + +#[sqlx::test] +async fn test_existing_new_session_becomes_connected_on_stats( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let network = create_network(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_network(&pool, network.id, device.id).await; + let gateway = create_gateway(&pool, network.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let existing_session = create_session(&pool, network.id, user.id, device.id, None, None).await; + assert_eq!(existing_session.state, VpnClientSessionState::New); + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let handshake = truncate_timestamp(Utc::now().naive_utc()); + harness.send_stats(build_stats_update( + network.id, + gateway.id, + &device.wireguard_pubkey, + handshake, + endpoint, + 100, + 200, + handshake, + )); + + let _ = harness.run_iteration().await; + + let updated_session = VpnClientSession::find_by_id(&pool, existing_session.id) + .await + .expect("failed to query session") + .expect("expected session"); + assert_eq!(updated_session.state, VpnClientSessionState::Connected); + assert_eq!(updated_session.connected_at, Some(handshake)); + assert_eq!(count_session_stats(&pool, updated_session.id).await, 1); +} + +#[sqlx::test] +async fn test_invalid_device_pubkey_updates_are_discarded( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let network = create_network(&pool).await; + let user = create_user(&pool).await; + let device = create_device_with_pubkey(&pool, user.id, "device-pubkey-valid").await; + attach_device_to_network(&pool, network.id, device.id).await; + let gateway = create_gateway(&pool, network.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let timestamp = Utc::now().naive_utc(); + harness.send_stats(build_stats_update( + network.id, + gateway.id, + "missing-pubkey", + timestamp, + endpoint, + 100, + 200, + timestamp, + )); + + let _ = harness.run_iteration().await; + + let maybe_session = VpnClientSession::try_get_active_session(&pool, network.id, device.id) + .await + .expect("failed to query active session"); + assert!(maybe_session.is_none()); + assert_eq!( + count_stats_for_device_location(&pool, device.id, network.id).await, + 0 + ); +} + +#[sqlx::test] +async fn test_out_of_order_peer_updates_are_discarded(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + let network = create_network(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_network(&pool, network.id, device.id).await; + let gateway = create_gateway(&pool, network.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let base_time = Utc::now().naive_utc(); + harness.send_stats(build_stats_update( + network.id, + gateway.id, + &device.wireguard_pubkey, + base_time, + endpoint, + 100, + 200, + base_time, + )); + let _ = harness.run_iteration().await; let session = VpnClientSession::try_get_active_session(&pool, network.id, device.id) .await .expect("failed to query active session") .expect("expected active session"); - assert_eq!(session.state, VpnClientSessionState::Connected); + assert_eq!(count_session_stats(&pool, session.id).await, 1); + + harness.send_stats(build_stats_update( + network.id, + gateway.id, + &device.wireguard_pubkey, + base_time - TimeDelta::seconds(1), + endpoint, + 150, + 260, + base_time + TimeDelta::seconds(1), + )); + let _ = harness.run_iteration().await; + assert_eq!(count_session_stats(&pool, session.id).await, 1); + + harness.send_stats(build_stats_update( + network.id, + gateway.id, + &device.wireguard_pubkey, + base_time + TimeDelta::seconds(2), + endpoint, + 150, + 260, + base_time - TimeDelta::seconds(1), + )); + let _ = harness.run_iteration().await; + assert_eq!(count_session_stats(&pool, session.id).await, 1); + + harness.send_stats(build_stats_update( + network.id, + gateway.id, + &device.wireguard_pubkey, + base_time + TimeDelta::seconds(3), + endpoint, + 90, + 190, + base_time + TimeDelta::seconds(3), + )); + let _ = harness.run_iteration().await; + assert_eq!(count_session_stats(&pool, session.id).await, 1); +} + +#[sqlx::test] +async fn test_device_public_key_change_reuses_existing_session( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let network = create_network(&pool).await; + let user = create_user(&pool).await; + let mut device = + create_device_with_pubkey(&pool, user.id, "device-pubkey-before-rotation").await; + attach_device_to_network(&pool, network.id, device.id).await; + let gateway = create_gateway(&pool, network.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let base_time = Utc::now().naive_utc(); + harness.send_stats(build_stats_update( + network.id, + gateway.id, + &device.wireguard_pubkey, + base_time, + endpoint, + 100, + 200, + base_time, + )); + let _ = harness.run_iteration().await; + + let existing_session = VpnClientSession::try_get_active_session(&pool, network.id, device.id) + .await + .expect("failed to query active session") + .expect("expected active session"); + + device.wireguard_pubkey = "device-pubkey-after-rotation".to_string(); + device + .save(&pool) + .await + .expect("failed to update device pubkey"); + + harness.send_stats(build_stats_update( + network.id, + gateway.id, + &device.wireguard_pubkey, + base_time + TimeDelta::seconds(10), + endpoint, + 150, + 260, + base_time + TimeDelta::seconds(10), + )); + let _ = harness.run_iteration().await; + + let sessions = + VpnClientSession::get_all_active_device_sessions_in_location(&pool, network.id, device.id) + .await + .expect("failed to query active sessions"); + assert_eq!(sessions.len(), 1); + assert_eq!(sessions[0].id, existing_session.id); + assert_eq!(count_session_stats(&pool, existing_session.id).await, 2); + + let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, network.id) + .await + .expect("failed to query latest stats") + .expect("expected latest stats"); + assert_eq!(latest_stats.session_id, existing_session.id); + assert_eq!(latest_stats.total_upload, 150); + assert_eq!(latest_stats.total_download, 260); +} + +#[sqlx::test] +async fn test_existing_session_in_db_is_reused_instead_of_creating_duplicate( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let network = create_network(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_network(&pool, network.id, device.id).await; + let gateway = create_gateway(&pool, network.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let base_time = Utc::now().naive_utc(); + let existing_session = create_session( + &pool, + network.id, + user.id, + device.id, + Some(base_time - TimeDelta::seconds(5)), + None, + ) + .await; + create_session_stats( + &pool, + existing_session.id, + gateway.id, + base_time - TimeDelta::seconds(5), + base_time - TimeDelta::seconds(5), + endpoint, + 100, + 200, + 0, + 0, + ) + .await; + + harness.send_stats(build_stats_update( + network.id, + gateway.id, + &device.wireguard_pubkey, + base_time, + endpoint, + 160, + 280, + base_time, + )); + let _ = harness.run_iteration().await; + + let sessions = + VpnClientSession::get_all_active_device_sessions_in_location(&pool, network.id, device.id) + .await + .expect("failed to query active sessions"); + assert_eq!(sessions.len(), 1); + assert_eq!(sessions[0].id, existing_session.id); + + let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, network.id) + .await + .expect("failed to query latest stats") + .expect("expected latest stats"); + assert_eq!(latest_stats.session_id, existing_session.id); + assert_eq!(latest_stats.upload_diff, 60); + assert_eq!(latest_stats.download_diff, 80); } diff --git a/crates/defguard_session_manager/tests/session_manager/stats.rs b/crates/defguard_session_manager/tests/session_manager/stats.rs index 542cf5467f..4364878454 100644 --- a/crates/defguard_session_manager/tests/session_manager/stats.rs +++ b/crates/defguard_session_manager/tests/session_manager/stats.rs @@ -1,53 +1,44 @@ use std::net::SocketAddr; use chrono::{TimeDelta, Utc}; -use defguard_common::{ - db::{models::vpn_session_stats::VpnSessionStats, setup_pool}, - messages::peer_stats_update::PeerStatsUpdate, +use defguard_common::db::{ + models::{vpn_client_session::VpnClientSession, vpn_session_stats::VpnSessionStats}, + setup_pool, }; -use defguard_session_manager::{SESSION_UPDATE_INTERVAL, run_session_manager_iteration}; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; -use tokio::time::{Duration, interval}; use crate::common::{ - SessionManagerHarness, attach_device_to_network, create_device, create_gateway, create_network, - create_user, + SessionManagerHarness, attach_device_to_network, build_stats_update, count_session_stats, + create_device, create_gateway, create_gateway_named, create_network, create_session, + create_session_stats, create_user, }; #[sqlx::test] -async fn test_session_manager_updates_stats(_: PgPoolOptions, options: PgConnectOptions) { +async fn test_session_manager_updates_stats_deltas_across_iterations( + _: PgPoolOptions, + options: PgConnectOptions, +) { let pool = setup_pool(options).await; let network = create_network(&pool).await; let user = create_user(&pool).await; let device = create_device(&pool, user.id).await; attach_device_to_network(&pool, network.id, device.id).await; let gateway = create_gateway(&pool, network.id, user.fullname()).await; - let mut harness = SessionManagerHarness::new(pool.clone()); let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); let base_time = Utc::now().naive_utc(); - let first_update = PeerStatsUpdate { - location_id: network.id, - gateway_id: gateway.id, - device_pubkey: device.wireguard_pubkey.clone(), - collected_at: base_time, + harness.send_stats(build_stats_update( + network.id, + gateway.id, + &device.wireguard_pubkey, + base_time, endpoint, - upload: 100, - download: 200, - latest_handshake: base_time - TimeDelta::seconds(5), - }; - - harness.send_stats(first_update); - - let mut session_update_timer = interval(Duration::from_secs(SESSION_UPDATE_INTERVAL)); - let _ = run_session_manager_iteration( - &mut harness.manager, - &mut harness.stats_rx, - &mut session_update_timer, - ) - .await - .expect("session manager iteration failed"); + 100, + 200, + base_time - TimeDelta::seconds(5), + )); + let _ = harness.run_iteration().await; let first_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, network.id) .await @@ -56,26 +47,17 @@ async fn test_session_manager_updates_stats(_: PgPoolOptions, options: PgConnect assert_eq!(first_stats.upload_diff, 0); assert_eq!(first_stats.download_diff, 0); - let second_update = PeerStatsUpdate { - location_id: network.id, - gateway_id: gateway.id, - device_pubkey: device.wireguard_pubkey.clone(), - collected_at: base_time + TimeDelta::seconds(10), + harness.send_stats(build_stats_update( + network.id, + gateway.id, + &device.wireguard_pubkey, + base_time + TimeDelta::seconds(10), endpoint, - upload: 150, - download: 260, - latest_handshake: base_time + TimeDelta::seconds(10), - }; - - harness.send_stats(second_update); - - let _ = run_session_manager_iteration( - &mut harness.manager, - &mut harness.stats_rx, - &mut session_update_timer, - ) - .await - .expect("session manager iteration failed"); + 150, + 260, + base_time + TimeDelta::seconds(10), + )); + let _ = harness.run_iteration().await; let second_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, network.id) .await @@ -83,4 +65,178 @@ async fn test_session_manager_updates_stats(_: PgPoolOptions, options: PgConnect .expect("expected session stats"); assert_eq!(second_stats.upload_diff, 50); assert_eq!(second_stats.download_diff, 60); + + harness.send_stats(build_stats_update( + network.id, + gateway.id, + &device.wireguard_pubkey, + base_time + TimeDelta::seconds(20), + endpoint, + 180, + 330, + base_time + TimeDelta::seconds(20), + )); + let _ = harness.run_iteration().await; + + let third_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, network.id) + .await + .expect("failed to query session stats") + .expect("expected session stats"); + assert_eq!(third_stats.upload_diff, 30); + assert_eq!(third_stats.download_diff, 70); +} + +#[sqlx::test] +async fn test_session_manager_calculates_stats_per_gateway( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let network = create_network(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_network(&pool, network.id, device.id).await; + let gateway_one = create_gateway_named(&pool, network.id, user.fullname(), "gateway-1").await; + let gateway_two = create_gateway_named(&pool, network.id, user.fullname(), "gateway-2").await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let base_time = Utc::now().naive_utc(); + harness.send_stats(build_stats_update( + network.id, + gateway_one.id, + &device.wireguard_pubkey, + base_time, + endpoint, + 100, + 200, + base_time, + )); + let _ = harness.run_iteration().await; + + harness.send_stats(build_stats_update( + network.id, + gateway_one.id, + &device.wireguard_pubkey, + base_time + TimeDelta::seconds(10), + endpoint, + 130, + 240, + base_time + TimeDelta::seconds(10), + )); + let _ = harness.run_iteration().await; + + harness.send_stats(build_stats_update( + network.id, + gateway_two.id, + &device.wireguard_pubkey, + base_time + TimeDelta::seconds(20), + endpoint, + 500, + 700, + base_time + TimeDelta::seconds(20), + )); + let _ = harness.run_iteration().await; + + harness.send_stats(build_stats_update( + network.id, + gateway_two.id, + &device.wireguard_pubkey, + base_time + TimeDelta::seconds(30), + endpoint, + 560, + 780, + base_time + TimeDelta::seconds(30), + )); + let _ = harness.run_iteration().await; + + let session = VpnClientSession::try_get_active_session(&pool, network.id, device.id) + .await + .expect("failed to query active session") + .expect("expected active session"); + let gateway_stats = session + .get_latest_stats_for_all_gateways(&pool) + .await + .expect("failed to query gateway stats"); + assert_eq!(gateway_stats.len(), 2); + + let stats_for_gateway_one = gateway_stats + .iter() + .find(|stats| stats.gateway_id == gateway_one.id) + .expect("expected gateway one stats"); + assert_eq!(stats_for_gateway_one.total_upload, 130); + assert_eq!(stats_for_gateway_one.total_download, 240); + assert_eq!(stats_for_gateway_one.upload_diff, 30); + assert_eq!(stats_for_gateway_one.download_diff, 40); + + let stats_for_gateway_two = gateway_stats + .iter() + .find(|stats| stats.gateway_id == gateway_two.id) + .expect("expected gateway two stats"); + assert_eq!(stats_for_gateway_two.total_upload, 560); + assert_eq!(stats_for_gateway_two.total_download, 780); + assert_eq!(stats_for_gateway_two.upload_diff, 60); + assert_eq!(stats_for_gateway_two.download_diff, 80); + + assert_eq!(count_session_stats(&pool, session.id).await, 4); +} + +#[sqlx::test] +async fn test_out_of_order_updates_for_existing_db_session_are_discarded( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let network = create_network(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_network(&pool, network.id, device.id).await; + let gateway = create_gateway(&pool, network.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let first_handshake = Utc::now().naive_utc() - TimeDelta::seconds(5); + let existing_session = create_session( + &pool, + network.id, + user.id, + device.id, + Some(first_handshake), + None, + ) + .await; + create_session_stats( + &pool, + existing_session.id, + gateway.id, + first_handshake, + first_handshake, + endpoint, + 100, + 200, + 0, + 0, + ) + .await; + + harness.send_stats(build_stats_update( + network.id, + gateway.id, + &device.wireguard_pubkey, + first_handshake - TimeDelta::seconds(1), + endpoint, + 110, + 210, + first_handshake, + )); + let _ = harness.run_iteration().await; + + assert_eq!(count_session_stats(&pool, existing_session.id).await, 1); + let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, network.id) + .await + .expect("failed to query latest stats") + .expect("expected latest stats"); + assert_eq!(latest_stats.session_id, existing_session.id); + assert_eq!(latest_stats.total_upload, 100); + assert_eq!(latest_stats.total_download, 200); } From 0b50ee08a9afce461b73885d9b3eef6853bd9493 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Thu, 12 Mar 2026 16:47:25 +0100 Subject: [PATCH 2/4] remove unused import --- crates/defguard_session_manager/tests/common/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/defguard_session_manager/tests/common/mod.rs b/crates/defguard_session_manager/tests/common/mod.rs index ad6cb639e9..4dead55dab 100644 --- a/crates/defguard_session_manager/tests/common/mod.rs +++ b/crates/defguard_session_manager/tests/common/mod.rs @@ -23,7 +23,7 @@ use defguard_session_manager::{ run_session_manager_iteration, }; use ipnetwork::IpNetwork; -use sqlx::{PgExecutor, query, query_as, query_scalar}; +use sqlx::{PgExecutor, query, query_scalar}; use tokio::{ sync::{ broadcast, From c84da109f5bec486fc63a20606772f852e940fbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Thu, 12 Mar 2026 17:05:53 +0100 Subject: [PATCH 3/4] review fixes --- crates/defguard_session_manager/tests/common/mod.rs | 11 ++++++++++- .../tests/session_manager/disconnects.rs | 4 ++-- .../tests/session_manager/event_flow.rs | 3 ++- .../tests/session_manager/mfa.rs | 12 ++++-------- .../tests/session_manager/sessions.rs | 5 +++-- 5 files changed, 21 insertions(+), 14 deletions(-) diff --git a/crates/defguard_session_manager/tests/common/mod.rs b/crates/defguard_session_manager/tests/common/mod.rs index 4dead55dab..c310ec8c22 100644 --- a/crates/defguard_session_manager/tests/common/mod.rs +++ b/crates/defguard_session_manager/tests/common/mod.rs @@ -3,7 +3,7 @@ use std::{ time::Duration, }; -use chrono::{NaiveDateTime, Timelike}; +use chrono::{NaiveDateTime, TimeDelta, Timelike, Utc}; use defguard_common::{ db::{ Id, @@ -259,6 +259,15 @@ pub(crate) fn truncate_timestamp(timestamp: NaiveDateTime) -> NaiveDateTime { .expect("failed to truncate timestamp precision") } +pub(crate) fn stale_session_timestamp(location: &WireguardNetwork) -> NaiveDateTime { + let reference_time = Utc::now().naive_utc(); + reference_time + .checked_sub_signed(TimeDelta::seconds( + i64::from(location.peer_disconnect_threshold) + 1, + )) + .expect("reference timestamp should stay within range") +} + pub(crate) async fn create_session( pool: &sqlx::PgPool, location_id: Id, diff --git a/crates/defguard_session_manager/tests/session_manager/disconnects.rs b/crates/defguard_session_manager/tests/session_manager/disconnects.rs index 501f4de413..bc1d4b428b 100644 --- a/crates/defguard_session_manager/tests/session_manager/disconnects.rs +++ b/crates/defguard_session_manager/tests/session_manager/disconnects.rs @@ -9,7 +9,7 @@ use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use crate::common::{ SessionManagerHarness, attach_device_to_network, create_device, create_gateway, create_network, - create_session, create_session_stats, create_user, + create_session, create_session_stats, create_user, stale_session_timestamp, }; #[sqlx::test] @@ -25,7 +25,7 @@ async fn test_inactive_connected_sessions_are_disconnected_after_threshold( let gateway = create_gateway(&pool, network.id, user.fullname()).await; let mut harness = SessionManagerHarness::new(pool.clone()); - let stale_handshake = Utc::now().naive_utc() - TimeDelta::seconds(301); + let stale_handshake = stale_session_timestamp(&network); let session = create_session( &pool, network.id, diff --git a/crates/defguard_session_manager/tests/session_manager/event_flow.rs b/crates/defguard_session_manager/tests/session_manager/event_flow.rs index 90b8104f0e..47409db637 100644 --- a/crates/defguard_session_manager/tests/session_manager/event_flow.rs +++ b/crates/defguard_session_manager/tests/session_manager/event_flow.rs @@ -9,6 +9,7 @@ use tokio::time::{Duration, timeout}; use crate::common::{ SessionManagerHarness, attach_device_to_network, build_stats_update, create_device, create_gateway, create_network, create_session, create_session_stats, create_user, + stale_session_timestamp, }; const RECEIVE_TIMEOUT: Duration = Duration::from_secs(1); @@ -111,7 +112,7 @@ async fn test_session_manager_emits_disconnect_event_for_inactive_standard_sessi let gateway = create_gateway(&pool, network.id, user.fullname()).await; let mut harness = SessionManagerHarness::new(pool.clone()); - let stale_handshake = Utc::now().naive_utc() - TimeDelta::seconds(301); + let stale_handshake = stale_session_timestamp(&network); let session = create_session( &pool, network.id, diff --git a/crates/defguard_session_manager/tests/session_manager/mfa.rs b/crates/defguard_session_manager/tests/session_manager/mfa.rs index 519591f076..762e4fa65d 100644 --- a/crates/defguard_session_manager/tests/session_manager/mfa.rs +++ b/crates/defguard_session_manager/tests/session_manager/mfa.rs @@ -18,7 +18,8 @@ use crate::common::{ SessionManagerHarness, assert_no_gateway_events, assert_no_session_manager_events, attach_device_to_network, authorize_device_in_network, build_stats_update, count_session_stats, count_stats_for_device_location, create_device, create_gateway, create_network_with_mfa_mode, - create_session, create_session_stats, create_user, set_session_created_at, truncate_timestamp, + create_session, create_session_stats, create_user, set_session_created_at, + stale_session_timestamp, truncate_timestamp, }; const RECEIVE_TIMEOUT: Duration = Duration::from_secs(1); @@ -334,7 +335,7 @@ async fn test_inactive_mfa_connected_sessions_disconnect_and_clear_authorization let gateway = create_gateway(&pool, network.id, user.fullname()).await; let mut harness = SessionManagerHarness::new(pool.clone()); - let stale_handshake = Utc::now().naive_utc() - TimeDelta::seconds(301); + let stale_handshake = stale_session_timestamp(&network); let session = create_session( &pool, network.id, @@ -410,12 +411,7 @@ async fn test_never_connected_mfa_new_sessions_disconnect_after_threshold( Some(VpnClientMfaMethod::Totp), ) .await; - set_session_created_at( - &pool, - session.id, - Utc::now().naive_utc() - TimeDelta::seconds(301), - ) - .await; + set_session_created_at(&pool, session.id, stale_session_timestamp(&network)).await; let _ = harness.run_idle_iteration().await; diff --git a/crates/defguard_session_manager/tests/session_manager/sessions.rs b/crates/defguard_session_manager/tests/session_manager/sessions.rs index 094aeddb2c..ac25075800 100644 --- a/crates/defguard_session_manager/tests/session_manager/sessions.rs +++ b/crates/defguard_session_manager/tests/session_manager/sessions.rs @@ -16,7 +16,8 @@ use crate::common::{ SessionManagerHarness, assert_no_gateway_events, assert_no_session_manager_events, attach_device_to_network, build_stats_update, count_session_stats, count_stats_for_device_location, create_device, create_device_with_pubkey, create_gateway, - create_network, create_session, create_session_stats, create_user, truncate_timestamp, + create_network, create_session, create_session_stats, create_user, stale_session_timestamp, + truncate_timestamp, }; const RECEIVE_TIMEOUT: Duration = Duration::from_secs(1); @@ -72,7 +73,7 @@ async fn test_stale_first_stats_update_does_not_create_session_or_stats( let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); let collected_at = truncate_timestamp(Utc::now().naive_utc()); - let stale_handshake = collected_at - TimeDelta::seconds(301); + let stale_handshake = stale_session_timestamp(&network); harness.send_stats(build_stats_update( network.id, gateway.id, From 01efd19eebc0e97806805e109bd4e23e784b5ad9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Thu, 12 Mar 2026 17:12:43 +0100 Subject: [PATCH 4/4] cleanup --- .../tests/common/mod.rs | 30 ++-- .../tests/session_manager/disconnects.rs | 22 +-- .../tests/session_manager/event_flow.rs | 36 ++--- .../tests/session_manager/mfa.rs | 94 ++++++------- .../tests/session_manager/sessions.rs | 130 +++++++++--------- .../tests/session_manager/stats.rs | 52 +++---- 6 files changed, 182 insertions(+), 182 deletions(-) diff --git a/crates/defguard_session_manager/tests/common/mod.rs b/crates/defguard_session_manager/tests/common/mod.rs index c310ec8c22..f967e088cb 100644 --- a/crates/defguard_session_manager/tests/common/mod.rs +++ b/crates/defguard_session_manager/tests/common/mod.rs @@ -108,11 +108,11 @@ impl SessionManagerHarness { } } -pub(crate) async fn create_network(pool: &sqlx::PgPool) -> WireguardNetwork { - create_network_with_mfa_mode(pool, LocationMfaMode::Disabled).await +pub(crate) async fn create_location(pool: &sqlx::PgPool) -> WireguardNetwork { + create_location_with_mfa_mode(pool, LocationMfaMode::Disabled).await } -pub(crate) async fn create_network_with_mfa_mode( +pub(crate) async fn create_location_with_mfa_mode( pool: &sqlx::PgPool, location_mfa_mode: LocationMfaMode, ) -> WireguardNetwork { @@ -134,7 +134,7 @@ pub(crate) async fn create_network_with_mfa_mode( ) .save(pool) .await - .expect("failed to create Wireguard network") + .expect("failed to create Wireguard location") } pub(crate) async fn create_user(pool: &sqlx::PgPool) -> User { @@ -173,34 +173,34 @@ pub(crate) async fn create_device_with_pubkey( .expect("failed to create device") } -pub(crate) async fn attach_device_to_network(pool: &sqlx::PgPool, network_id: Id, device_id: Id) { +pub(crate) async fn attach_device_to_location(pool: &sqlx::PgPool, location_id: Id, device_id: Id) { let network_device = WireguardNetworkDevice::new( - network_id, + location_id, device_id, vec![IpAddr::V4(Ipv4Addr::new(10, 0, 0, 10))], ); network_device .insert(pool) .await - .expect("failed to attach device to network"); + .expect("failed to attach device to location"); } pub(crate) async fn create_gateway( pool: &sqlx::PgPool, - network_id: Id, + location_id: Id, modified_by: String, ) -> Gateway { - create_gateway_named(pool, network_id, modified_by, "gateway-1").await + create_gateway_named(pool, location_id, modified_by, "gateway-1").await } pub(crate) async fn create_gateway_named( pool: &sqlx::PgPool, - network_id: Id, + location_id: Id, modified_by: String, name: &str, ) -> Gateway { Gateway::new( - network_id, + location_id, name.to_string(), "127.0.0.1".to_string(), 51820, @@ -211,13 +211,13 @@ pub(crate) async fn create_gateway_named( .expect("failed to create gateway") } -pub(crate) async fn authorize_device_in_network( +pub(crate) async fn authorize_device_in_location( pool: &sqlx::PgPool, - network_id: Id, + location_id: Id, device_id: Id, preshared_key: &str, ) { - let mut network_device = WireguardNetworkDevice::find(pool, device_id, network_id) + let mut network_device = WireguardNetworkDevice::find(pool, device_id, location_id) .await .expect("failed to load device network info") .expect("expected device network info"); @@ -227,7 +227,7 @@ pub(crate) async fn authorize_device_in_network( network_device .update(pool) .await - .expect("failed to authorize device in network"); + .expect("failed to authorize device in location"); } #[allow(clippy::too_many_arguments)] diff --git a/crates/defguard_session_manager/tests/session_manager/disconnects.rs b/crates/defguard_session_manager/tests/session_manager/disconnects.rs index bc1d4b428b..938c428761 100644 --- a/crates/defguard_session_manager/tests/session_manager/disconnects.rs +++ b/crates/defguard_session_manager/tests/session_manager/disconnects.rs @@ -8,8 +8,8 @@ use defguard_common::db::{ use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use crate::common::{ - SessionManagerHarness, attach_device_to_network, create_device, create_gateway, create_network, - create_session, create_session_stats, create_user, stale_session_timestamp, + SessionManagerHarness, attach_device_to_location, create_device, create_gateway, + create_location, create_session, create_session_stats, create_user, stale_session_timestamp, }; #[sqlx::test] @@ -18,17 +18,17 @@ async fn test_inactive_connected_sessions_are_disconnected_after_threshold( options: PgConnectOptions, ) { let pool = setup_pool(options).await; - let network = create_network(&pool).await; + let location = create_location(&pool).await; let user = create_user(&pool).await; let device = create_device(&pool, user.id).await; - attach_device_to_network(&pool, network.id, device.id).await; - let gateway = create_gateway(&pool, network.id, user.fullname()).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; let mut harness = SessionManagerHarness::new(pool.clone()); - let stale_handshake = stale_session_timestamp(&network); + let stale_handshake = stale_session_timestamp(&location); let session = create_session( &pool, - network.id, + location.id, user.id, device.id, Some(stale_handshake), @@ -65,17 +65,17 @@ async fn test_inactive_connected_sessions_are_disconnected_after_threshold( #[sqlx::test] async fn test_recent_connected_sessions_remain_active(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let network = create_network(&pool).await; + let location = create_location(&pool).await; let user = create_user(&pool).await; let device = create_device(&pool, user.id).await; - attach_device_to_network(&pool, network.id, device.id).await; - let gateway = create_gateway(&pool, network.id, user.fullname()).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; let mut harness = SessionManagerHarness::new(pool.clone()); let recent_handshake = Utc::now().naive_utc() - TimeDelta::seconds(30); let session = create_session( &pool, - network.id, + location.id, user.id, device.id, Some(recent_handshake), diff --git a/crates/defguard_session_manager/tests/session_manager/event_flow.rs b/crates/defguard_session_manager/tests/session_manager/event_flow.rs index 47409db637..744ddd0893 100644 --- a/crates/defguard_session_manager/tests/session_manager/event_flow.rs +++ b/crates/defguard_session_manager/tests/session_manager/event_flow.rs @@ -7,8 +7,8 @@ use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use tokio::time::{Duration, timeout}; use crate::common::{ - SessionManagerHarness, attach_device_to_network, build_stats_update, create_device, - create_gateway, create_network, create_session, create_session_stats, create_user, + SessionManagerHarness, attach_device_to_location, build_stats_update, create_device, + create_gateway, create_location, create_session, create_session_stats, create_user, stale_session_timestamp, }; @@ -20,18 +20,18 @@ async fn test_session_manager_emits_connected_event_for_first_stats( options: PgConnectOptions, ) { let pool = setup_pool(options).await; - let network = create_network(&pool).await; + let location = create_location(&pool).await; let user = create_user(&pool).await; let device = create_device(&pool, user.id).await; - attach_device_to_network(&pool, network.id, device.id).await; - let gateway = create_gateway(&pool, network.id, user.fullname()).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; let mut harness = SessionManagerHarness::new(pool); let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); let handshake = Utc::now().naive_utc() - TimeDelta::seconds(5); harness.send_stats(build_stats_update( - network.id, + location.id, gateway.id, &device.wireguard_pubkey, handshake, @@ -52,7 +52,7 @@ async fn test_session_manager_emits_connected_event_for_first_stats( event.event, SessionManagerEventType::ClientConnected )); - assert_eq!(event.context.location.id, network.id); + assert_eq!(event.context.location.id, location.id); assert_eq!(event.context.user.id, user.id); assert_eq!(event.context.device.id, device.id); assert_eq!(event.context.public_ip, endpoint.ip()); @@ -64,17 +64,17 @@ async fn test_reusing_existing_connected_session_does_not_emit_duplicate_connect options: PgConnectOptions, ) { let pool = setup_pool(options).await; - let network = create_network(&pool).await; + let location = create_location(&pool).await; let user = create_user(&pool).await; let device = create_device(&pool, user.id).await; - attach_device_to_network(&pool, network.id, device.id).await; - let gateway = create_gateway(&pool, network.id, user.fullname()).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; let mut harness = SessionManagerHarness::new(pool.clone()); let connected_at = Utc::now().naive_utc() - TimeDelta::seconds(5); let _session = create_session( &pool, - network.id, + location.id, user.id, device.id, Some(connected_at), @@ -84,7 +84,7 @@ async fn test_reusing_existing_connected_session_does_not_emit_duplicate_connect let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); harness.send_stats(build_stats_update( - network.id, + location.id, gateway.id, &device.wireguard_pubkey, connected_at, @@ -105,17 +105,17 @@ async fn test_session_manager_emits_disconnect_event_for_inactive_standard_sessi options: PgConnectOptions, ) { let pool = setup_pool(options).await; - let network = create_network(&pool).await; + let location = create_location(&pool).await; let user = create_user(&pool).await; let device = create_device(&pool, user.id).await; - attach_device_to_network(&pool, network.id, device.id).await; - let gateway = create_gateway(&pool, network.id, user.fullname()).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; let mut harness = SessionManagerHarness::new(pool.clone()); - let stale_handshake = stale_session_timestamp(&network); + let stale_handshake = stale_session_timestamp(&location); let session = create_session( &pool, - network.id, + location.id, user.id, device.id, Some(stale_handshake), @@ -146,7 +146,7 @@ async fn test_session_manager_emits_disconnect_event_for_inactive_standard_sessi event.event, SessionManagerEventType::ClientDisconnected )); - assert_eq!(event.context.location.id, network.id); + assert_eq!(event.context.location.id, location.id); assert_eq!(event.context.user.id, user.id); assert_eq!(event.context.device.id, device.id); diff --git a/crates/defguard_session_manager/tests/session_manager/mfa.rs b/crates/defguard_session_manager/tests/session_manager/mfa.rs index 762e4fa65d..f167211854 100644 --- a/crates/defguard_session_manager/tests/session_manager/mfa.rs +++ b/crates/defguard_session_manager/tests/session_manager/mfa.rs @@ -16,10 +16,10 @@ use tokio::time::{Duration, timeout}; use crate::common::{ SessionManagerHarness, assert_no_gateway_events, assert_no_session_manager_events, - attach_device_to_network, authorize_device_in_network, build_stats_update, count_session_stats, - count_stats_for_device_location, create_device, create_gateway, create_network_with_mfa_mode, - create_session, create_session_stats, create_user, set_session_created_at, - stale_session_timestamp, truncate_timestamp, + attach_device_to_location, authorize_device_in_location, build_stats_update, + count_session_stats, count_stats_for_device_location, create_device, create_gateway, + create_location_with_mfa_mode, create_session, create_session_stats, create_user, + set_session_created_at, stale_session_timestamp, truncate_timestamp, }; const RECEIVE_TIMEOUT: Duration = Duration::from_secs(1); @@ -30,17 +30,17 @@ async fn test_mfa_location_stats_do_not_create_missing_session( options: PgConnectOptions, ) { let pool = setup_pool(options).await; - let network = create_network_with_mfa_mode(&pool, LocationMfaMode::Internal).await; + let location = create_location_with_mfa_mode(&pool, LocationMfaMode::Internal).await; let user = create_user(&pool).await; let device = create_device(&pool, user.id).await; - attach_device_to_network(&pool, network.id, device.id).await; - let gateway = create_gateway(&pool, network.id, user.fullname()).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; let mut harness = SessionManagerHarness::new(pool.clone()); let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); let timestamp = Utc::now().naive_utc(); harness.send_stats(build_stats_update( - network.id, + location.id, gateway.id, &device.wireguard_pubkey, timestamp, @@ -53,13 +53,13 @@ async fn test_mfa_location_stats_do_not_create_missing_session( let _ = harness.run_iteration().await; assert!( - VpnClientSession::try_get_active_session(&pool, network.id, device.id) + VpnClientSession::try_get_active_session(&pool, location.id, device.id) .await .expect("failed to query active session") .is_none() ); assert_eq!( - count_stats_for_device_location(&pool, device.id, network.id).await, + count_stats_for_device_location(&pool, device.id, location.id).await, 0 ); } @@ -70,16 +70,16 @@ async fn test_mfa_new_session_upgrades_to_connected_on_stats( options: PgConnectOptions, ) { let pool = setup_pool(options).await; - let network = create_network_with_mfa_mode(&pool, LocationMfaMode::Internal).await; + let location = create_location_with_mfa_mode(&pool, LocationMfaMode::Internal).await; let user = create_user(&pool).await; let device = create_device(&pool, user.id).await; - attach_device_to_network(&pool, network.id, device.id).await; - let gateway = create_gateway(&pool, network.id, user.fullname()).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; let mut harness = SessionManagerHarness::new(pool.clone()); let session = create_session( &pool, - network.id, + location.id, user.id, device.id, None, @@ -90,7 +90,7 @@ async fn test_mfa_new_session_upgrades_to_connected_on_stats( let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); let handshake = truncate_timestamp(Utc::now().naive_utc()); harness.send_stats(build_stats_update( - network.id, + location.id, gateway.id, &device.wireguard_pubkey, handshake, @@ -109,14 +109,14 @@ async fn test_mfa_new_session_upgrades_to_connected_on_stats( assert_eq!(refreshed_session.state, VpnClientSessionState::Connected); assert_eq!(refreshed_session.connected_at, Some(handshake)); assert_eq!( - count_stats_for_device_location(&pool, device.id, network.id).await, + count_stats_for_device_location(&pool, device.id, location.id).await, 1 ); let second_collected_at = handshake + TimeDelta::seconds(30); let second_handshake = handshake + TimeDelta::seconds(25); harness.send_stats(build_stats_update( - network.id, + location.id, gateway.id, &device.wireguard_pubkey, second_collected_at, @@ -136,7 +136,7 @@ async fn test_mfa_new_session_upgrades_to_connected_on_stats( assert_eq!(updated_session.connected_at, Some(handshake)); let active_sessions = - VpnClientSession::get_all_active_device_sessions_in_location(&pool, network.id, device.id) + VpnClientSession::get_all_active_device_sessions_in_location(&pool, location.id, device.id) .await .expect("failed to query active sessions"); assert_eq!(active_sessions.len(), 1); @@ -144,11 +144,11 @@ async fn test_mfa_new_session_upgrades_to_connected_on_stats( assert_eq!(count_session_stats(&pool, session.id).await, 2); assert_eq!( - count_stats_for_device_location(&pool, device.id, network.id).await, + count_stats_for_device_location(&pool, device.id, location.id).await, 2 ); - let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, network.id) + let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, location.id) .await .expect("failed to query latest stats") .expect("expected latest stats"); @@ -165,16 +165,16 @@ async fn test_duplicate_first_stats_on_mfa_new_session_are_idempotent( options: PgConnectOptions, ) { let pool = setup_pool(options).await; - let network = create_network_with_mfa_mode(&pool, LocationMfaMode::Internal).await; + let location = create_location_with_mfa_mode(&pool, LocationMfaMode::Internal).await; let user = create_user(&pool).await; let device = create_device(&pool, user.id).await; - attach_device_to_network(&pool, network.id, device.id).await; - let gateway = create_gateway(&pool, network.id, user.fullname()).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; let mut harness = SessionManagerHarness::new(pool.clone()); let session = create_session( &pool, - network.id, + location.id, user.id, device.id, None, @@ -186,7 +186,7 @@ async fn test_duplicate_first_stats_on_mfa_new_session_are_idempotent( let handshake = truncate_timestamp(Utc::now().naive_utc()); let duplicate_update = || { build_stats_update( - network.id, + location.id, gateway.id, &device.wireguard_pubkey, handshake, @@ -210,7 +210,7 @@ async fn test_duplicate_first_stats_on_mfa_new_session_are_idempotent( assert_eq!(refreshed_session.connected_at, Some(handshake)); let active_sessions = - VpnClientSession::get_all_active_device_sessions_in_location(&pool, network.id, device.id) + VpnClientSession::get_all_active_device_sessions_in_location(&pool, location.id, device.id) .await .expect("failed to query active sessions"); assert_eq!(active_sessions.len(), 1); @@ -218,7 +218,7 @@ async fn test_duplicate_first_stats_on_mfa_new_session_are_idempotent( assert_eq!(count_session_stats(&pool, session.id).await, 2); - let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, network.id) + let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, location.id) .await .expect("failed to query latest stats") .expect("expected latest stats"); @@ -236,16 +236,16 @@ async fn test_repeated_later_stats_on_mfa_session_remain_idempotent( options: PgConnectOptions, ) { let pool = setup_pool(options).await; - let network = create_network_with_mfa_mode(&pool, LocationMfaMode::Internal).await; + let location = create_location_with_mfa_mode(&pool, LocationMfaMode::Internal).await; let user = create_user(&pool).await; let device = create_device(&pool, user.id).await; - attach_device_to_network(&pool, network.id, device.id).await; - let gateway = create_gateway(&pool, network.id, user.fullname()).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; let mut harness = SessionManagerHarness::new(pool.clone()); let session = create_session( &pool, - network.id, + location.id, user.id, device.id, None, @@ -256,7 +256,7 @@ async fn test_repeated_later_stats_on_mfa_session_remain_idempotent( let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); let first_handshake = truncate_timestamp(Utc::now().naive_utc() - TimeDelta::seconds(30)); harness.send_stats(build_stats_update( - network.id, + location.id, gateway.id, &device.wireguard_pubkey, first_handshake, @@ -281,7 +281,7 @@ async fn test_repeated_later_stats_on_mfa_session_remain_idempotent( let later_collected_at = first_handshake + TimeDelta::seconds(30); let later_handshake = first_handshake + TimeDelta::seconds(20); harness.send_stats(build_stats_update( - network.id, + location.id, gateway.id, &device.wireguard_pubkey, later_collected_at, @@ -301,7 +301,7 @@ async fn test_repeated_later_stats_on_mfa_session_remain_idempotent( assert_eq!(refreshed_session.connected_at, Some(first_handshake)); let active_sessions = - VpnClientSession::get_all_active_device_sessions_in_location(&pool, network.id, device.id) + VpnClientSession::get_all_active_device_sessions_in_location(&pool, location.id, device.id) .await .expect("failed to query active sessions"); assert_eq!(active_sessions.len(), 1); @@ -309,7 +309,7 @@ async fn test_repeated_later_stats_on_mfa_session_remain_idempotent( assert_eq!(count_session_stats(&pool, session.id).await, 2); - let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, network.id) + let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, location.id) .await .expect("failed to query latest stats") .expect("expected latest stats"); @@ -327,18 +327,18 @@ async fn test_inactive_mfa_connected_sessions_disconnect_and_clear_authorization options: PgConnectOptions, ) { let pool = setup_pool(options).await; - let network = create_network_with_mfa_mode(&pool, LocationMfaMode::Internal).await; + let location = create_location_with_mfa_mode(&pool, LocationMfaMode::Internal).await; let user = create_user(&pool).await; let device = create_device(&pool, user.id).await; - attach_device_to_network(&pool, network.id, device.id).await; - authorize_device_in_network(&pool, network.id, device.id, "psk-before-disconnect").await; - let gateway = create_gateway(&pool, network.id, user.fullname()).await; + attach_device_to_location(&pool, location.id, device.id).await; + authorize_device_in_location(&pool, location.id, device.id, "psk-before-disconnect").await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; let mut harness = SessionManagerHarness::new(pool.clone()); - let stale_handshake = stale_session_timestamp(&network); + let stale_handshake = stale_session_timestamp(&location); let session = create_session( &pool, - network.id, + location.id, user.id, device.id, Some(stale_handshake), @@ -370,7 +370,7 @@ async fn test_inactive_mfa_connected_sessions_disconnect_and_clear_authorization VpnClientSessionState::Disconnected ); - let network_device = WireguardNetworkDevice::find(&pool, device.id, network.id) + let network_device = WireguardNetworkDevice::find(&pool, device.id, location.id) .await .expect("failed to query network device") .expect("expected network device"); @@ -383,7 +383,7 @@ async fn test_inactive_mfa_connected_sessions_disconnect_and_clear_authorization .expect("gateway event channel closed"); match gateway_event { GatewayEvent::MfaSessionDisconnected(location_id, disconnected_device) => { - assert_eq!(location_id, network.id); + assert_eq!(location_id, location.id); assert_eq!(disconnected_device.id, device.id); } other => panic!("unexpected gateway event: {other:?}"), @@ -396,22 +396,22 @@ async fn test_never_connected_mfa_new_sessions_disconnect_after_threshold( options: PgConnectOptions, ) { let pool = setup_pool(options).await; - let network = create_network_with_mfa_mode(&pool, LocationMfaMode::Internal).await; + let location = create_location_with_mfa_mode(&pool, LocationMfaMode::Internal).await; let user = create_user(&pool).await; let device = create_device(&pool, user.id).await; - attach_device_to_network(&pool, network.id, device.id).await; + attach_device_to_location(&pool, location.id, device.id).await; let mut harness = SessionManagerHarness::new(pool.clone()); let session = create_session( &pool, - network.id, + location.id, user.id, device.id, None, Some(VpnClientMfaMethod::Totp), ) .await; - set_session_created_at(&pool, session.id, stale_session_timestamp(&network)).await; + set_session_created_at(&pool, session.id, stale_session_timestamp(&location)).await; let _ = harness.run_idle_iteration().await; diff --git a/crates/defguard_session_manager/tests/session_manager/sessions.rs b/crates/defguard_session_manager/tests/session_manager/sessions.rs index ac25075800..77bd99d639 100644 --- a/crates/defguard_session_manager/tests/session_manager/sessions.rs +++ b/crates/defguard_session_manager/tests/session_manager/sessions.rs @@ -14,9 +14,9 @@ use tokio::time::{Duration, timeout}; use crate::common::{ SessionManagerHarness, assert_no_gateway_events, assert_no_session_manager_events, - attach_device_to_network, build_stats_update, count_session_stats, + attach_device_to_location, build_stats_update, count_session_stats, count_stats_for_device_location, create_device, create_device_with_pubkey, create_gateway, - create_network, create_session, create_session_stats, create_user, stale_session_timestamp, + create_location, create_session, create_session_stats, create_user, stale_session_timestamp, truncate_timestamp, }; @@ -28,17 +28,17 @@ async fn test_session_manager_creates_connected_session_from_first_stats( options: PgConnectOptions, ) { let pool = setup_pool(options).await; - let network = create_network(&pool).await; + let location = create_location(&pool).await; let user = create_user(&pool).await; let device = create_device(&pool, user.id).await; - attach_device_to_network(&pool, network.id, device.id).await; - let gateway = create_gateway(&pool, network.id, user.fullname()).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; let mut harness = SessionManagerHarness::new(pool.clone()); let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); let handshake = truncate_timestamp(Utc::now().naive_utc() - TimeDelta::seconds(5)); harness.send_stats(build_stats_update( - network.id, + location.id, gateway.id, &device.wireguard_pubkey, handshake, @@ -50,7 +50,7 @@ async fn test_session_manager_creates_connected_session_from_first_stats( let _ = harness.run_iteration().await; - let session = VpnClientSession::try_get_active_session(&pool, network.id, device.id) + let session = VpnClientSession::try_get_active_session(&pool, location.id, device.id) .await .expect("failed to query active session") .expect("expected active session"); @@ -64,18 +64,18 @@ async fn test_stale_first_stats_update_does_not_create_session_or_stats( options: PgConnectOptions, ) { let pool = setup_pool(options).await; - let network = create_network(&pool).await; + let location = create_location(&pool).await; let user = create_user(&pool).await; let device = create_device(&pool, user.id).await; - attach_device_to_network(&pool, network.id, device.id).await; - let gateway = create_gateway(&pool, network.id, user.fullname()).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; let mut harness = SessionManagerHarness::new(pool.clone()); let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); let collected_at = truncate_timestamp(Utc::now().naive_utc()); - let stale_handshake = stale_session_timestamp(&network); + let stale_handshake = stale_session_timestamp(&location); harness.send_stats(build_stats_update( - network.id, + location.id, gateway.id, &device.wireguard_pubkey, collected_at, @@ -88,13 +88,13 @@ async fn test_stale_first_stats_update_does_not_create_session_or_stats( let _ = harness.run_iteration().await; assert!( - VpnClientSession::try_get_active_session(&pool, network.id, device.id) + VpnClientSession::try_get_active_session(&pool, location.id, device.id) .await .expect("failed to query active session") .is_none() ); assert_eq!( - count_stats_for_device_location(&pool, device.id, network.id).await, + count_stats_for_device_location(&pool, device.id, location.id).await, 0 ); assert_no_session_manager_events(&mut harness); @@ -107,18 +107,18 @@ async fn test_duplicate_stats_in_same_batch_reuse_existing_session( options: PgConnectOptions, ) { let pool = setup_pool(options).await; - let network = create_network(&pool).await; + let location = create_location(&pool).await; let user = create_user(&pool).await; let device = create_device(&pool, user.id).await; - attach_device_to_network(&pool, network.id, device.id).await; - let gateway = create_gateway(&pool, network.id, user.fullname()).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; let mut harness = SessionManagerHarness::new(pool.clone()); let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); let base_time = Utc::now().naive_utc(); let duplicate_update = || { build_stats_update( - network.id, + location.id, gateway.id, &device.wireguard_pubkey, base_time, @@ -141,20 +141,20 @@ async fn test_duplicate_stats_in_same_batch_reuse_existing_session( connected_event.event, SessionManagerEventType::ClientConnected )); - assert_eq!(connected_event.context.location.id, network.id); + assert_eq!(connected_event.context.location.id, location.id); assert_eq!(connected_event.context.user.id, user.id); assert_eq!(connected_event.context.device.id, device.id); assert_eq!(connected_event.context.public_ip, endpoint.ip()); assert_no_session_manager_events(&mut harness); assert_no_gateway_events(&mut harness); - let active_session = VpnClientSession::try_get_active_session(&pool, network.id, device.id) + let active_session = VpnClientSession::try_get_active_session(&pool, location.id, device.id) .await .expect("failed to query active session") .expect("expected active session"); let sessions = - VpnClientSession::get_all_active_device_sessions_in_location(&pool, network.id, device.id) + VpnClientSession::get_all_active_device_sessions_in_location(&pool, location.id, device.id) .await .expect("failed to query active sessions"); assert_eq!(sessions.len(), 1); @@ -163,7 +163,7 @@ async fn test_duplicate_stats_in_same_batch_reuse_existing_session( let session = sessions.first().expect("expected active session"); assert_eq!(count_session_stats(&pool, session.id).await, 2); - let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, network.id) + let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, location.id) .await .expect("failed to query latest stats") .expect("expected latest stats"); @@ -178,17 +178,17 @@ async fn test_duplicate_stats_across_iterations_reuse_existing_session( options: PgConnectOptions, ) { let pool = setup_pool(options).await; - let network = create_network(&pool).await; + let location = create_location(&pool).await; let user = create_user(&pool).await; let device = create_device(&pool, user.id).await; - attach_device_to_network(&pool, network.id, device.id).await; - let gateway = create_gateway(&pool, network.id, user.fullname()).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; let mut harness = SessionManagerHarness::new(pool.clone()); let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); let base_time = Utc::now().naive_utc(); let update = build_stats_update( - network.id, + location.id, gateway.id, &device.wireguard_pubkey, base_time, @@ -201,7 +201,7 @@ async fn test_duplicate_stats_across_iterations_reuse_existing_session( harness.send_stats(update); let _ = harness.run_iteration().await; - let first_session = VpnClientSession::try_get_active_session(&pool, network.id, device.id) + let first_session = VpnClientSession::try_get_active_session(&pool, location.id, device.id) .await .expect("failed to query active session") .expect("expected active session"); @@ -218,7 +218,7 @@ async fn test_duplicate_stats_across_iterations_reuse_existing_session( connected_event.event, SessionManagerEventType::ClientConnected )); - assert_eq!(connected_event.context.location.id, network.id); + assert_eq!(connected_event.context.location.id, location.id); assert_eq!(connected_event.context.user.id, user.id); assert_eq!(connected_event.context.device.id, device.id); assert_eq!(connected_event.context.public_ip, endpoint.ip()); @@ -226,7 +226,7 @@ async fn test_duplicate_stats_across_iterations_reuse_existing_session( assert_no_gateway_events(&mut harness); harness.send_stats(build_stats_update( - network.id, + location.id, gateway.id, &device.wireguard_pubkey, base_time, @@ -241,14 +241,14 @@ async fn test_duplicate_stats_across_iterations_reuse_existing_session( assert_no_gateway_events(&mut harness); let sessions = - VpnClientSession::get_all_active_device_sessions_in_location(&pool, network.id, device.id) + VpnClientSession::get_all_active_device_sessions_in_location(&pool, location.id, device.id) .await .expect("failed to query active sessions"); assert_eq!(sessions.len(), 1); assert_eq!(sessions[0].id, first_session.id); assert_eq!(count_session_stats(&pool, first_session.id).await, 2); - let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, network.id) + let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, location.id) .await .expect("failed to query latest stats") .expect("expected latest stats"); @@ -263,20 +263,20 @@ async fn test_existing_new_session_becomes_connected_on_stats( options: PgConnectOptions, ) { let pool = setup_pool(options).await; - let network = create_network(&pool).await; + let location = create_location(&pool).await; let user = create_user(&pool).await; let device = create_device(&pool, user.id).await; - attach_device_to_network(&pool, network.id, device.id).await; - let gateway = create_gateway(&pool, network.id, user.fullname()).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; let mut harness = SessionManagerHarness::new(pool.clone()); - let existing_session = create_session(&pool, network.id, user.id, device.id, None, None).await; + let existing_session = create_session(&pool, location.id, user.id, device.id, None, None).await; assert_eq!(existing_session.state, VpnClientSessionState::New); let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); let handshake = truncate_timestamp(Utc::now().naive_utc()); harness.send_stats(build_stats_update( - network.id, + location.id, gateway.id, &device.wireguard_pubkey, handshake, @@ -303,17 +303,17 @@ async fn test_invalid_device_pubkey_updates_are_discarded( options: PgConnectOptions, ) { let pool = setup_pool(options).await; - let network = create_network(&pool).await; + let location = create_location(&pool).await; let user = create_user(&pool).await; let device = create_device_with_pubkey(&pool, user.id, "device-pubkey-valid").await; - attach_device_to_network(&pool, network.id, device.id).await; - let gateway = create_gateway(&pool, network.id, user.fullname()).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; let mut harness = SessionManagerHarness::new(pool.clone()); let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); let timestamp = Utc::now().naive_utc(); harness.send_stats(build_stats_update( - network.id, + location.id, gateway.id, "missing-pubkey", timestamp, @@ -325,12 +325,12 @@ async fn test_invalid_device_pubkey_updates_are_discarded( let _ = harness.run_iteration().await; - let maybe_session = VpnClientSession::try_get_active_session(&pool, network.id, device.id) + let maybe_session = VpnClientSession::try_get_active_session(&pool, location.id, device.id) .await .expect("failed to query active session"); assert!(maybe_session.is_none()); assert_eq!( - count_stats_for_device_location(&pool, device.id, network.id).await, + count_stats_for_device_location(&pool, device.id, location.id).await, 0 ); } @@ -338,17 +338,17 @@ async fn test_invalid_device_pubkey_updates_are_discarded( #[sqlx::test] async fn test_out_of_order_peer_updates_are_discarded(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let network = create_network(&pool).await; + let location = create_location(&pool).await; let user = create_user(&pool).await; let device = create_device(&pool, user.id).await; - attach_device_to_network(&pool, network.id, device.id).await; - let gateway = create_gateway(&pool, network.id, user.fullname()).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; let mut harness = SessionManagerHarness::new(pool.clone()); let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); let base_time = Utc::now().naive_utc(); harness.send_stats(build_stats_update( - network.id, + location.id, gateway.id, &device.wireguard_pubkey, base_time, @@ -359,14 +359,14 @@ async fn test_out_of_order_peer_updates_are_discarded(_: PgPoolOptions, options: )); let _ = harness.run_iteration().await; - let session = VpnClientSession::try_get_active_session(&pool, network.id, device.id) + let session = VpnClientSession::try_get_active_session(&pool, location.id, device.id) .await .expect("failed to query active session") .expect("expected active session"); assert_eq!(count_session_stats(&pool, session.id).await, 1); harness.send_stats(build_stats_update( - network.id, + location.id, gateway.id, &device.wireguard_pubkey, base_time - TimeDelta::seconds(1), @@ -379,7 +379,7 @@ async fn test_out_of_order_peer_updates_are_discarded(_: PgPoolOptions, options: assert_eq!(count_session_stats(&pool, session.id).await, 1); harness.send_stats(build_stats_update( - network.id, + location.id, gateway.id, &device.wireguard_pubkey, base_time + TimeDelta::seconds(2), @@ -392,7 +392,7 @@ async fn test_out_of_order_peer_updates_are_discarded(_: PgPoolOptions, options: assert_eq!(count_session_stats(&pool, session.id).await, 1); harness.send_stats(build_stats_update( - network.id, + location.id, gateway.id, &device.wireguard_pubkey, base_time + TimeDelta::seconds(3), @@ -411,18 +411,18 @@ async fn test_device_public_key_change_reuses_existing_session( options: PgConnectOptions, ) { let pool = setup_pool(options).await; - let network = create_network(&pool).await; + let location = create_location(&pool).await; let user = create_user(&pool).await; let mut device = create_device_with_pubkey(&pool, user.id, "device-pubkey-before-rotation").await; - attach_device_to_network(&pool, network.id, device.id).await; - let gateway = create_gateway(&pool, network.id, user.fullname()).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; let mut harness = SessionManagerHarness::new(pool.clone()); let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); let base_time = Utc::now().naive_utc(); harness.send_stats(build_stats_update( - network.id, + location.id, gateway.id, &device.wireguard_pubkey, base_time, @@ -433,7 +433,7 @@ async fn test_device_public_key_change_reuses_existing_session( )); let _ = harness.run_iteration().await; - let existing_session = VpnClientSession::try_get_active_session(&pool, network.id, device.id) + let existing_session = VpnClientSession::try_get_active_session(&pool, location.id, device.id) .await .expect("failed to query active session") .expect("expected active session"); @@ -445,7 +445,7 @@ async fn test_device_public_key_change_reuses_existing_session( .expect("failed to update device pubkey"); harness.send_stats(build_stats_update( - network.id, + location.id, gateway.id, &device.wireguard_pubkey, base_time + TimeDelta::seconds(10), @@ -457,14 +457,14 @@ async fn test_device_public_key_change_reuses_existing_session( let _ = harness.run_iteration().await; let sessions = - VpnClientSession::get_all_active_device_sessions_in_location(&pool, network.id, device.id) + VpnClientSession::get_all_active_device_sessions_in_location(&pool, location.id, device.id) .await .expect("failed to query active sessions"); assert_eq!(sessions.len(), 1); assert_eq!(sessions[0].id, existing_session.id); assert_eq!(count_session_stats(&pool, existing_session.id).await, 2); - let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, network.id) + let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, location.id) .await .expect("failed to query latest stats") .expect("expected latest stats"); @@ -479,18 +479,18 @@ async fn test_existing_session_in_db_is_reused_instead_of_creating_duplicate( options: PgConnectOptions, ) { let pool = setup_pool(options).await; - let network = create_network(&pool).await; + let location = create_location(&pool).await; let user = create_user(&pool).await; let device = create_device(&pool, user.id).await; - attach_device_to_network(&pool, network.id, device.id).await; - let gateway = create_gateway(&pool, network.id, user.fullname()).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; let mut harness = SessionManagerHarness::new(pool.clone()); let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); let base_time = Utc::now().naive_utc(); let existing_session = create_session( &pool, - network.id, + location.id, user.id, device.id, Some(base_time - TimeDelta::seconds(5)), @@ -512,7 +512,7 @@ async fn test_existing_session_in_db_is_reused_instead_of_creating_duplicate( .await; harness.send_stats(build_stats_update( - network.id, + location.id, gateway.id, &device.wireguard_pubkey, base_time, @@ -524,13 +524,13 @@ async fn test_existing_session_in_db_is_reused_instead_of_creating_duplicate( let _ = harness.run_iteration().await; let sessions = - VpnClientSession::get_all_active_device_sessions_in_location(&pool, network.id, device.id) + VpnClientSession::get_all_active_device_sessions_in_location(&pool, location.id, device.id) .await .expect("failed to query active sessions"); assert_eq!(sessions.len(), 1); assert_eq!(sessions[0].id, existing_session.id); - let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, network.id) + let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, location.id) .await .expect("failed to query latest stats") .expect("expected latest stats"); diff --git a/crates/defguard_session_manager/tests/session_manager/stats.rs b/crates/defguard_session_manager/tests/session_manager/stats.rs index 4364878454..a8a6cf3293 100644 --- a/crates/defguard_session_manager/tests/session_manager/stats.rs +++ b/crates/defguard_session_manager/tests/session_manager/stats.rs @@ -8,8 +8,8 @@ use defguard_common::db::{ use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use crate::common::{ - SessionManagerHarness, attach_device_to_network, build_stats_update, count_session_stats, - create_device, create_gateway, create_gateway_named, create_network, create_session, + SessionManagerHarness, attach_device_to_location, build_stats_update, count_session_stats, + create_device, create_gateway, create_gateway_named, create_location, create_session, create_session_stats, create_user, }; @@ -19,17 +19,17 @@ async fn test_session_manager_updates_stats_deltas_across_iterations( options: PgConnectOptions, ) { let pool = setup_pool(options).await; - let network = create_network(&pool).await; + let location = create_location(&pool).await; let user = create_user(&pool).await; let device = create_device(&pool, user.id).await; - attach_device_to_network(&pool, network.id, device.id).await; - let gateway = create_gateway(&pool, network.id, user.fullname()).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; let mut harness = SessionManagerHarness::new(pool.clone()); let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); let base_time = Utc::now().naive_utc(); harness.send_stats(build_stats_update( - network.id, + location.id, gateway.id, &device.wireguard_pubkey, base_time, @@ -40,7 +40,7 @@ async fn test_session_manager_updates_stats_deltas_across_iterations( )); let _ = harness.run_iteration().await; - let first_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, network.id) + let first_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, location.id) .await .expect("failed to query session stats") .expect("expected session stats"); @@ -48,7 +48,7 @@ async fn test_session_manager_updates_stats_deltas_across_iterations( assert_eq!(first_stats.download_diff, 0); harness.send_stats(build_stats_update( - network.id, + location.id, gateway.id, &device.wireguard_pubkey, base_time + TimeDelta::seconds(10), @@ -59,7 +59,7 @@ async fn test_session_manager_updates_stats_deltas_across_iterations( )); let _ = harness.run_iteration().await; - let second_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, network.id) + let second_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, location.id) .await .expect("failed to query session stats") .expect("expected session stats"); @@ -67,7 +67,7 @@ async fn test_session_manager_updates_stats_deltas_across_iterations( assert_eq!(second_stats.download_diff, 60); harness.send_stats(build_stats_update( - network.id, + location.id, gateway.id, &device.wireguard_pubkey, base_time + TimeDelta::seconds(20), @@ -78,7 +78,7 @@ async fn test_session_manager_updates_stats_deltas_across_iterations( )); let _ = harness.run_iteration().await; - let third_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, network.id) + let third_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, location.id) .await .expect("failed to query session stats") .expect("expected session stats"); @@ -92,18 +92,18 @@ async fn test_session_manager_calculates_stats_per_gateway( options: PgConnectOptions, ) { let pool = setup_pool(options).await; - let network = create_network(&pool).await; + let location = create_location(&pool).await; let user = create_user(&pool).await; let device = create_device(&pool, user.id).await; - attach_device_to_network(&pool, network.id, device.id).await; - let gateway_one = create_gateway_named(&pool, network.id, user.fullname(), "gateway-1").await; - let gateway_two = create_gateway_named(&pool, network.id, user.fullname(), "gateway-2").await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway_one = create_gateway_named(&pool, location.id, user.fullname(), "gateway-1").await; + let gateway_two = create_gateway_named(&pool, location.id, user.fullname(), "gateway-2").await; let mut harness = SessionManagerHarness::new(pool.clone()); let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); let base_time = Utc::now().naive_utc(); harness.send_stats(build_stats_update( - network.id, + location.id, gateway_one.id, &device.wireguard_pubkey, base_time, @@ -115,7 +115,7 @@ async fn test_session_manager_calculates_stats_per_gateway( let _ = harness.run_iteration().await; harness.send_stats(build_stats_update( - network.id, + location.id, gateway_one.id, &device.wireguard_pubkey, base_time + TimeDelta::seconds(10), @@ -127,7 +127,7 @@ async fn test_session_manager_calculates_stats_per_gateway( let _ = harness.run_iteration().await; harness.send_stats(build_stats_update( - network.id, + location.id, gateway_two.id, &device.wireguard_pubkey, base_time + TimeDelta::seconds(20), @@ -139,7 +139,7 @@ async fn test_session_manager_calculates_stats_per_gateway( let _ = harness.run_iteration().await; harness.send_stats(build_stats_update( - network.id, + location.id, gateway_two.id, &device.wireguard_pubkey, base_time + TimeDelta::seconds(30), @@ -150,7 +150,7 @@ async fn test_session_manager_calculates_stats_per_gateway( )); let _ = harness.run_iteration().await; - let session = VpnClientSession::try_get_active_session(&pool, network.id, device.id) + let session = VpnClientSession::try_get_active_session(&pool, location.id, device.id) .await .expect("failed to query active session") .expect("expected active session"); @@ -187,18 +187,18 @@ async fn test_out_of_order_updates_for_existing_db_session_are_discarded( options: PgConnectOptions, ) { let pool = setup_pool(options).await; - let network = create_network(&pool).await; + let location = create_location(&pool).await; let user = create_user(&pool).await; let device = create_device(&pool, user.id).await; - attach_device_to_network(&pool, network.id, device.id).await; - let gateway = create_gateway(&pool, network.id, user.fullname()).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; let mut harness = SessionManagerHarness::new(pool.clone()); let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); let first_handshake = Utc::now().naive_utc() - TimeDelta::seconds(5); let existing_session = create_session( &pool, - network.id, + location.id, user.id, device.id, Some(first_handshake), @@ -220,7 +220,7 @@ async fn test_out_of_order_updates_for_existing_db_session_are_discarded( .await; harness.send_stats(build_stats_update( - network.id, + location.id, gateway.id, &device.wireguard_pubkey, first_handshake - TimeDelta::seconds(1), @@ -232,7 +232,7 @@ async fn test_out_of_order_updates_for_existing_db_session_are_discarded( let _ = harness.run_iteration().await; assert_eq!(count_session_stats(&pool, existing_session.id).await, 1); - let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, network.id) + let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, location.id) .await .expect("failed to query latest stats") .expect("expected latest stats");