diff --git a/crates/defguard_session_manager/tests/common/mod.rs b/crates/defguard_session_manager/tests/common/mod.rs index 005530c367..f967e088cb 100644 --- a/crates/defguard_session_manager/tests/common/mod.rs +++ b/crates/defguard_session_manager/tests/common/mod.rs @@ -1,5 +1,9 @@ -use std::net::{IpAddr, Ipv4Addr}; +use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + time::Duration, +}; +use chrono::{NaiveDateTime, TimeDelta, Timelike, Utc}; use defguard_common::{ db::{ Id, @@ -7,27 +11,63 @@ use defguard_common::{ Device, DeviceType, User, WireguardNetwork, device::WireguardNetworkDevice, gateway::Gateway, + vpn_client_session::{VpnClientMfaMethod, VpnClientSession}, + vpn_session_stats::VpnSessionStats, wireguard::{LocationMfaMode, ServiceLocationMode}, }, }, messages::peer_stats_update::PeerStatsUpdate, }; -use defguard_session_manager::{SessionManager, events::SessionManagerEvent}; +use defguard_session_manager::{ + IterationOutcome, SESSION_UPDATE_INTERVAL, SessionManager, events::SessionManagerEvent, + run_session_manager_iteration, +}; use ipnetwork::IpNetwork; -use tokio::sync::{broadcast, mpsc}; +use sqlx::{PgExecutor, query, query_scalar}; +use tokio::{ + sync::{ + broadcast, + mpsc::{self}, + }, + time::interval, +}; pub(crate) struct SessionManagerHarness { pub(crate) manager: SessionManager, stats_tx: mpsc::UnboundedSender, pub(crate) stats_rx: mpsc::UnboundedReceiver, pub(crate) event_rx: mpsc::UnboundedReceiver, + pub(crate) gateway_rx: broadcast::Receiver, +} + +pub(crate) fn assert_no_session_manager_events(harness: &mut SessionManagerHarness) { + match harness.event_rx.try_recv() { + Err(mpsc::error::TryRecvError::Empty) => {} + Err(mpsc::error::TryRecvError::Disconnected) => { + panic!("session manager event channel disconnected unexpectedly") + } + Ok(event) => panic!("unexpected session manager event: {event:?}"), + } +} + +pub(crate) fn assert_no_gateway_events(harness: &mut SessionManagerHarness) { + match harness.gateway_rx.try_recv() { + Err(broadcast::error::TryRecvError::Empty) => {} + Err(broadcast::error::TryRecvError::Closed) => { + panic!("gateway event channel closed unexpectedly") + } + Err(broadcast::error::TryRecvError::Lagged(skipped)) => { + panic!("gateway event channel lagged and skipped {skipped} events") + } + Ok(event) => panic!("unexpected gateway event: {event:?}"), + } } impl SessionManagerHarness { pub(crate) fn new(pool: sqlx::PgPool) -> Self { let (stats_tx, stats_rx) = mpsc::unbounded_channel(); let (event_tx, event_rx) = mpsc::unbounded_channel(); - let (gateway_tx, _gateway_rx) = broadcast::channel(16); + let (gateway_tx, gateway_rx) = broadcast::channel(16); let manager = SessionManager::new(pool, event_tx, gateway_tx); Self { @@ -35,6 +75,7 @@ impl SessionManagerHarness { stats_tx, stats_rx, event_rx, + gateway_rx, } } @@ -43,9 +84,38 @@ impl SessionManagerHarness { .send(update) .expect("failed to send peer stats update"); } + + pub(crate) async fn run_iteration(&mut self) -> IterationOutcome { + let mut session_update_timer = interval(Duration::from_secs(SESSION_UPDATE_INTERVAL)); + run_session_manager_iteration( + &mut self.manager, + &mut self.stats_rx, + &mut session_update_timer, + ) + .await + .expect("session manager iteration failed") + } + + pub(crate) async fn run_idle_iteration(&mut self) -> IterationOutcome { + let mut session_update_timer = interval(Duration::from_millis(1)); + run_session_manager_iteration( + &mut self.manager, + &mut self.stats_rx, + &mut session_update_timer, + ) + .await + .expect("session manager iteration failed") + } +} + +pub(crate) async fn create_location(pool: &sqlx::PgPool) -> WireguardNetwork { + create_location_with_mfa_mode(pool, LocationMfaMode::Disabled).await } -pub(crate) async fn create_network(pool: &sqlx::PgPool) -> WireguardNetwork { +pub(crate) async fn create_location_with_mfa_mode( + pool: &sqlx::PgPool, + location_mfa_mode: LocationMfaMode, +) -> WireguardNetwork { WireguardNetwork::new( "TestNet".to_string(), vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 0)), 24).unwrap()], @@ -59,12 +129,12 @@ pub(crate) async fn create_network(pool: &sqlx::PgPool) -> WireguardNetwork 300, false, false, - LocationMfaMode::Disabled, + location_mfa_mode, ServiceLocationMode::Disabled, ) .save(pool) .await - .expect("failed to create Wireguard network") + .expect("failed to create Wireguard location") } pub(crate) async fn create_user(pool: &sqlx::PgPool) -> User { @@ -82,9 +152,17 @@ pub(crate) async fn create_user(pool: &sqlx::PgPool) -> User { } pub(crate) async fn create_device(pool: &sqlx::PgPool, user_id: Id) -> Device { + create_device_with_pubkey(pool, user_id, "device-pubkey-test").await +} + +pub(crate) async fn create_device_with_pubkey( + pool: &sqlx::PgPool, + user_id: Id, + wireguard_pubkey: &str, +) -> Device { Device::new( "session-test-device".to_string(), - "device-pubkey-test".to_string(), + wireguard_pubkey.to_string(), user_id, DeviceType::User, None, @@ -95,26 +173,35 @@ pub(crate) async fn create_device(pool: &sqlx::PgPool, user_id: Id) -> Device Gateway { + create_gateway_named(pool, location_id, modified_by, "gateway-1").await +} + +pub(crate) async fn create_gateway_named( + pool: &sqlx::PgPool, + location_id: Id, modified_by: String, + name: &str, ) -> Gateway { Gateway::new( - network_id, - "gateway-1".to_string(), + location_id, + name.to_string(), "127.0.0.1".to_string(), 51820, modified_by, @@ -123,3 +210,142 @@ pub(crate) async fn create_gateway( .await .expect("failed to create gateway") } + +pub(crate) async fn authorize_device_in_location( + pool: &sqlx::PgPool, + location_id: Id, + device_id: Id, + preshared_key: &str, +) { + let mut network_device = WireguardNetworkDevice::find(pool, device_id, location_id) + .await + .expect("failed to load device network info") + .expect("expected device network info"); + network_device.is_authorized = true; + network_device.authorized_at = Some(chrono::Utc::now().naive_utc()); + network_device.preshared_key = Some(preshared_key.to_string()); + network_device + .update(pool) + .await + .expect("failed to authorize device in location"); +} + +#[allow(clippy::too_many_arguments)] +pub(crate) fn build_stats_update( + location_id: Id, + gateway_id: Id, + device_pubkey: impl Into, + collected_at: NaiveDateTime, + endpoint: SocketAddr, + upload: u64, + download: u64, + latest_handshake: NaiveDateTime, +) -> PeerStatsUpdate { + PeerStatsUpdate { + location_id, + gateway_id, + device_pubkey: device_pubkey.into(), + collected_at: truncate_timestamp(collected_at), + endpoint, + upload, + download, + latest_handshake: truncate_timestamp(latest_handshake), + } +} + +pub(crate) fn truncate_timestamp(timestamp: NaiveDateTime) -> NaiveDateTime { + timestamp + .with_nanosecond((timestamp.nanosecond() / 1_000) * 1_000) + .expect("failed to truncate timestamp precision") +} + +pub(crate) fn stale_session_timestamp(location: &WireguardNetwork) -> NaiveDateTime { + let reference_time = Utc::now().naive_utc(); + reference_time + .checked_sub_signed(TimeDelta::seconds( + i64::from(location.peer_disconnect_threshold) + 1, + )) + .expect("reference timestamp should stay within range") +} + +pub(crate) async fn create_session( + pool: &sqlx::PgPool, + location_id: Id, + user_id: Id, + device_id: Id, + connected_at: Option, + mfa_method: Option, +) -> VpnClientSession { + VpnClientSession::new(location_id, user_id, device_id, connected_at, mfa_method) + .save(pool) + .await + .expect("failed to create vpn client session") +} + +#[allow(clippy::too_many_arguments)] +pub(crate) async fn create_session_stats( + pool: &sqlx::PgPool, + session_id: Id, + gateway_id: Id, + collected_at: NaiveDateTime, + latest_handshake: NaiveDateTime, + endpoint: SocketAddr, + total_upload: i64, + total_download: i64, + upload_diff: i64, + download_diff: i64, +) -> VpnSessionStats { + VpnSessionStats::new( + session_id, + gateway_id, + collected_at, + latest_handshake, + endpoint.to_string(), + total_upload, + total_download, + upload_diff, + download_diff, + ) + .save(pool) + .await + .expect("failed to create vpn session stats") +} + +pub(crate) async fn set_session_created_at<'e, E: PgExecutor<'e>>( + executor: E, + session_id: Id, + created_at: NaiveDateTime, +) { + query("UPDATE vpn_client_session SET created_at = $1 WHERE id = $2") + .bind(created_at) + .bind(session_id) + .execute(executor) + .await + .expect("failed to update session created_at"); +} + +pub(crate) async fn count_session_stats<'e, E: PgExecutor<'e>>(executor: E, session_id: Id) -> i64 { + query_scalar("SELECT COUNT(*) FROM vpn_session_stats WHERE session_id = $1") + .bind(session_id) + .fetch_one(executor) + .await + .expect("failed to count vpn session stats") +} + +pub(crate) async fn count_stats_for_device_location<'e, E: PgExecutor<'e>>( + executor: E, + device_id: Id, + location_id: Id, +) -> i64 { + query_scalar( + "SELECT COUNT(*) \ + FROM vpn_session_stats stats \ + JOIN vpn_client_session session ON stats.session_id = session.id \ + WHERE session.device_id = $1 AND session.location_id = $2", + ) + .bind(device_id) + .bind(location_id) + .fetch_one(executor) + .await + .expect("failed to count device session stats") +} diff --git a/crates/defguard_session_manager/tests/session_manager/disconnects.rs b/crates/defguard_session_manager/tests/session_manager/disconnects.rs new file mode 100644 index 0000000000..938c428761 --- /dev/null +++ b/crates/defguard_session_manager/tests/session_manager/disconnects.rs @@ -0,0 +1,107 @@ +use std::net::SocketAddr; + +use chrono::{TimeDelta, Utc}; +use defguard_common::db::{ + models::vpn_client_session::{VpnClientSession, VpnClientSessionState}, + setup_pool, +}; +use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; + +use crate::common::{ + SessionManagerHarness, attach_device_to_location, create_device, create_gateway, + create_location, create_session, create_session_stats, create_user, stale_session_timestamp, +}; + +#[sqlx::test] +async fn test_inactive_connected_sessions_are_disconnected_after_threshold( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let location = create_location(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let stale_handshake = stale_session_timestamp(&location); + let session = create_session( + &pool, + location.id, + user.id, + device.id, + Some(stale_handshake), + None, + ) + .await; + create_session_stats( + &pool, + session.id, + gateway.id, + stale_handshake, + stale_handshake, + "203.0.113.10:51820".parse::().unwrap(), + 100, + 200, + 0, + 0, + ) + .await; + + let _ = harness.run_idle_iteration().await; + + let disconnected_session = VpnClientSession::find_by_id(&pool, session.id) + .await + .expect("failed to query session") + .expect("expected session"); + assert_eq!( + disconnected_session.state, + VpnClientSessionState::Disconnected + ); + assert!(disconnected_session.disconnected_at.is_some()); +} + +#[sqlx::test] +async fn test_recent_connected_sessions_remain_active(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + let location = create_location(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let recent_handshake = Utc::now().naive_utc() - TimeDelta::seconds(30); + let session = create_session( + &pool, + location.id, + user.id, + device.id, + Some(recent_handshake), + None, + ) + .await; + create_session_stats( + &pool, + session.id, + gateway.id, + recent_handshake, + recent_handshake, + "203.0.113.10:51820".parse::().unwrap(), + 100, + 200, + 0, + 0, + ) + .await; + + let _ = harness.run_idle_iteration().await; + + let refreshed_session = VpnClientSession::find_by_id(&pool, session.id) + .await + .expect("failed to query session") + .expect("expected session"); + assert_eq!(refreshed_session.state, VpnClientSessionState::Connected); + assert!(refreshed_session.disconnected_at.is_none()); +} diff --git a/crates/defguard_session_manager/tests/session_manager/event_flow.rs b/crates/defguard_session_manager/tests/session_manager/event_flow.rs index 8ffe26d73c..744ddd0893 100644 --- a/crates/defguard_session_manager/tests/session_manager/event_flow.rs +++ b/crates/defguard_session_manager/tests/session_manager/event_flow.rs @@ -1,65 +1,158 @@ use std::net::SocketAddr; use chrono::{TimeDelta, Utc}; -use defguard_common::{db::setup_pool, messages::peer_stats_update::PeerStatsUpdate}; -use defguard_session_manager::{ - SESSION_UPDATE_INTERVAL, events::SessionManagerEventType, run_session_manager_iteration, -}; +use defguard_common::db::{models::vpn_client_session::VpnClientSession, setup_pool}; +use defguard_session_manager::events::SessionManagerEventType; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; -use tokio::time::{Duration, interval}; +use tokio::time::{Duration, timeout}; use crate::common::{ - SessionManagerHarness, attach_device_to_network, create_device, create_gateway, create_network, - create_user, + SessionManagerHarness, attach_device_to_location, build_stats_update, create_device, + create_gateway, create_location, create_session, create_session_stats, create_user, + stale_session_timestamp, }; +const RECEIVE_TIMEOUT: Duration = Duration::from_secs(1); + #[sqlx::test] -async fn test_session_manager_emits_connected_event(_: PgPoolOptions, options: PgConnectOptions) { +async fn test_session_manager_emits_connected_event_for_first_stats( + _: PgPoolOptions, + options: PgConnectOptions, +) { let pool = setup_pool(options).await; - let network = create_network(&pool).await; + let location = create_location(&pool).await; let user = create_user(&pool).await; let device = create_device(&pool, user.id).await; - attach_device_to_network(&pool, network.id, device.id).await; - let gateway = create_gateway(&pool, network.id, user.fullname()).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; let mut harness = SessionManagerHarness::new(pool); let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); - let base_time = Utc::now().naive_utc(); - let update = PeerStatsUpdate { - location_id: network.id, - gateway_id: gateway.id, - device_pubkey: device.wireguard_pubkey.clone(), - collected_at: base_time, + let handshake = Utc::now().naive_utc() - TimeDelta::seconds(5); + harness.send_stats(build_stats_update( + location.id, + gateway.id, + &device.wireguard_pubkey, + handshake, endpoint, - upload: 100, - download: 200, - latest_handshake: base_time - TimeDelta::seconds(5), - }; - - harness.send_stats(update); - - let mut session_update_timer = interval(Duration::from_secs(SESSION_UPDATE_INTERVAL)); - let _ = run_session_manager_iteration( - &mut harness.manager, - &mut harness.stats_rx, - &mut session_update_timer, - ) - .await - .expect("session manager iteration failed"); + 100, + 200, + handshake, + )); + + let _ = harness.run_iteration().await; - let event = harness - .event_rx - .recv() + let event = timeout(RECEIVE_TIMEOUT, harness.event_rx.recv()) .await + .expect("timed out waiting for ClientConnected event") .expect("session manager event channel closed"); assert!(matches!( event.event, SessionManagerEventType::ClientConnected )); - assert_eq!(event.context.location.id, network.id); + assert_eq!(event.context.location.id, location.id); assert_eq!(event.context.user.id, user.id); assert_eq!(event.context.device.id, device.id); assert_eq!(event.context.public_ip, endpoint.ip()); } + +#[sqlx::test] +async fn test_reusing_existing_connected_session_does_not_emit_duplicate_connected_event( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let location = create_location(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let connected_at = Utc::now().naive_utc() - TimeDelta::seconds(5); + let _session = create_session( + &pool, + location.id, + user.id, + device.id, + Some(connected_at), + None, + ) + .await; + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + harness.send_stats(build_stats_update( + location.id, + gateway.id, + &device.wireguard_pubkey, + connected_at, + endpoint, + 100, + 200, + connected_at, + )); + + let _ = harness.run_iteration().await; + + assert!(harness.event_rx.try_recv().is_err()); +} + +#[sqlx::test] +async fn test_session_manager_emits_disconnect_event_for_inactive_standard_session( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let location = create_location(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let stale_handshake = stale_session_timestamp(&location); + let session = create_session( + &pool, + location.id, + user.id, + device.id, + Some(stale_handshake), + None, + ) + .await; + create_session_stats( + &pool, + session.id, + gateway.id, + stale_handshake, + stale_handshake, + "203.0.113.10:51820".parse().unwrap(), + 100, + 200, + 0, + 0, + ) + .await; + + let _ = harness.run_idle_iteration().await; + + let event = timeout(RECEIVE_TIMEOUT, harness.event_rx.recv()) + .await + .expect("timed out waiting for ClientDisconnected event") + .expect("session manager event channel closed"); + assert!(matches!( + event.event, + SessionManagerEventType::ClientDisconnected + )); + assert_eq!(event.context.location.id, location.id); + assert_eq!(event.context.user.id, user.id); + assert_eq!(event.context.device.id, device.id); + + let disconnected_session = VpnClientSession::find_by_id(&pool, session.id) + .await + .expect("failed to query session") + .expect("expected session"); + assert!(disconnected_session.disconnected_at.is_some()); +} diff --git a/crates/defguard_session_manager/tests/session_manager/mfa.rs b/crates/defguard_session_manager/tests/session_manager/mfa.rs new file mode 100644 index 0000000000..f167211854 --- /dev/null +++ b/crates/defguard_session_manager/tests/session_manager/mfa.rs @@ -0,0 +1,427 @@ +use std::net::SocketAddr; + +use chrono::{TimeDelta, Utc}; +use defguard_common::db::{ + models::{ + device::WireguardNetworkDevice, + vpn_client_session::{VpnClientMfaMethod, VpnClientSession, VpnClientSessionState}, + vpn_session_stats::VpnSessionStats, + wireguard::LocationMfaMode, + }, + setup_pool, +}; +use defguard_core::grpc::GatewayEvent; +use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; +use tokio::time::{Duration, timeout}; + +use crate::common::{ + SessionManagerHarness, assert_no_gateway_events, assert_no_session_manager_events, + attach_device_to_location, authorize_device_in_location, build_stats_update, + count_session_stats, count_stats_for_device_location, create_device, create_gateway, + create_location_with_mfa_mode, create_session, create_session_stats, create_user, + set_session_created_at, stale_session_timestamp, truncate_timestamp, +}; + +const RECEIVE_TIMEOUT: Duration = Duration::from_secs(1); + +#[sqlx::test] +async fn test_mfa_location_stats_do_not_create_missing_session( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let location = create_location_with_mfa_mode(&pool, LocationMfaMode::Internal).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let timestamp = Utc::now().naive_utc(); + harness.send_stats(build_stats_update( + location.id, + gateway.id, + &device.wireguard_pubkey, + timestamp, + endpoint, + 100, + 200, + timestamp, + )); + + let _ = harness.run_iteration().await; + + assert!( + VpnClientSession::try_get_active_session(&pool, location.id, device.id) + .await + .expect("failed to query active session") + .is_none() + ); + assert_eq!( + count_stats_for_device_location(&pool, device.id, location.id).await, + 0 + ); +} + +#[sqlx::test] +async fn test_mfa_new_session_upgrades_to_connected_on_stats( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let location = create_location_with_mfa_mode(&pool, LocationMfaMode::Internal).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let session = create_session( + &pool, + location.id, + user.id, + device.id, + None, + Some(VpnClientMfaMethod::Totp), + ) + .await; + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let handshake = truncate_timestamp(Utc::now().naive_utc()); + harness.send_stats(build_stats_update( + location.id, + gateway.id, + &device.wireguard_pubkey, + handshake, + endpoint, + 100, + 200, + handshake, + )); + + let _ = harness.run_iteration().await; + + let refreshed_session = VpnClientSession::find_by_id(&pool, session.id) + .await + .expect("failed to query session") + .expect("expected session"); + assert_eq!(refreshed_session.state, VpnClientSessionState::Connected); + assert_eq!(refreshed_session.connected_at, Some(handshake)); + assert_eq!( + count_stats_for_device_location(&pool, device.id, location.id).await, + 1 + ); + + let second_collected_at = handshake + TimeDelta::seconds(30); + let second_handshake = handshake + TimeDelta::seconds(25); + harness.send_stats(build_stats_update( + location.id, + gateway.id, + &device.wireguard_pubkey, + second_collected_at, + endpoint, + 160, + 280, + second_handshake, + )); + + let _ = harness.run_iteration().await; + + let updated_session = VpnClientSession::find_by_id(&pool, session.id) + .await + .expect("failed to query session") + .expect("expected session"); + assert_eq!(updated_session.state, VpnClientSessionState::Connected); + assert_eq!(updated_session.connected_at, Some(handshake)); + + let active_sessions = + VpnClientSession::get_all_active_device_sessions_in_location(&pool, location.id, device.id) + .await + .expect("failed to query active sessions"); + assert_eq!(active_sessions.len(), 1); + assert_eq!(active_sessions[0].id, session.id); + + assert_eq!(count_session_stats(&pool, session.id).await, 2); + assert_eq!( + count_stats_for_device_location(&pool, device.id, location.id).await, + 2 + ); + + let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, location.id) + .await + .expect("failed to query latest stats") + .expect("expected latest stats"); + assert_eq!(latest_stats.session_id, session.id); + assert_eq!(latest_stats.total_upload, 160); + assert_eq!(latest_stats.total_download, 280); + assert_eq!(latest_stats.upload_diff, 60); + assert_eq!(latest_stats.download_diff, 80); +} + +#[sqlx::test] +async fn test_duplicate_first_stats_on_mfa_new_session_are_idempotent( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let location = create_location_with_mfa_mode(&pool, LocationMfaMode::Internal).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let session = create_session( + &pool, + location.id, + user.id, + device.id, + None, + Some(VpnClientMfaMethod::Totp), + ) + .await; + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let handshake = truncate_timestamp(Utc::now().naive_utc()); + let duplicate_update = || { + build_stats_update( + location.id, + gateway.id, + &device.wireguard_pubkey, + handshake, + endpoint, + 100, + 200, + handshake, + ) + }; + + harness.send_stats(duplicate_update()); + harness.send_stats(duplicate_update()); + + let _ = harness.run_iteration().await; + + let refreshed_session = VpnClientSession::find_by_id(&pool, session.id) + .await + .expect("failed to query session") + .expect("expected session"); + assert_eq!(refreshed_session.state, VpnClientSessionState::Connected); + assert_eq!(refreshed_session.connected_at, Some(handshake)); + + let active_sessions = + VpnClientSession::get_all_active_device_sessions_in_location(&pool, location.id, device.id) + .await + .expect("failed to query active sessions"); + assert_eq!(active_sessions.len(), 1); + assert_eq!(active_sessions[0].id, session.id); + + assert_eq!(count_session_stats(&pool, session.id).await, 2); + + let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, location.id) + .await + .expect("failed to query latest stats") + .expect("expected latest stats"); + assert_eq!(latest_stats.session_id, session.id); + assert_eq!(latest_stats.upload_diff, 0); + assert_eq!(latest_stats.download_diff, 0); + + assert_no_session_manager_events(&mut harness); + assert_no_gateway_events(&mut harness); +} + +#[sqlx::test] +async fn test_repeated_later_stats_on_mfa_session_remain_idempotent( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let location = create_location_with_mfa_mode(&pool, LocationMfaMode::Internal).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let session = create_session( + &pool, + location.id, + user.id, + device.id, + None, + Some(VpnClientMfaMethod::Totp), + ) + .await; + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let first_handshake = truncate_timestamp(Utc::now().naive_utc() - TimeDelta::seconds(30)); + harness.send_stats(build_stats_update( + location.id, + gateway.id, + &device.wireguard_pubkey, + first_handshake, + endpoint, + 100, + 200, + first_handshake, + )); + + let _ = harness.run_iteration().await; + + let connected_session = VpnClientSession::find_by_id(&pool, session.id) + .await + .expect("failed to query session") + .expect("expected session"); + assert_eq!(connected_session.state, VpnClientSessionState::Connected); + assert_eq!(connected_session.connected_at, Some(first_handshake)); + + assert_no_session_manager_events(&mut harness); + assert_no_gateway_events(&mut harness); + + let later_collected_at = first_handshake + TimeDelta::seconds(30); + let later_handshake = first_handshake + TimeDelta::seconds(20); + harness.send_stats(build_stats_update( + location.id, + gateway.id, + &device.wireguard_pubkey, + later_collected_at, + endpoint, + 100, + 200, + later_handshake, + )); + + let _ = harness.run_iteration().await; + + let refreshed_session = VpnClientSession::find_by_id(&pool, session.id) + .await + .expect("failed to query session") + .expect("expected session"); + assert_eq!(refreshed_session.state, VpnClientSessionState::Connected); + assert_eq!(refreshed_session.connected_at, Some(first_handshake)); + + let active_sessions = + VpnClientSession::get_all_active_device_sessions_in_location(&pool, location.id, device.id) + .await + .expect("failed to query active sessions"); + assert_eq!(active_sessions.len(), 1); + assert_eq!(active_sessions[0].id, session.id); + + assert_eq!(count_session_stats(&pool, session.id).await, 2); + + let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, location.id) + .await + .expect("failed to query latest stats") + .expect("expected latest stats"); + assert_eq!(latest_stats.session_id, session.id); + assert_eq!(latest_stats.upload_diff, 0); + assert_eq!(latest_stats.download_diff, 0); + + assert_no_session_manager_events(&mut harness); + assert_no_gateway_events(&mut harness); +} + +#[sqlx::test] +async fn test_inactive_mfa_connected_sessions_disconnect_and_clear_authorization( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let location = create_location_with_mfa_mode(&pool, LocationMfaMode::Internal).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_location(&pool, location.id, device.id).await; + authorize_device_in_location(&pool, location.id, device.id, "psk-before-disconnect").await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let stale_handshake = stale_session_timestamp(&location); + let session = create_session( + &pool, + location.id, + user.id, + device.id, + Some(stale_handshake), + Some(VpnClientMfaMethod::Totp), + ) + .await; + create_session_stats( + &pool, + session.id, + gateway.id, + stale_handshake, + stale_handshake, + "203.0.113.10:51820".parse().unwrap(), + 100, + 200, + 0, + 0, + ) + .await; + + let _ = harness.run_idle_iteration().await; + + let disconnected_session = VpnClientSession::find_by_id(&pool, session.id) + .await + .expect("failed to query session") + .expect("expected session"); + assert_eq!( + disconnected_session.state, + VpnClientSessionState::Disconnected + ); + + let network_device = WireguardNetworkDevice::find(&pool, device.id, location.id) + .await + .expect("failed to query network device") + .expect("expected network device"); + assert!(!network_device.is_authorized); + assert_eq!(network_device.preshared_key, None); + + let gateway_event = timeout(RECEIVE_TIMEOUT, harness.gateway_rx.recv()) + .await + .expect("timed out waiting for MFA disconnect gateway event") + .expect("gateway event channel closed"); + match gateway_event { + GatewayEvent::MfaSessionDisconnected(location_id, disconnected_device) => { + assert_eq!(location_id, location.id); + assert_eq!(disconnected_device.id, device.id); + } + other => panic!("unexpected gateway event: {other:?}"), + } +} + +#[sqlx::test] +async fn test_never_connected_mfa_new_sessions_disconnect_after_threshold( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let location = create_location_with_mfa_mode(&pool, LocationMfaMode::Internal).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_location(&pool, location.id, device.id).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let session = create_session( + &pool, + location.id, + user.id, + device.id, + None, + Some(VpnClientMfaMethod::Totp), + ) + .await; + set_session_created_at(&pool, session.id, stale_session_timestamp(&location)).await; + + let _ = harness.run_idle_iteration().await; + + let disconnected_session = VpnClientSession::find_by_id(&pool, session.id) + .await + .expect("failed to query session") + .expect("expected session"); + assert_eq!( + disconnected_session.state, + VpnClientSessionState::Disconnected + ); + assert!(disconnected_session.disconnected_at.is_some()); +} diff --git a/crates/defguard_session_manager/tests/session_manager/mod.rs b/crates/defguard_session_manager/tests/session_manager/mod.rs index a16bea70cf..1cf8461d46 100644 --- a/crates/defguard_session_manager/tests/session_manager/mod.rs +++ b/crates/defguard_session_manager/tests/session_manager/mod.rs @@ -1,3 +1,5 @@ +mod disconnects; mod event_flow; +mod mfa; mod sessions; mod stats; diff --git a/crates/defguard_session_manager/tests/session_manager/sessions.rs b/crates/defguard_session_manager/tests/session_manager/sessions.rs index 8d05bc51fe..77bd99d639 100644 --- a/crates/defguard_session_manager/tests/session_manager/sessions.rs +++ b/crates/defguard_session_manager/tests/session_manager/sessions.rs @@ -1,57 +1,540 @@ +use std::net::SocketAddr; + use chrono::{TimeDelta, Utc}; -use defguard_common::{ - db::{ - models::vpn_client_session::{VpnClientSession, VpnClientSessionState}, - setup_pool, +use defguard_common::db::{ + models::{ + vpn_client_session::{VpnClientSession, VpnClientSessionState}, + vpn_session_stats::VpnSessionStats, }, - messages::peer_stats_update::PeerStatsUpdate, + setup_pool, }; -use defguard_session_manager::{SESSION_UPDATE_INTERVAL, run_session_manager_iteration}; +use defguard_session_manager::events::SessionManagerEventType; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; -use tokio::time::{Duration, interval}; +use tokio::time::{Duration, timeout}; use crate::common::{ - SessionManagerHarness, attach_device_to_network, create_device, create_gateway, create_network, - create_user, + SessionManagerHarness, assert_no_gateway_events, assert_no_session_manager_events, + attach_device_to_location, build_stats_update, count_session_stats, + count_stats_for_device_location, create_device, create_device_with_pubkey, create_gateway, + create_location, create_session, create_session_stats, create_user, stale_session_timestamp, + truncate_timestamp, }; +const RECEIVE_TIMEOUT: Duration = Duration::from_secs(1); + #[sqlx::test] -async fn test_session_manager_creates_active_session(_: PgPoolOptions, options: PgConnectOptions) { +async fn test_session_manager_creates_connected_session_from_first_stats( + _: PgPoolOptions, + options: PgConnectOptions, +) { let pool = setup_pool(options).await; - let network = create_network(&pool).await; + let location = create_location(&pool).await; let user = create_user(&pool).await; let device = create_device(&pool, user.id).await; - attach_device_to_network(&pool, network.id, device.id).await; - let gateway = create_gateway(&pool, network.id, user.fullname()).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let handshake = truncate_timestamp(Utc::now().naive_utc() - TimeDelta::seconds(5)); + harness.send_stats(build_stats_update( + location.id, + gateway.id, + &device.wireguard_pubkey, + handshake, + endpoint, + 100, + 200, + handshake, + )); + + let _ = harness.run_iteration().await; + + let session = VpnClientSession::try_get_active_session(&pool, location.id, device.id) + .await + .expect("failed to query active session") + .expect("expected active session"); + assert_eq!(session.state, VpnClientSessionState::Connected); + assert_eq!(session.connected_at, Some(handshake)); +} +#[sqlx::test] +async fn test_stale_first_stats_update_does_not_create_session_or_stats( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let location = create_location(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; let mut harness = SessionManagerHarness::new(pool.clone()); + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let collected_at = truncate_timestamp(Utc::now().naive_utc()); + let stale_handshake = stale_session_timestamp(&location); + harness.send_stats(build_stats_update( + location.id, + gateway.id, + &device.wireguard_pubkey, + collected_at, + endpoint, + 100, + 200, + stale_handshake, + )); + + let _ = harness.run_iteration().await; + + assert!( + VpnClientSession::try_get_active_session(&pool, location.id, device.id) + .await + .expect("failed to query active session") + .is_none() + ); + assert_eq!( + count_stats_for_device_location(&pool, device.id, location.id).await, + 0 + ); + assert_no_session_manager_events(&mut harness); + assert_no_gateway_events(&mut harness); +} + +#[sqlx::test] +async fn test_duplicate_stats_in_same_batch_reuse_existing_session( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let location = create_location(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); let base_time = Utc::now().naive_utc(); - let update = PeerStatsUpdate { - location_id: network.id, - gateway_id: gateway.id, - device_pubkey: device.wireguard_pubkey.clone(), - collected_at: base_time, - endpoint: "203.0.113.10:51820".parse().unwrap(), - upload: 100, - download: 200, - latest_handshake: base_time - TimeDelta::seconds(5), + let duplicate_update = || { + build_stats_update( + location.id, + gateway.id, + &device.wireguard_pubkey, + base_time, + endpoint, + 100, + 200, + base_time, + ) }; + harness.send_stats(duplicate_update()); + harness.send_stats(duplicate_update()); + + let _ = harness.run_iteration().await; + + let connected_event = timeout(RECEIVE_TIMEOUT, harness.event_rx.recv()) + .await + .expect("timed out waiting for ClientConnected event in duplicate same-batch stats test") + .expect("session manager event channel closed while waiting for duplicate same-batch stats event"); + assert!(matches!( + connected_event.event, + SessionManagerEventType::ClientConnected + )); + assert_eq!(connected_event.context.location.id, location.id); + assert_eq!(connected_event.context.user.id, user.id); + assert_eq!(connected_event.context.device.id, device.id); + assert_eq!(connected_event.context.public_ip, endpoint.ip()); + assert_no_session_manager_events(&mut harness); + assert_no_gateway_events(&mut harness); + + let active_session = VpnClientSession::try_get_active_session(&pool, location.id, device.id) + .await + .expect("failed to query active session") + .expect("expected active session"); + + let sessions = + VpnClientSession::get_all_active_device_sessions_in_location(&pool, location.id, device.id) + .await + .expect("failed to query active sessions"); + assert_eq!(sessions.len(), 1); + assert_eq!(sessions[0].id, active_session.id); + + let session = sessions.first().expect("expected active session"); + assert_eq!(count_session_stats(&pool, session.id).await, 2); + + let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, location.id) + .await + .expect("failed to query latest stats") + .expect("expected latest stats"); + assert_eq!(latest_stats.session_id, session.id); + assert_eq!(latest_stats.upload_diff, 0); + assert_eq!(latest_stats.download_diff, 0); +} + +#[sqlx::test] +async fn test_duplicate_stats_across_iterations_reuse_existing_session( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let location = create_location(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let base_time = Utc::now().naive_utc(); + let update = build_stats_update( + location.id, + gateway.id, + &device.wireguard_pubkey, + base_time, + endpoint, + 100, + 200, + base_time, + ); harness.send_stats(update); + let _ = harness.run_iteration().await; - let mut session_update_timer = interval(Duration::from_secs(SESSION_UPDATE_INTERVAL)); - let _ = run_session_manager_iteration( - &mut harness.manager, - &mut harness.stats_rx, - &mut session_update_timer, - ) - .await - .expect("session manager iteration failed"); + let first_session = VpnClientSession::try_get_active_session(&pool, location.id, device.id) + .await + .expect("failed to query active session") + .expect("expected active session"); + + let connected_event = timeout(RECEIVE_TIMEOUT, harness.event_rx.recv()) + .await + .expect( + "timed out waiting for ClientConnected event in duplicate cross-iteration stats test", + ) + .expect( + "session manager event channel closed while waiting for duplicate cross-iteration stats event", + ); + assert!(matches!( + connected_event.event, + SessionManagerEventType::ClientConnected + )); + assert_eq!(connected_event.context.location.id, location.id); + assert_eq!(connected_event.context.user.id, user.id); + assert_eq!(connected_event.context.device.id, device.id); + assert_eq!(connected_event.context.public_ip, endpoint.ip()); + assert_no_session_manager_events(&mut harness); + assert_no_gateway_events(&mut harness); + + harness.send_stats(build_stats_update( + location.id, + gateway.id, + &device.wireguard_pubkey, + base_time, + endpoint, + 100, + 200, + base_time, + )); + let _ = harness.run_iteration().await; + + assert_no_session_manager_events(&mut harness); + assert_no_gateway_events(&mut harness); + + let sessions = + VpnClientSession::get_all_active_device_sessions_in_location(&pool, location.id, device.id) + .await + .expect("failed to query active sessions"); + assert_eq!(sessions.len(), 1); + assert_eq!(sessions[0].id, first_session.id); + assert_eq!(count_session_stats(&pool, first_session.id).await, 2); + + let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, location.id) + .await + .expect("failed to query latest stats") + .expect("expected latest stats"); + assert_eq!(latest_stats.session_id, first_session.id); + assert_eq!(latest_stats.upload_diff, 0); + assert_eq!(latest_stats.download_diff, 0); +} + +#[sqlx::test] +async fn test_existing_new_session_becomes_connected_on_stats( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let location = create_location(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let existing_session = create_session(&pool, location.id, user.id, device.id, None, None).await; + assert_eq!(existing_session.state, VpnClientSessionState::New); + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let handshake = truncate_timestamp(Utc::now().naive_utc()); + harness.send_stats(build_stats_update( + location.id, + gateway.id, + &device.wireguard_pubkey, + handshake, + endpoint, + 100, + 200, + handshake, + )); + + let _ = harness.run_iteration().await; + + let updated_session = VpnClientSession::find_by_id(&pool, existing_session.id) + .await + .expect("failed to query session") + .expect("expected session"); + assert_eq!(updated_session.state, VpnClientSessionState::Connected); + assert_eq!(updated_session.connected_at, Some(handshake)); + assert_eq!(count_session_stats(&pool, updated_session.id).await, 1); +} + +#[sqlx::test] +async fn test_invalid_device_pubkey_updates_are_discarded( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let location = create_location(&pool).await; + let user = create_user(&pool).await; + let device = create_device_with_pubkey(&pool, user.id, "device-pubkey-valid").await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); - let session = VpnClientSession::try_get_active_session(&pool, network.id, device.id) + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let timestamp = Utc::now().naive_utc(); + harness.send_stats(build_stats_update( + location.id, + gateway.id, + "missing-pubkey", + timestamp, + endpoint, + 100, + 200, + timestamp, + )); + + let _ = harness.run_iteration().await; + + let maybe_session = VpnClientSession::try_get_active_session(&pool, location.id, device.id) + .await + .expect("failed to query active session"); + assert!(maybe_session.is_none()); + assert_eq!( + count_stats_for_device_location(&pool, device.id, location.id).await, + 0 + ); +} + +#[sqlx::test] +async fn test_out_of_order_peer_updates_are_discarded(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + let location = create_location(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let base_time = Utc::now().naive_utc(); + harness.send_stats(build_stats_update( + location.id, + gateway.id, + &device.wireguard_pubkey, + base_time, + endpoint, + 100, + 200, + base_time, + )); + let _ = harness.run_iteration().await; + + let session = VpnClientSession::try_get_active_session(&pool, location.id, device.id) .await .expect("failed to query active session") .expect("expected active session"); - assert_eq!(session.state, VpnClientSessionState::Connected); + assert_eq!(count_session_stats(&pool, session.id).await, 1); + + harness.send_stats(build_stats_update( + location.id, + gateway.id, + &device.wireguard_pubkey, + base_time - TimeDelta::seconds(1), + endpoint, + 150, + 260, + base_time + TimeDelta::seconds(1), + )); + let _ = harness.run_iteration().await; + assert_eq!(count_session_stats(&pool, session.id).await, 1); + + harness.send_stats(build_stats_update( + location.id, + gateway.id, + &device.wireguard_pubkey, + base_time + TimeDelta::seconds(2), + endpoint, + 150, + 260, + base_time - TimeDelta::seconds(1), + )); + let _ = harness.run_iteration().await; + assert_eq!(count_session_stats(&pool, session.id).await, 1); + + harness.send_stats(build_stats_update( + location.id, + gateway.id, + &device.wireguard_pubkey, + base_time + TimeDelta::seconds(3), + endpoint, + 90, + 190, + base_time + TimeDelta::seconds(3), + )); + let _ = harness.run_iteration().await; + assert_eq!(count_session_stats(&pool, session.id).await, 1); +} + +#[sqlx::test] +async fn test_device_public_key_change_reuses_existing_session( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let location = create_location(&pool).await; + let user = create_user(&pool).await; + let mut device = + create_device_with_pubkey(&pool, user.id, "device-pubkey-before-rotation").await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let base_time = Utc::now().naive_utc(); + harness.send_stats(build_stats_update( + location.id, + gateway.id, + &device.wireguard_pubkey, + base_time, + endpoint, + 100, + 200, + base_time, + )); + let _ = harness.run_iteration().await; + + let existing_session = VpnClientSession::try_get_active_session(&pool, location.id, device.id) + .await + .expect("failed to query active session") + .expect("expected active session"); + + device.wireguard_pubkey = "device-pubkey-after-rotation".to_string(); + device + .save(&pool) + .await + .expect("failed to update device pubkey"); + + harness.send_stats(build_stats_update( + location.id, + gateway.id, + &device.wireguard_pubkey, + base_time + TimeDelta::seconds(10), + endpoint, + 150, + 260, + base_time + TimeDelta::seconds(10), + )); + let _ = harness.run_iteration().await; + + let sessions = + VpnClientSession::get_all_active_device_sessions_in_location(&pool, location.id, device.id) + .await + .expect("failed to query active sessions"); + assert_eq!(sessions.len(), 1); + assert_eq!(sessions[0].id, existing_session.id); + assert_eq!(count_session_stats(&pool, existing_session.id).await, 2); + + let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, location.id) + .await + .expect("failed to query latest stats") + .expect("expected latest stats"); + assert_eq!(latest_stats.session_id, existing_session.id); + assert_eq!(latest_stats.total_upload, 150); + assert_eq!(latest_stats.total_download, 260); +} + +#[sqlx::test] +async fn test_existing_session_in_db_is_reused_instead_of_creating_duplicate( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let location = create_location(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let base_time = Utc::now().naive_utc(); + let existing_session = create_session( + &pool, + location.id, + user.id, + device.id, + Some(base_time - TimeDelta::seconds(5)), + None, + ) + .await; + create_session_stats( + &pool, + existing_session.id, + gateway.id, + base_time - TimeDelta::seconds(5), + base_time - TimeDelta::seconds(5), + endpoint, + 100, + 200, + 0, + 0, + ) + .await; + + harness.send_stats(build_stats_update( + location.id, + gateway.id, + &device.wireguard_pubkey, + base_time, + endpoint, + 160, + 280, + base_time, + )); + let _ = harness.run_iteration().await; + + let sessions = + VpnClientSession::get_all_active_device_sessions_in_location(&pool, location.id, device.id) + .await + .expect("failed to query active sessions"); + assert_eq!(sessions.len(), 1); + assert_eq!(sessions[0].id, existing_session.id); + + let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, location.id) + .await + .expect("failed to query latest stats") + .expect("expected latest stats"); + assert_eq!(latest_stats.session_id, existing_session.id); + assert_eq!(latest_stats.upload_diff, 60); + assert_eq!(latest_stats.download_diff, 80); } diff --git a/crates/defguard_session_manager/tests/session_manager/stats.rs b/crates/defguard_session_manager/tests/session_manager/stats.rs index 542cf5467f..a8a6cf3293 100644 --- a/crates/defguard_session_manager/tests/session_manager/stats.rs +++ b/crates/defguard_session_manager/tests/session_manager/stats.rs @@ -1,86 +1,242 @@ use std::net::SocketAddr; use chrono::{TimeDelta, Utc}; -use defguard_common::{ - db::{models::vpn_session_stats::VpnSessionStats, setup_pool}, - messages::peer_stats_update::PeerStatsUpdate, +use defguard_common::db::{ + models::{vpn_client_session::VpnClientSession, vpn_session_stats::VpnSessionStats}, + setup_pool, }; -use defguard_session_manager::{SESSION_UPDATE_INTERVAL, run_session_manager_iteration}; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; -use tokio::time::{Duration, interval}; use crate::common::{ - SessionManagerHarness, attach_device_to_network, create_device, create_gateway, create_network, - create_user, + SessionManagerHarness, attach_device_to_location, build_stats_update, count_session_stats, + create_device, create_gateway, create_gateway_named, create_location, create_session, + create_session_stats, create_user, }; #[sqlx::test] -async fn test_session_manager_updates_stats(_: PgPoolOptions, options: PgConnectOptions) { +async fn test_session_manager_updates_stats_deltas_across_iterations( + _: PgPoolOptions, + options: PgConnectOptions, +) { let pool = setup_pool(options).await; - let network = create_network(&pool).await; + let location = create_location(&pool).await; let user = create_user(&pool).await; let device = create_device(&pool, user.id).await; - attach_device_to_network(&pool, network.id, device.id).await; - let gateway = create_gateway(&pool, network.id, user.fullname()).await; - + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; let mut harness = SessionManagerHarness::new(pool.clone()); let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); let base_time = Utc::now().naive_utc(); - let first_update = PeerStatsUpdate { - location_id: network.id, - gateway_id: gateway.id, - device_pubkey: device.wireguard_pubkey.clone(), - collected_at: base_time, + harness.send_stats(build_stats_update( + location.id, + gateway.id, + &device.wireguard_pubkey, + base_time, endpoint, - upload: 100, - download: 200, - latest_handshake: base_time - TimeDelta::seconds(5), - }; - - harness.send_stats(first_update); - - let mut session_update_timer = interval(Duration::from_secs(SESSION_UPDATE_INTERVAL)); - let _ = run_session_manager_iteration( - &mut harness.manager, - &mut harness.stats_rx, - &mut session_update_timer, - ) - .await - .expect("session manager iteration failed"); + 100, + 200, + base_time - TimeDelta::seconds(5), + )); + let _ = harness.run_iteration().await; - let first_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, network.id) + let first_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, location.id) .await .expect("failed to query session stats") .expect("expected session stats"); assert_eq!(first_stats.upload_diff, 0); assert_eq!(first_stats.download_diff, 0); - let second_update = PeerStatsUpdate { - location_id: network.id, - gateway_id: gateway.id, - device_pubkey: device.wireguard_pubkey.clone(), - collected_at: base_time + TimeDelta::seconds(10), + harness.send_stats(build_stats_update( + location.id, + gateway.id, + &device.wireguard_pubkey, + base_time + TimeDelta::seconds(10), endpoint, - upload: 150, - download: 260, - latest_handshake: base_time + TimeDelta::seconds(10), - }; + 150, + 260, + base_time + TimeDelta::seconds(10), + )); + let _ = harness.run_iteration().await; - harness.send_stats(second_update); - - let _ = run_session_manager_iteration( - &mut harness.manager, - &mut harness.stats_rx, - &mut session_update_timer, - ) - .await - .expect("session manager iteration failed"); - - let second_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, network.id) + let second_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, location.id) .await .expect("failed to query session stats") .expect("expected session stats"); assert_eq!(second_stats.upload_diff, 50); assert_eq!(second_stats.download_diff, 60); + + harness.send_stats(build_stats_update( + location.id, + gateway.id, + &device.wireguard_pubkey, + base_time + TimeDelta::seconds(20), + endpoint, + 180, + 330, + base_time + TimeDelta::seconds(20), + )); + let _ = harness.run_iteration().await; + + let third_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, location.id) + .await + .expect("failed to query session stats") + .expect("expected session stats"); + assert_eq!(third_stats.upload_diff, 30); + assert_eq!(third_stats.download_diff, 70); +} + +#[sqlx::test] +async fn test_session_manager_calculates_stats_per_gateway( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let location = create_location(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway_one = create_gateway_named(&pool, location.id, user.fullname(), "gateway-1").await; + let gateway_two = create_gateway_named(&pool, location.id, user.fullname(), "gateway-2").await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let base_time = Utc::now().naive_utc(); + harness.send_stats(build_stats_update( + location.id, + gateway_one.id, + &device.wireguard_pubkey, + base_time, + endpoint, + 100, + 200, + base_time, + )); + let _ = harness.run_iteration().await; + + harness.send_stats(build_stats_update( + location.id, + gateway_one.id, + &device.wireguard_pubkey, + base_time + TimeDelta::seconds(10), + endpoint, + 130, + 240, + base_time + TimeDelta::seconds(10), + )); + let _ = harness.run_iteration().await; + + harness.send_stats(build_stats_update( + location.id, + gateway_two.id, + &device.wireguard_pubkey, + base_time + TimeDelta::seconds(20), + endpoint, + 500, + 700, + base_time + TimeDelta::seconds(20), + )); + let _ = harness.run_iteration().await; + + harness.send_stats(build_stats_update( + location.id, + gateway_two.id, + &device.wireguard_pubkey, + base_time + TimeDelta::seconds(30), + endpoint, + 560, + 780, + base_time + TimeDelta::seconds(30), + )); + let _ = harness.run_iteration().await; + + let session = VpnClientSession::try_get_active_session(&pool, location.id, device.id) + .await + .expect("failed to query active session") + .expect("expected active session"); + let gateway_stats = session + .get_latest_stats_for_all_gateways(&pool) + .await + .expect("failed to query gateway stats"); + assert_eq!(gateway_stats.len(), 2); + + let stats_for_gateway_one = gateway_stats + .iter() + .find(|stats| stats.gateway_id == gateway_one.id) + .expect("expected gateway one stats"); + assert_eq!(stats_for_gateway_one.total_upload, 130); + assert_eq!(stats_for_gateway_one.total_download, 240); + assert_eq!(stats_for_gateway_one.upload_diff, 30); + assert_eq!(stats_for_gateway_one.download_diff, 40); + + let stats_for_gateway_two = gateway_stats + .iter() + .find(|stats| stats.gateway_id == gateway_two.id) + .expect("expected gateway two stats"); + assert_eq!(stats_for_gateway_two.total_upload, 560); + assert_eq!(stats_for_gateway_two.total_download, 780); + assert_eq!(stats_for_gateway_two.upload_diff, 60); + assert_eq!(stats_for_gateway_two.download_diff, 80); + + assert_eq!(count_session_stats(&pool, session.id).await, 4); +} + +#[sqlx::test] +async fn test_out_of_order_updates_for_existing_db_session_are_discarded( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let location = create_location(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let first_handshake = Utc::now().naive_utc() - TimeDelta::seconds(5); + let existing_session = create_session( + &pool, + location.id, + user.id, + device.id, + Some(first_handshake), + None, + ) + .await; + create_session_stats( + &pool, + existing_session.id, + gateway.id, + first_handshake, + first_handshake, + endpoint, + 100, + 200, + 0, + 0, + ) + .await; + + harness.send_stats(build_stats_update( + location.id, + gateway.id, + &device.wireguard_pubkey, + first_handshake - TimeDelta::seconds(1), + endpoint, + 110, + 210, + first_handshake, + )); + let _ = harness.run_iteration().await; + + assert_eq!(count_session_stats(&pool, existing_session.id).await, 1); + let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, location.id) + .await + .expect("failed to query latest stats") + .expect("expected latest stats"); + assert_eq!(latest_stats.session_id, existing_session.id); + assert_eq!(latest_stats.total_upload, 100); + assert_eq!(latest_stats.total_download, 200); }