From c4b7dff6eca5bd079b959efedcd3228eac1e07a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Mon, 16 Mar 2026 17:10:58 +0100 Subject: [PATCH 01/18] emit disconnect events during MFA reconnect --- .../src/db/models/activity_log/mod.rs | 2 + crates/defguard_core/src/events.rs | 4 + .../src/grpc/proxy/client_mfa.rs | 286 +++++++++++++++++- crates/defguard_event_logger/Cargo.toml | 3 + .../defguard_event_logger/src/description.rs | 6 + crates/defguard_event_logger/src/lib.rs | 197 +++++++++--- crates/defguard_event_logger/src/message.rs | 8 + .../src/handlers/bidi.rs | 7 + .../src/handlers/session_manager.rs | 14 +- 9 files changed, 479 insertions(+), 48 deletions(-) diff --git a/crates/defguard_core/src/db/models/activity_log/mod.rs b/crates/defguard_core/src/db/models/activity_log/mod.rs index 0938d190b6..9dfa15c6eb 100644 --- a/crates/defguard_core/src/db/models/activity_log/mod.rs +++ b/crates/defguard_core/src/db/models/activity_log/mod.rs @@ -75,6 +75,8 @@ pub enum EventType { // VPN client events VpnClientConnected, VpnClientDisconnected, + VpnClientMfaConnected, + VpnClientMfaDisconnected, VpnClientMfaSuccess, VpnClientMfaFailed, // Enrollment events diff --git a/crates/defguard_core/src/events.rs b/crates/defguard_core/src/events.rs index fbd458762f..a9f7eab7a7 100644 --- a/crates/defguard_core/src/events.rs +++ b/crates/defguard_core/src/events.rs @@ -392,4 +392,8 @@ pub enum DesktopClientMfaEvent { method: ClientMFAMethod, message: String, }, + Disconnected { + device: Device, + location: WireguardNetwork, + }, } diff --git a/crates/defguard_core/src/grpc/proxy/client_mfa.rs b/crates/defguard_core/src/grpc/proxy/client_mfa.rs index 9f27f76af1..d7a34274dc 100644 --- a/crates/defguard_core/src/grpc/proxy/client_mfa.rs +++ b/crates/defguard_core/src/grpc/proxy/client_mfa.rs @@ -787,7 +787,7 @@ impl ClientMfaServer { // disconnect all active sessions for session in active_sessions { debug!("Disconnecting previous active MFA VPN session {session:?}."); - self.disconnect_session(&mut *conn, session, location, device) + self.disconnect_mfa_session(&mut *conn, session, location, user, device) .await?; } @@ -800,11 +800,12 @@ impl ClientMfaServer { } /// Update session state as disconnected and send relevant gateway update - async fn disconnect_session( + async fn disconnect_mfa_session( &self, conn: &mut PgConnection, mut session: VpnClientSession, location: &WireguardNetwork, + user: &User, device: &Device, ) -> Result<(), Status> { // update session state in DB @@ -839,14 +840,283 @@ impl ClientMfaServer { Status::internal("unexpected error") })?; } - let event = GatewayEvent::MfaSessionDisconnected(location.id, device.clone()); - self.wireguard_tx.send(event).map_err(|err| { - error!("Error sending WireGuard event: {err}"); - Status::internal("unexpected error") - })?; - // FIXME: add audit log event + // only emit disconnect events if MFA session has actually connected + // and not for New sessions + if session.state == VpnClientSessionState::Connected { + let gateway_event = GatewayEvent::MfaSessionDisconnected(location.id, device.clone()); + self.wireguard_tx.send(gateway_event).map_err(|err| { + error!("Error sending WireGuard event: {err}"); + Status::internal("unexpected error") + })?; + + let context = BidiRequestContext { + timestamp: disconnect_timestamp, + user_id: user.id, + username: user.username.clone(), + ip: std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED), + device_name: format!("{} (ID {})", device.name, device.id), + }; + self.emit_event(BidiStreamEvent { + context, + event: BidiStreamEventType::DesktopClientMfa(Box::new( + DesktopClientMfaEvent::Disconnected { + location: location.clone(), + device: device.clone(), + }, + )), + }) + .map_err(Status::from)?; + } Ok(()) } } + +#[cfg(test)] +mod tests { + use std::{ + collections::HashMap, + net::{IpAddr, Ipv4Addr}, + sync::{Arc, RwLock}, + }; + + use defguard_common::db::{ + models::{DeviceType, device::WireguardNetworkDevice, wireguard::ServiceLocationMode}, + setup_pool, + }; + use ipnetwork::IpNetwork; + use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; + use tokio::sync::{broadcast, mpsc::unbounded_channel, oneshot}; + + use super::*; + + #[sqlx::test] + async fn test_replacing_connected_mfa_session_emits_disconnect_event( + _: PgPoolOptions, + options: PgConnectOptions, + ) { + let pool = setup_pool(options).await; + let location = create_mfa_location(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_location(&pool, location.id, device.id).await; + let old_session = VpnClientSession::new( + location.id, + user.id, + device.id, + Some(Utc::now().naive_utc()), + Some(VpnClientMfaMethod::Totp), + ) + .save(&pool) + .await + .expect("failed to create existing MFA session"); + + let (server, mut event_rx, _gateway_rx) = make_server(pool.clone()); + let mut conn = pool.acquire().await.expect("failed to acquire connection"); + + server + .create_new_mfa_session( + &mut conn, + &location, + &user, + &device, + VpnClientMfaMethod::Totp, + ) + .await + .expect("should replace connected MFA session"); + + let event = event_rx + .try_recv() + .expect("expected MFA disconnect event for replaced connected session"); + assert!(matches!( + event.event, + BidiStreamEventType::DesktopClientMfa(event) + if matches!(*event, DesktopClientMfaEvent::Disconnected { .. }) + )); + assert_eq!(event.context.user_id, user.id); + assert_eq!(event.context.username, user.username); + + let old_session = VpnClientSession::find_by_id(&pool, old_session.id) + .await + .expect("failed to query old session") + .expect("expected old session"); + assert_eq!(old_session.state, VpnClientSessionState::Disconnected); + } + + #[sqlx::test] + async fn test_replacing_new_mfa_session_does_not_emit_disconnect_event( + _: PgPoolOptions, + options: PgConnectOptions, + ) { + let pool = setup_pool(options).await; + let location = create_mfa_location(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_location(&pool, location.id, device.id).await; + let old_session = VpnClientSession::new( + location.id, + user.id, + device.id, + None, + Some(VpnClientMfaMethod::Totp), + ) + .save(&pool) + .await + .expect("failed to create existing new MFA session"); + + let (server, mut event_rx, _gateway_rx) = make_server(pool.clone()); + let mut conn = pool.acquire().await.expect("failed to acquire connection"); + + server + .create_new_mfa_session( + &mut conn, + &location, + &user, + &device, + VpnClientMfaMethod::Totp, + ) + .await + .expect("should replace new MFA session"); + + assert!(event_rx.try_recv().is_err()); + + let old_session = VpnClientSession::find_by_id(&pool, old_session.id) + .await + .expect("failed to query old session") + .expect("expected old session"); + assert_eq!(old_session.state, VpnClientSessionState::Disconnected); + } + + #[sqlx::test] + async fn test_replacing_connected_non_mfa_session_does_not_emit_mfa_disconnect_event( + _: PgPoolOptions, + options: PgConnectOptions, + ) { + let pool = setup_pool(options).await; + let location = create_mfa_location(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_location(&pool, location.id, device.id).await; + let old_session = VpnClientSession::new( + location.id, + user.id, + device.id, + Some(Utc::now().naive_utc()), + None, + ) + .save(&pool) + .await + .expect("failed to create existing connected non-MFA session"); + + let (server, mut event_rx, _gateway_rx) = make_server(pool.clone()); + let mut conn = pool.acquire().await.expect("failed to acquire connection"); + + server + .create_new_mfa_session( + &mut conn, + &location, + &user, + &device, + VpnClientMfaMethod::Totp, + ) + .await + .expect("should replace connected non-MFA session"); + + assert!(event_rx.try_recv().is_err()); + + let old_session = VpnClientSession::find_by_id(&pool, old_session.id) + .await + .expect("failed to query old session") + .expect("expected old session"); + assert_eq!(old_session.state, VpnClientSessionState::Disconnected); + } + + fn make_server( + pool: PgPool, + ) -> ( + ClientMfaServer, + tokio::sync::mpsc::UnboundedReceiver, + tokio::sync::broadcast::Receiver, + ) { + let (wireguard_tx, wireguard_rx) = broadcast::channel(8); + let (bidi_event_tx, bidi_event_rx) = unbounded_channel(); + let remote_mfa_responses: Arc>>> = + Arc::default(); + let sessions: Arc>> = Arc::default(); + + ( + ClientMfaServer::new( + pool, + wireguard_tx, + bidi_event_tx, + remote_mfa_responses, + sessions, + ), + bidi_event_rx, + wireguard_rx, + ) + } + + async fn create_user(pool: &PgPool) -> User { + User::new( + "client-mfa-test", + Some("pass123"), + "Tester", + "ClientMfa", + "client-mfa@example.com", + None, + ) + .save(pool) + .await + .expect("failed to create user") + } + + async fn create_device(pool: &PgPool, user_id: Id) -> Device { + Device::new( + "client-mfa-device".to_string(), + "client-mfa-pubkey".to_string(), + user_id, + DeviceType::User, + None, + true, + ) + .save(pool) + .await + .expect("failed to create device") + } + + async fn create_mfa_location(pool: &PgPool) -> WireguardNetwork { + WireguardNetwork::new( + "client-mfa-location".to_string(), + vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 10, 0, 0)), 24).unwrap()], + 51820, + "vpn.example.com".to_string(), + None, + 1420, + 0, + vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap()], + true, + 25, + 300, + false, + false, + LocationMfaMode::Internal, + ServiceLocationMode::Disabled, + ) + .save(pool) + .await + .expect("failed to create location") + } + + async fn attach_device_to_location(pool: &PgPool, location_id: Id, device_id: Id) { + WireguardNetworkDevice::new( + location_id, + device_id, + vec![IpAddr::V4(Ipv4Addr::new(10, 10, 0, 10))], + ) + .insert(pool) + .await + .expect("failed to attach device to location"); + } +} diff --git a/crates/defguard_event_logger/Cargo.toml b/crates/defguard_event_logger/Cargo.toml index 6761794377..30de520e51 100644 --- a/crates/defguard_event_logger/Cargo.toml +++ b/crates/defguard_event_logger/Cargo.toml @@ -21,3 +21,6 @@ sqlx.workspace = true thiserror.workspace = true tokio.workspace = true tracing.workspace = true + +[dev-dependencies] +ipnetwork.workspace = true diff --git a/crates/defguard_event_logger/src/description.rs b/crates/defguard_event_logger/src/description.rs index 856494df81..69b08bebfc 100644 --- a/crates/defguard_event_logger/src/description.rs +++ b/crates/defguard_event_logger/src/description.rs @@ -292,6 +292,12 @@ pub fn get_vpn_event_description(event: &VpnEvent) -> Option { VpnEvent::DisconnectedFromLocation { location, device } => Some(format!( "Device {device} disconnected from location {location}" )), + VpnEvent::MfaConnectedToLocation { location, device } => Some(format!( + "Device {device} connected to MFA location {location}" + )), + VpnEvent::MfaDisconnectedFromLocation { location, device } => Some(format!( + "Device {device} disconnected from MFA location {location}" + )), } } diff --git a/crates/defguard_event_logger/src/lib.rs b/crates/defguard_event_logger/src/lib.rs index 3d954d2dd8..5332efccfe 100644 --- a/crates/defguard_event_logger/src/lib.rs +++ b/crates/defguard_event_logger/src/lib.rs @@ -37,6 +37,55 @@ pub mod message; const MESSAGE_LIMIT: usize = 100; +fn map_vpn_event(event: VpnEvent) -> (EventType, Option) { + match event { + VpnEvent::ClientMfaFailed { + location, + device, + method, + message, + } => ( + EventType::VpnClientMfaFailed, + serde_json::to_value(VpnClientMfaFailedMetadata { + location, + device, + method, + message, + }) + .ok(), + ), + VpnEvent::ClientMfaSuccess { + location, + device, + method, + } => ( + EventType::VpnClientMfaSuccess, + serde_json::to_value(VpnClientMfaMetadata { + location, + device, + method, + }) + .ok(), + ), + VpnEvent::ConnectedToLocation { location, device } => ( + EventType::VpnClientConnected, + serde_json::to_value(VpnClientMetadata { location, device }).ok(), + ), + VpnEvent::DisconnectedFromLocation { location, device } => ( + EventType::VpnClientDisconnected, + serde_json::to_value(VpnClientMetadata { location, device }).ok(), + ), + VpnEvent::MfaConnectedToLocation { location, device } => ( + EventType::VpnClientMfaConnected, + serde_json::to_value(VpnClientMetadata { location, device }).ok(), + ), + VpnEvent::MfaDisconnectedFromLocation { location, device } => ( + EventType::VpnClientMfaDisconnected, + serde_json::to_value(VpnClientMetadata { location, device }).ok(), + ), + } +} + /// Run the event logger service /// /// This function runs in an infinite loop, receiving messages from the event_logger_rx channel @@ -493,44 +542,7 @@ pub async fn run_event_logger( let module = ActivityLogModule::Vpn; let description = get_vpn_event_description(&event); - let (event_type, metadata) = match *event { - VpnEvent::ClientMfaFailed { - location, - device, - method, - message, - } => ( - EventType::VpnClientMfaFailed, - serde_json::to_value(VpnClientMfaFailedMetadata { - location, - device, - method, - message, - }) - .ok(), - ), - VpnEvent::ClientMfaSuccess { - location, - device, - method, - } => ( - EventType::VpnClientMfaSuccess, - serde_json::to_value(VpnClientMfaMetadata { - location, - device, - method, - }) - .ok(), - ), - VpnEvent::ConnectedToLocation { location, device } => ( - EventType::VpnClientConnected, - serde_json::to_value(VpnClientMetadata { location, device }).ok(), - ), - VpnEvent::DisconnectedFromLocation { location, device } => ( - EventType::VpnClientDisconnected, - serde_json::to_value(VpnClientMetadata { location, device }).ok(), - ), - }; + let (event_type, metadata) = map_vpn_event(*event); (module, event_type, description, metadata) } LoggerEvent::Enrollment(event) => { @@ -610,3 +622,112 @@ pub async fn run_event_logger( transaction.commit().await?; } } + +#[cfg(test)] +mod tests { + use defguard_common::db::{ + NoId, + models::{ + Device, DeviceType, WireguardNetwork, + wireguard::{LocationMfaMode, ServiceLocationMode}, + }, + }; + use ipnetwork::IpNetwork; + use std::net::{IpAddr, Ipv4Addr}; + + use super::*; + + fn sample_device() -> Device { + Device::new( + "vpn-device".to_string(), + "pubkey".to_string(), + 1, + DeviceType::User, + None, + true, + ) + .save_placeholder_id(20) + } + + fn sample_location() -> WireguardNetwork { + WireguardNetwork::new( + "vpn-location".to_string(), + vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 0)), 24).unwrap()], + 51820, + "vpn.example.com".to_string(), + None, + 1420, + 0, + vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap()], + true, + 25, + 300, + false, + false, + LocationMfaMode::Internal, + ServiceLocationMode::Disabled, + ) + .save_placeholder_id(10) + } + + #[test] + fn maps_mfa_vpn_connect_and_disconnect_events() { + let location = sample_location(); + let device = sample_device(); + + let (event_type, _) = map_vpn_event(VpnEvent::MfaConnectedToLocation { + location: location.clone(), + device: device.clone(), + }); + assert!(matches!(event_type, EventType::VpnClientMfaConnected)); + + let (event_type, _) = + map_vpn_event(VpnEvent::MfaDisconnectedFromLocation { location, device }); + assert!(matches!(event_type, EventType::VpnClientMfaDisconnected)); + } + + trait WithPlaceholderId { + fn save_placeholder_id(self, id: i64) -> T; + } + + impl WithPlaceholderId> for Device { + fn save_placeholder_id(self, id: i64) -> Device { + Device { + id, + name: self.name, + wireguard_pubkey: self.wireguard_pubkey, + user_id: self.user_id, + created: self.created, + device_type: self.device_type, + description: self.description, + configured: self.configured, + } + } + } + + impl WithPlaceholderId> for WireguardNetwork { + fn save_placeholder_id(self, id: i64) -> WireguardNetwork { + WireguardNetwork { + id, + name: self.name, + address: self.address, + port: self.port, + pubkey: self.pubkey, + prvkey: self.prvkey, + endpoint: self.endpoint, + dns: self.dns, + mtu: self.mtu, + fwmark: self.fwmark, + allowed_ips: self.allowed_ips, + allow_all_groups: self.allow_all_groups, + connected_at: self.connected_at, + acl_enabled: self.acl_enabled, + acl_default_allow: self.acl_default_allow, + keepalive_interval: self.keepalive_interval, + peer_disconnect_threshold: self.peer_disconnect_threshold, + location_mfa_mode: self.location_mfa_mode, + service_location_mode: self.service_location_mode, + } + } + } +} diff --git a/crates/defguard_event_logger/src/message.rs b/crates/defguard_event_logger/src/message.rs index 94626063d2..7a2dc95997 100644 --- a/crates/defguard_event_logger/src/message.rs +++ b/crates/defguard_event_logger/src/message.rs @@ -360,6 +360,14 @@ pub enum VpnEvent { location: WireguardNetwork, device: Device, }, + MfaConnectedToLocation { + location: WireguardNetwork, + device: Device, + }, + MfaDisconnectedFromLocation { + location: WireguardNetwork, + device: Device, + }, } /// Represents activity log events related to user enrollment process diff --git a/crates/defguard_event_router/src/handlers/bidi.rs b/crates/defguard_event_router/src/handlers/bidi.rs index 669a0e0518..8a43ef419f 100644 --- a/crates/defguard_event_router/src/handlers/bidi.rs +++ b/crates/defguard_event_router/src/handlers/bidi.rs @@ -71,6 +71,13 @@ impl EventRouter { })), Some(location), ), + DesktopClientMfaEvent::Disconnected { location, device } => ( + LoggerEvent::Vpn(Box::new(VpnEvent::MfaDisconnectedFromLocation { + location: location.clone(), + device, + })), + Some(location), + ), }, }; diff --git a/crates/defguard_event_router/src/handlers/session_manager.rs b/crates/defguard_event_router/src/handlers/session_manager.rs index e818c4d0f4..ef19a6398a 100644 --- a/crates/defguard_event_router/src/handlers/session_manager.rs +++ b/crates/defguard_event_router/src/handlers/session_manager.rs @@ -27,8 +27,18 @@ impl EventRouter { device, })) } - SessionManagerEventType::MfaClientConnected => todo!(), - SessionManagerEventType::MfaClientDisconnected => todo!(), + SessionManagerEventType::MfaClientConnected => { + LoggerEvent::Vpn(Box::new(VpnEvent::MfaConnectedToLocation { + location, + device, + })) + } + SessionManagerEventType::MfaClientDisconnected => { + LoggerEvent::Vpn(Box::new(VpnEvent::MfaDisconnectedFromLocation { + location, + device, + })) + } }; self.log_event( EventContext::from_session_manager_context(context), From 519288fc6a83ed7eaec10bd1b9063bea0ebfaeee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Mon, 16 Mar 2026 17:23:38 +0100 Subject: [PATCH 02/18] add connected/disconnected events helpers --- crates/defguard_session_manager/src/events.rs | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/crates/defguard_session_manager/src/events.rs b/crates/defguard_session_manager/src/events.rs index e22495fa41..fbbc26c76b 100644 --- a/crates/defguard_session_manager/src/events.rs +++ b/crates/defguard_session_manager/src/events.rs @@ -2,8 +2,8 @@ use std::net::IpAddr; use chrono::NaiveDateTime; use defguard_common::db::{ - Id, models::{Device, User, WireguardNetwork}, + Id, }; #[derive(Debug)] @@ -12,6 +12,36 @@ pub struct SessionManagerEvent { pub event: SessionManagerEventType, } +impl SessionManagerEvent { + #[must_use] + pub fn connected_for_session( + context: SessionManagerEventContext, + is_mfa_session: bool, + ) -> Self { + let event = if is_mfa_session { + SessionManagerEventType::MfaClientConnected + } else { + SessionManagerEventType::ClientConnected + }; + + Self { context, event } + } + + #[must_use] + pub fn disconnected_for_session( + context: SessionManagerEventContext, + is_mfa_session: bool, + ) -> Self { + let event = if is_mfa_session { + SessionManagerEventType::MfaClientDisconnected + } else { + SessionManagerEventType::ClientDisconnected + }; + + Self { context, event } + } +} + #[derive(Debug)] pub struct SessionManagerEventContext { pub timestamp: NaiveDateTime, From 0406dca130e6eb3a0b24c9af6e9b5e630ba6667b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Mon, 16 Mar 2026 17:30:21 +0100 Subject: [PATCH 03/18] emit connected event for new mfa sessions --- crates/defguard_session_manager/src/lib.rs | 18 ++++++---- .../src/session_state.rs | 33 ++++++++++++++++--- 2 files changed, 39 insertions(+), 12 deletions(-) diff --git a/crates/defguard_session_manager/src/lib.rs b/crates/defguard_session_manager/src/lib.rs index 9ee8e4a463..27993de125 100644 --- a/crates/defguard_session_manager/src/lib.rs +++ b/crates/defguard_session_manager/src/lib.rs @@ -25,7 +25,7 @@ use tracing::{debug, error, info, trace, warn}; use crate::{ error::SessionManagerError, - events::{SessionManagerEvent, SessionManagerEventContext, SessionManagerEventType}, + events::{SessionManagerEvent, SessionManagerEventContext}, session_state::ActiveSessionsMap, }; @@ -192,7 +192,9 @@ impl SessionManager { if let Some(session) = maybe_session { // update session stats - session.update_stats(transaction, message).await?; + session + .update_stats(transaction, message, &self.session_manager_event_tx) + .await?; } trace!("Finished processing peer stats update"); @@ -315,11 +317,13 @@ impl SessionManager { // FIXME: this is a workaround since we require an IP for each audit log event public_ip: IpAddr::V4(Ipv4Addr::UNSPECIFIED), }; - let event = SessionManagerEvent { - context, - event: SessionManagerEventType::ClientDisconnected, - }; - self.session_manager_event_tx.send(event)?; + if session.connected_at.is_some() { + let event = SessionManagerEvent::disconnected_for_session( + context, + session.mfa_method.is_some(), + ); + self.session_manager_event_tx.send(event)?; + } Ok(()) } diff --git a/crates/defguard_session_manager/src/session_state.rs b/crates/defguard_session_manager/src/session_state.rs index f225814f0a..b0b18d56b9 100644 --- a/crates/defguard_session_manager/src/session_state.rs +++ b/crates/defguard_session_manager/src/session_state.rs @@ -19,7 +19,7 @@ use tracing::{debug, warn}; use crate::{ error::SessionManagerError, - events::{SessionManagerEvent, SessionManagerEventContext, SessionManagerEventType}, + events::{SessionManagerEvent, SessionManagerEventContext}, }; /// Helper map to store latest stats update for each gateway in a given location @@ -103,6 +103,7 @@ impl SessionState { &mut self, transaction: &mut PgConnection, peer_stats_update: PeerStatsUpdate, + event_tx: &UnboundedSender, ) -> Result<(), SessionManagerError> { // mark new MFA session as connected if necessary if self.state == VpnClientSessionState::New { @@ -117,6 +118,31 @@ impl SessionState { db_session.connected_at = Some(peer_stats_update.latest_handshake); db_session.save(&mut *transaction).await?; + let user = User::find_by_id(&mut *transaction, db_session.user_id) + .await? + .ok_or(SessionManagerError::UserDoesNotExistError(db_session.user_id))?; + let device = Device::find_by_id(&mut *transaction, db_session.device_id) + .await? + .ok_or(SessionManagerError::DeviceDoesNotExistError(db_session.device_id))?; + let location = WireguardNetwork::find_by_id(&mut *transaction, db_session.location_id) + .await? + .ok_or(SessionManagerError::LocationDoesNotExistError( + db_session.location_id, + ))?; + + let context = SessionManagerEventContext { + timestamp: peer_stats_update.latest_handshake, + location, + user, + device, + public_ip: peer_stats_update.endpoint.ip(), + }; + let event = SessionManagerEvent::connected_for_session( + context, + db_session.mfa_method.is_some(), + ); + event_tx.send(event)?; + // update local session state self.state = VpnClientSessionState::Connected; } @@ -339,10 +365,7 @@ impl ActiveSessionsMap { device, public_ip, }; - let event = SessionManagerEvent { - context, - event: SessionManagerEventType::ClientConnected, - }; + let event = SessionManagerEvent::connected_for_session(context, false); event_tx.send(event)?; Ok(session_map.0.get_mut(&device_id)) From a11c0f10faaf32ef108b7d636c04711efa20e26e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Mon, 16 Mar 2026 17:30:39 +0100 Subject: [PATCH 04/18] update session manager tests --- Cargo.lock | 1 + .../tests/session_manager/mfa.rs | 57 +++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 13844a5a7e..6a7b890406 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1449,6 +1449,7 @@ dependencies = [ "defguard_common", "defguard_core", "defguard_session_manager", + "ipnetwork", "serde_json", "sqlx", "thiserror 2.0.18", diff --git a/crates/defguard_session_manager/tests/session_manager/mfa.rs b/crates/defguard_session_manager/tests/session_manager/mfa.rs index f167211854..c55c880c7b 100644 --- a/crates/defguard_session_manager/tests/session_manager/mfa.rs +++ b/crates/defguard_session_manager/tests/session_manager/mfa.rs @@ -11,6 +11,7 @@ use defguard_common::db::{ setup_pool, }; use defguard_core::grpc::GatewayEvent; +use defguard_session_manager::events::SessionManagerEventType; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use tokio::time::{Duration, timeout}; @@ -113,6 +114,19 @@ async fn test_mfa_new_session_upgrades_to_connected_on_stats( 1 ); + let connected_event = timeout(RECEIVE_TIMEOUT, harness.event_rx.recv()) + .await + .expect("timed out waiting for MfaClientConnected event") + .expect("session manager event channel closed"); + assert!(matches!( + connected_event.event, + SessionManagerEventType::MfaClientConnected + )); + assert_eq!(connected_event.context.location.id, location.id); + assert_eq!(connected_event.context.user.id, user.id); + assert_eq!(connected_event.context.device.id, device.id); + assert_eq!(connected_event.context.public_ip, endpoint.ip()); + let second_collected_at = handshake + TimeDelta::seconds(30); let second_handshake = handshake + TimeDelta::seconds(25); harness.send_stats(build_stats_update( @@ -157,6 +171,9 @@ async fn test_mfa_new_session_upgrades_to_connected_on_stats( assert_eq!(latest_stats.total_download, 280); assert_eq!(latest_stats.upload_diff, 60); assert_eq!(latest_stats.download_diff, 80); + + assert_no_session_manager_events(&mut harness); + assert_no_gateway_events(&mut harness); } #[sqlx::test] @@ -226,6 +243,19 @@ async fn test_duplicate_first_stats_on_mfa_new_session_are_idempotent( assert_eq!(latest_stats.upload_diff, 0); assert_eq!(latest_stats.download_diff, 0); + let connected_event = timeout(RECEIVE_TIMEOUT, harness.event_rx.recv()) + .await + .expect("timed out waiting for MfaClientConnected event in duplicate first-stats test") + .expect("session manager event channel closed"); + assert!(matches!( + connected_event.event, + SessionManagerEventType::MfaClientConnected + )); + assert_eq!(connected_event.context.location.id, location.id); + assert_eq!(connected_event.context.user.id, user.id); + assert_eq!(connected_event.context.device.id, device.id); + assert_eq!(connected_event.context.public_ip, endpoint.ip()); + assert_no_session_manager_events(&mut harness); assert_no_gateway_events(&mut harness); } @@ -275,6 +305,19 @@ async fn test_repeated_later_stats_on_mfa_session_remain_idempotent( assert_eq!(connected_session.state, VpnClientSessionState::Connected); assert_eq!(connected_session.connected_at, Some(first_handshake)); + let connected_event = timeout(RECEIVE_TIMEOUT, harness.event_rx.recv()) + .await + .expect("timed out waiting for MfaClientConnected event in repeated-stats test") + .expect("session manager event channel closed"); + assert!(matches!( + connected_event.event, + SessionManagerEventType::MfaClientConnected + )); + assert_eq!(connected_event.context.location.id, location.id); + assert_eq!(connected_event.context.user.id, user.id); + assert_eq!(connected_event.context.device.id, device.id); + assert_eq!(connected_event.context.public_ip, endpoint.ip()); + assert_no_session_manager_events(&mut harness); assert_no_gateway_events(&mut harness); @@ -388,6 +431,18 @@ async fn test_inactive_mfa_connected_sessions_disconnect_and_clear_authorization } other => panic!("unexpected gateway event: {other:?}"), } + + let disconnected_event = timeout(RECEIVE_TIMEOUT, harness.event_rx.recv()) + .await + .expect("timed out waiting for MfaClientDisconnected event") + .expect("session manager event channel closed"); + assert!(matches!( + disconnected_event.event, + SessionManagerEventType::MfaClientDisconnected + )); + assert_eq!(disconnected_event.context.location.id, location.id); + assert_eq!(disconnected_event.context.user.id, user.id); + assert_eq!(disconnected_event.context.device.id, device.id); } #[sqlx::test] @@ -424,4 +479,6 @@ async fn test_never_connected_mfa_new_sessions_disconnect_after_threshold( VpnClientSessionState::Disconnected ); assert!(disconnected_session.disconnected_at.is_some()); + + assert_no_session_manager_events(&mut harness); } From c04c88b08a870527bb724156c4e5e849e76bb6e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Mon, 16 Mar 2026 17:31:19 +0100 Subject: [PATCH 05/18] formatting --- crates/defguard_session_manager/src/events.rs | 2 +- crates/defguard_session_manager/src/session_state.rs | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/crates/defguard_session_manager/src/events.rs b/crates/defguard_session_manager/src/events.rs index fbbc26c76b..a2eb71de42 100644 --- a/crates/defguard_session_manager/src/events.rs +++ b/crates/defguard_session_manager/src/events.rs @@ -2,8 +2,8 @@ use std::net::IpAddr; use chrono::NaiveDateTime; use defguard_common::db::{ - models::{Device, User, WireguardNetwork}, Id, + models::{Device, User, WireguardNetwork}, }; #[derive(Debug)] diff --git a/crates/defguard_session_manager/src/session_state.rs b/crates/defguard_session_manager/src/session_state.rs index b0b18d56b9..14d4c95a83 100644 --- a/crates/defguard_session_manager/src/session_state.rs +++ b/crates/defguard_session_manager/src/session_state.rs @@ -120,10 +120,14 @@ impl SessionState { let user = User::find_by_id(&mut *transaction, db_session.user_id) .await? - .ok_or(SessionManagerError::UserDoesNotExistError(db_session.user_id))?; + .ok_or(SessionManagerError::UserDoesNotExistError( + db_session.user_id, + ))?; let device = Device::find_by_id(&mut *transaction, db_session.device_id) .await? - .ok_or(SessionManagerError::DeviceDoesNotExistError(db_session.device_id))?; + .ok_or(SessionManagerError::DeviceDoesNotExistError( + db_session.device_id, + ))?; let location = WireguardNetwork::find_by_id(&mut *transaction, db_session.location_id) .await? .ok_or(SessionManagerError::LocationDoesNotExistError( From 33e9af24b6ca1ef02b32ef28737d1c59cb15a67a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Mon, 16 Mar 2026 17:48:56 +0100 Subject: [PATCH 06/18] support mfa events in frontend --- web/messages/en/activity.json | 2 ++ web/src/shared/api/activity-log-types.ts | 2 ++ 2 files changed, 4 insertions(+) diff --git a/web/messages/en/activity.json b/web/messages/en/activity.json index bd707c532f..9a0284f982 100644 --- a/web/messages/en/activity.json +++ b/web/messages/en/activity.json @@ -29,6 +29,8 @@ "activity_event_activity_log_stream_removed": "Activity log stream removed", "activity_event_vpn_client_connected": "VPN client connected", "activity_event_vpn_client_disconnected": "VPN client disconnected", + "activity_event_vpn_client_mfa_connected": "VPN client MFA connected", + "activity_event_vpn_client_mfa_disconnected": "VPN client MFA disconnected", "activity_event_vpn_client_mfa_success": "VPN client MFA success", "activity_event_vpn_client_mfa_failed": "VPN client MFA failed", "activity_event_enrollment_token_added": "Enrollment token added", diff --git a/web/src/shared/api/activity-log-types.ts b/web/src/shared/api/activity-log-types.ts index 57129647b7..4747353d63 100644 --- a/web/src/shared/api/activity-log-types.ts +++ b/web/src/shared/api/activity-log-types.ts @@ -46,6 +46,8 @@ export const ActivityLogEventType = { VpnClientConnected: 'vpn_client_connected', VpnClientDisconnected: 'vpn_client_disconnected', + VpnClientMfaConnected: 'vpn_client_mfa_connected', + VpnClientMfaDisconnected: 'vpn_client_mfa_disconnected', VpnClientMfaSuccess: 'vpn_client_mfa_success', VpnClientMfaFailed: 'vpn_client_mfa_failed', From 98b6f1be0bc16bcebee01b378260471e7913b85a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Tue, 17 Mar 2026 07:56:25 +0100 Subject: [PATCH 07/18] cleanup --- crates/defguard_core/src/grpc/proxy/client_mfa.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/crates/defguard_core/src/grpc/proxy/client_mfa.rs b/crates/defguard_core/src/grpc/proxy/client_mfa.rs index d7a34274dc..7e7762b070 100644 --- a/crates/defguard_core/src/grpc/proxy/client_mfa.rs +++ b/crates/defguard_core/src/grpc/proxy/client_mfa.rs @@ -787,7 +787,7 @@ impl ClientMfaServer { // disconnect all active sessions for session in active_sessions { debug!("Disconnecting previous active MFA VPN session {session:?}."); - self.disconnect_mfa_session(&mut *conn, session, location, user, device) + self.disconnect_session(&mut *conn, session, location, user, device) .await?; } @@ -800,7 +800,7 @@ impl ClientMfaServer { } /// Update session state as disconnected and send relevant gateway update - async fn disconnect_mfa_session( + async fn disconnect_session( &self, conn: &mut PgConnection, mut session: VpnClientSession, @@ -808,6 +808,9 @@ impl ClientMfaServer { user: &User, device: &Device, ) -> Result<(), Status> { + let is_connected = session.state == VpnClientSessionState::Connected; + let is_mfa_session = session.mfa_method.is_some(); + // update session state in DB let disconnect_timestamp = Utc::now().naive_utc(); session.disconnected_at = Some(disconnect_timestamp); @@ -841,9 +844,7 @@ impl ClientMfaServer { })?; } - // only emit disconnect events if MFA session has actually connected - // and not for New sessions - if session.state == VpnClientSessionState::Connected { + if is_connected && is_mfa_session { let gateway_event = GatewayEvent::MfaSessionDisconnected(location.id, device.clone()); self.wireguard_tx.send(gateway_event).map_err(|err| { error!("Error sending WireGuard event: {err}"); From d14fec7fd7dafee9c86fd701aa1146d3e919b49d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Tue, 17 Mar 2026 08:24:29 +0100 Subject: [PATCH 08/18] disconnect both types of sessions on MFA re-auth --- crates/defguard_core/src/events.rs | 1 + .../defguard_core/src/grpc/proxy/client_mfa.rs | 18 ++++++++++++------ .../defguard_event_router/src/handlers/bidi.rs | 6 +++++- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/crates/defguard_core/src/events.rs b/crates/defguard_core/src/events.rs index a9f7eab7a7..941b2c9a23 100644 --- a/crates/defguard_core/src/events.rs +++ b/crates/defguard_core/src/events.rs @@ -395,5 +395,6 @@ pub enum DesktopClientMfaEvent { Disconnected { device: Device, location: WireguardNetwork, + is_mfa_session: bool, }, } diff --git a/crates/defguard_core/src/grpc/proxy/client_mfa.rs b/crates/defguard_core/src/grpc/proxy/client_mfa.rs index 7e7762b070..1814c39caf 100644 --- a/crates/defguard_core/src/grpc/proxy/client_mfa.rs +++ b/crates/defguard_core/src/grpc/proxy/client_mfa.rs @@ -844,12 +844,17 @@ impl ClientMfaServer { })?; } - if is_connected && is_mfa_session { - let gateway_event = GatewayEvent::MfaSessionDisconnected(location.id, device.clone()); - self.wireguard_tx.send(gateway_event).map_err(|err| { - error!("Error sending WireGuard event: {err}"); - Status::internal("unexpected error") - })?; + // only emit disconnect events if a session has actually been connected + if is_connected { + // gateway update is only needed to remove peer for MFA sessions + if is_mfa_session { + let gateway_event = + GatewayEvent::MfaSessionDisconnected(location.id, device.clone()); + self.wireguard_tx.send(gateway_event).map_err(|err| { + error!("Error sending WireGuard event: {err}"); + Status::internal("unexpected error") + })?; + } let context = BidiRequestContext { timestamp: disconnect_timestamp, @@ -864,6 +869,7 @@ impl ClientMfaServer { DesktopClientMfaEvent::Disconnected { location: location.clone(), device: device.clone(), + is_mfa_session, }, )), }) diff --git a/crates/defguard_event_router/src/handlers/bidi.rs b/crates/defguard_event_router/src/handlers/bidi.rs index 8a43ef419f..4137b74045 100644 --- a/crates/defguard_event_router/src/handlers/bidi.rs +++ b/crates/defguard_event_router/src/handlers/bidi.rs @@ -71,7 +71,11 @@ impl EventRouter { })), Some(location), ), - DesktopClientMfaEvent::Disconnected { location, device } => ( + DesktopClientMfaEvent::Disconnected { + location, + device, + is_mfa_session, + } => ( LoggerEvent::Vpn(Box::new(VpnEvent::MfaDisconnectedFromLocation { location: location.clone(), device, From 1934a172ddf5a89918ea61cc5c4ad95cb9e18f46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Tue, 17 Mar 2026 09:10:28 +0100 Subject: [PATCH 09/18] handle event routing for both types of disconnects --- Cargo.lock | 1 + crates/defguard_event_router/Cargo.toml | 3 + .../src/handlers/bidi.rs | 209 +++++++++++++++++- 3 files changed, 206 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6a7b890406..609c252df7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1461,6 +1461,7 @@ dependencies = [ name = "defguard_event_router" version = "0.0.0" dependencies = [ + "defguard_common", "defguard_core", "defguard_event_logger", "defguard_session_manager", diff --git a/crates/defguard_event_router/Cargo.toml b/crates/defguard_event_router/Cargo.toml index bb38e84a48..8b64231497 100644 --- a/crates/defguard_event_router/Cargo.toml +++ b/crates/defguard_event_router/Cargo.toml @@ -17,3 +17,6 @@ defguard_session_manager = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true } tracing = { workspace = true } + +[dev-dependencies] +defguard_common = { workspace = true } diff --git a/crates/defguard_event_router/src/handlers/bidi.rs b/crates/defguard_event_router/src/handlers/bidi.rs index 4137b74045..025a6b8465 100644 --- a/crates/defguard_event_router/src/handlers/bidi.rs +++ b/crates/defguard_event_router/src/handlers/bidi.rs @@ -75,13 +75,21 @@ impl EventRouter { location, device, is_mfa_session, - } => ( - LoggerEvent::Vpn(Box::new(VpnEvent::MfaDisconnectedFromLocation { - location: location.clone(), - device, - })), - Some(location), - ), + } => { + let vpn_event = if is_mfa_session { + VpnEvent::MfaDisconnectedFromLocation { + location: location.clone(), + device, + } + } else { + VpnEvent::DisconnectedFromLocation { + location: location.clone(), + device, + } + }; + + (LoggerEvent::Vpn(Box::new(vpn_event)), Some(location)) + } }, }; @@ -91,3 +99,190 @@ impl EventRouter { ) } } + +#[cfg(test)] +mod tests { + use std::{ + net::{IpAddr, Ipv4Addr}, + sync::Arc, + }; + + use defguard_common::db::{ + NoId, + models::{ + Device, DeviceType, WireguardNetwork, + wireguard::{LocationMfaMode, ServiceLocationMode}, + }, + }; + use defguard_core::{ + events::{BidiRequestContext, BidiStreamEventType}, + grpc::GatewayEvent, + }; + use tokio::sync::{Notify, broadcast, mpsc::unbounded_channel}; + + use super::*; + use crate::RouterReceiverSet; + + #[test] + fn maps_mfa_disconnect_bidi_events_to_mfa_disconnect_logger_events() { + let message = route_disconnect_event(true); + + match message.event { + LoggerEvent::Vpn(event) => match *event { + VpnEvent::MfaDisconnectedFromLocation { location, device } => { + assert_eq!(location.id, sample_location().id); + assert_eq!(device.id, sample_device().id); + } + _ => panic!("expected MFA disconnect vpn event"), + }, + _ => panic!("expected vpn logger event"), + } + } + + #[test] + fn maps_non_mfa_disconnect_bidi_events_to_standard_disconnect_logger_events() { + let message = route_disconnect_event(false); + + match message.event { + LoggerEvent::Vpn(event) => match *event { + VpnEvent::DisconnectedFromLocation { location, device } => { + assert_eq!(location.id, sample_location().id); + assert_eq!(device.id, sample_device().id); + } + _ => panic!("expected standard disconnect vpn event"), + }, + _ => panic!("expected vpn logger event"), + } + } + + fn sample_router() -> ( + EventRouter, + tokio::sync::mpsc::UnboundedReceiver, + ) { + let (_api_tx, api_rx) = unbounded_channel(); + let (_bidi_tx, bidi_rx) = unbounded_channel(); + let (_session_manager_tx, session_manager_rx) = unbounded_channel(); + let (event_logger_tx, event_logger_rx) = unbounded_channel(); + let (wireguard_tx, _wireguard_rx) = broadcast::channel::(1); + + ( + EventRouter::new( + RouterReceiverSet::new(api_rx, bidi_rx, session_manager_rx), + event_logger_tx, + wireguard_tx, + Arc::new(Notify::new()), + ), + event_logger_rx, + ) + } + + fn route_disconnect_event( + is_mfa_session: bool, + ) -> defguard_event_logger::message::EventLoggerMessage { + let (router, mut event_logger_rx) = sample_router(); + + router + .handle_bidi_event(BidiStreamEvent { + context: sample_context(), + event: BidiStreamEventType::DesktopClientMfa(Box::new( + DesktopClientMfaEvent::Disconnected { + location: sample_location(), + device: sample_device(), + is_mfa_session, + }, + )), + }) + .expect("bidi disconnect event should be routed"); + + event_logger_rx + .try_recv() + .expect("router should emit an activity log message") + } + + fn sample_context() -> BidiRequestContext { + BidiRequestContext::new( + 1, + "alice".to_string(), + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + "desktop-app".to_string(), + ) + } + + fn sample_device() -> Device { + Device::new( + "vpn-device".to_string(), + "pubkey".to_string(), + 1, + DeviceType::User, + None, + true, + ) + .save_placeholder_id(20) + } + + fn sample_location() -> WireguardNetwork { + WireguardNetwork::new( + "vpn-location".to_string(), + vec!["10.0.0.0/24".parse().unwrap()], + 51820, + "vpn.example.com".to_string(), + None, + 1420, + 0, + vec!["0.0.0.0/0".parse().unwrap()], + true, + 25, + 300, + false, + false, + LocationMfaMode::Internal, + ServiceLocationMode::Disabled, + ) + .save_placeholder_id(10) + } + + trait WithPlaceholderId { + fn save_placeholder_id(self, id: i64) -> T; + } + + impl WithPlaceholderId> for Device { + fn save_placeholder_id(self, id: i64) -> Device { + Device { + id, + name: self.name, + wireguard_pubkey: self.wireguard_pubkey, + user_id: self.user_id, + created: self.created, + device_type: self.device_type, + description: self.description, + configured: self.configured, + } + } + } + + impl WithPlaceholderId> for WireguardNetwork { + fn save_placeholder_id(self, id: i64) -> WireguardNetwork { + WireguardNetwork { + id, + name: self.name, + address: self.address, + port: self.port, + pubkey: self.pubkey, + prvkey: self.prvkey, + endpoint: self.endpoint, + dns: self.dns, + mtu: self.mtu, + fwmark: self.fwmark, + allowed_ips: self.allowed_ips, + allow_all_groups: self.allow_all_groups, + connected_at: self.connected_at, + acl_enabled: self.acl_enabled, + acl_default_allow: self.acl_default_allow, + keepalive_interval: self.keepalive_interval, + peer_disconnect_threshold: self.peer_disconnect_threshold, + location_mfa_mode: self.location_mfa_mode, + service_location_mode: self.service_location_mode, + } + } + } +} From ff82cadba262f489e9924d3fed4d1f4547a16176 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Tue, 17 Mar 2026 09:21:11 +0100 Subject: [PATCH 10/18] update tests --- .../src/grpc/proxy/client_mfa.rs | 80 +++++++++++++++---- .../src/handlers/bidi.rs | 12 +-- 2 files changed, 72 insertions(+), 20 deletions(-) diff --git a/crates/defguard_core/src/grpc/proxy/client_mfa.rs b/crates/defguard_core/src/grpc/proxy/client_mfa.rs index 1814c39caf..e8f902fb65 100644 --- a/crates/defguard_core/src/grpc/proxy/client_mfa.rs +++ b/crates/defguard_core/src/grpc/proxy/client_mfa.rs @@ -899,7 +899,7 @@ mod tests { use super::*; #[sqlx::test] - async fn test_replacing_connected_mfa_session_emits_disconnect_event( + async fn test_replacing_connected_mfa_session_emits_mfa_disconnect_event( _: PgPoolOptions, options: PgConnectOptions, ) { @@ -919,7 +919,7 @@ mod tests { .await .expect("failed to create existing MFA session"); - let (server, mut event_rx, _gateway_rx) = make_server(pool.clone()); + let (server, mut event_rx, mut gateway_rx) = make_server(pool.clone()); let mut conn = pool.acquire().await.expect("failed to acquire connection"); server @@ -933,14 +933,35 @@ mod tests { .await .expect("should replace connected MFA session"); + let gateway_event = gateway_rx + .try_recv() + .expect("expected MFA gateway disconnect event for replaced connected session"); + match gateway_event { + GatewayEvent::MfaSessionDisconnected(location_id, disconnected_device) => { + assert_eq!(location_id, location.id); + assert_eq!(disconnected_device.id, device.id); + } + other => panic!("unexpected gateway event: {other:?}"), + } + let event = event_rx .try_recv() - .expect("expected MFA disconnect event for replaced connected session"); - assert!(matches!( - event.event, - BidiStreamEventType::DesktopClientMfa(event) - if matches!(*event, DesktopClientMfaEvent::Disconnected { .. }) - )); + .expect("expected MFA disconnect audit event for replaced connected session"); + match event.event { + BidiStreamEventType::DesktopClientMfa(event) => match *event { + DesktopClientMfaEvent::Disconnected { + location: event_location, + device: event_device, + is_mfa_session, + } => { + assert_eq!(event_location.id, location.id); + assert_eq!(event_device.id, device.id); + assert!(is_mfa_session); + } + other => panic!("unexpected bidi event: {other:?}"), + }, + other => panic!("unexpected bidi stream event type: {other:?}"), + } assert_eq!(event.context.user_id, user.id); assert_eq!(event.context.username, user.username); @@ -952,7 +973,7 @@ mod tests { } #[sqlx::test] - async fn test_replacing_new_mfa_session_does_not_emit_disconnect_event( + async fn test_replacing_new_mfa_session_marks_session_disconnected_without_disconnect_event( _: PgPoolOptions, options: PgConnectOptions, ) { @@ -972,7 +993,7 @@ mod tests { .await .expect("failed to create existing new MFA session"); - let (server, mut event_rx, _gateway_rx) = make_server(pool.clone()); + let (server, mut event_rx, mut gateway_rx) = make_server(pool.clone()); let mut conn = pool.acquire().await.expect("failed to acquire connection"); server @@ -986,7 +1007,14 @@ mod tests { .await .expect("should replace new MFA session"); - assert!(event_rx.try_recv().is_err()); + assert!(matches!( + event_rx.try_recv(), + Err(tokio::sync::mpsc::error::TryRecvError::Empty) + )); + assert!(matches!( + gateway_rx.try_recv(), + Err(broadcast::error::TryRecvError::Empty) + )); let old_session = VpnClientSession::find_by_id(&pool, old_session.id) .await @@ -996,7 +1024,7 @@ mod tests { } #[sqlx::test] - async fn test_replacing_connected_non_mfa_session_does_not_emit_mfa_disconnect_event( + async fn test_replacing_connected_non_mfa_session_emits_standard_disconnect_event( _: PgPoolOptions, options: PgConnectOptions, ) { @@ -1016,7 +1044,7 @@ mod tests { .await .expect("failed to create existing connected non-MFA session"); - let (server, mut event_rx, _gateway_rx) = make_server(pool.clone()); + let (server, mut event_rx, mut gateway_rx) = make_server(pool.clone()); let mut conn = pool.acquire().await.expect("failed to acquire connection"); server @@ -1030,7 +1058,31 @@ mod tests { .await .expect("should replace connected non-MFA session"); - assert!(event_rx.try_recv().is_err()); + assert!(matches!( + gateway_rx.try_recv(), + Err(broadcast::error::TryRecvError::Empty) + )); + + let event = event_rx + .try_recv() + .expect("expected standard disconnect audit event for replaced connected non-MFA session"); + match event.event { + BidiStreamEventType::DesktopClientMfa(event) => match *event { + DesktopClientMfaEvent::Disconnected { + location: event_location, + device: event_device, + is_mfa_session, + } => { + assert_eq!(event_location.id, location.id); + assert_eq!(event_device.id, device.id); + assert!(!is_mfa_session); + } + other => panic!("unexpected bidi event: {other:?}"), + }, + other => panic!("unexpected bidi stream event type: {other:?}"), + } + assert_eq!(event.context.user_id, user.id); + assert_eq!(event.context.username, user.username); let old_session = VpnClientSession::find_by_id(&pool, old_session.id) .await diff --git a/crates/defguard_event_router/src/handlers/bidi.rs b/crates/defguard_event_router/src/handlers/bidi.rs index 025a6b8465..f05c321bd4 100644 --- a/crates/defguard_event_router/src/handlers/bidi.rs +++ b/crates/defguard_event_router/src/handlers/bidi.rs @@ -4,7 +4,7 @@ use defguard_core::events::{ use defguard_event_logger::message::{EnrollmentEvent, EventContext, LoggerEvent, VpnEvent}; use tracing::debug; -use crate::{EventRouter, error::EventRouterError}; +use crate::{error::EventRouterError, EventRouter}; impl EventRouter { pub(crate) fn handle_bidi_event(&self, event: BidiStreamEvent) -> Result<(), EventRouterError> { @@ -108,23 +108,23 @@ mod tests { }; use defguard_common::db::{ - NoId, models::{ - Device, DeviceType, WireguardNetwork, wireguard::{LocationMfaMode, ServiceLocationMode}, + Device, DeviceType, WireguardNetwork, }, + NoId, }; use defguard_core::{ events::{BidiRequestContext, BidiStreamEventType}, grpc::GatewayEvent, }; - use tokio::sync::{Notify, broadcast, mpsc::unbounded_channel}; + use tokio::sync::{broadcast, mpsc::unbounded_channel, Notify}; use super::*; use crate::RouterReceiverSet; #[test] - fn maps_mfa_disconnect_bidi_events_to_mfa_disconnect_logger_events() { + fn maps_disconnect_bidi_events_from_mfa_sessions_to_mfa_disconnect_logger_events() { let message = route_disconnect_event(true); match message.event { @@ -140,7 +140,7 @@ mod tests { } #[test] - fn maps_non_mfa_disconnect_bidi_events_to_standard_disconnect_logger_events() { + fn maps_disconnect_bidi_events_from_non_mfa_sessions_to_standard_disconnect_logger_events() { let message = route_disconnect_event(false); match message.event { From ca445dde6a5d5fcdde2d423f4b471b30aa74ff21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Tue, 17 Mar 2026 09:21:30 +0100 Subject: [PATCH 11/18] formatting --- crates/defguard_core/src/grpc/proxy/client_mfa.rs | 6 +++--- crates/defguard_event_router/src/handlers/bidi.rs | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/crates/defguard_core/src/grpc/proxy/client_mfa.rs b/crates/defguard_core/src/grpc/proxy/client_mfa.rs index e8f902fb65..8a72c5d94e 100644 --- a/crates/defguard_core/src/grpc/proxy/client_mfa.rs +++ b/crates/defguard_core/src/grpc/proxy/client_mfa.rs @@ -1063,9 +1063,9 @@ mod tests { Err(broadcast::error::TryRecvError::Empty) )); - let event = event_rx - .try_recv() - .expect("expected standard disconnect audit event for replaced connected non-MFA session"); + let event = event_rx.try_recv().expect( + "expected standard disconnect audit event for replaced connected non-MFA session", + ); match event.event { BidiStreamEventType::DesktopClientMfa(event) => match *event { DesktopClientMfaEvent::Disconnected { diff --git a/crates/defguard_event_router/src/handlers/bidi.rs b/crates/defguard_event_router/src/handlers/bidi.rs index f05c321bd4..a3155d98d3 100644 --- a/crates/defguard_event_router/src/handlers/bidi.rs +++ b/crates/defguard_event_router/src/handlers/bidi.rs @@ -4,7 +4,7 @@ use defguard_core::events::{ use defguard_event_logger::message::{EnrollmentEvent, EventContext, LoggerEvent, VpnEvent}; use tracing::debug; -use crate::{error::EventRouterError, EventRouter}; +use crate::{EventRouter, error::EventRouterError}; impl EventRouter { pub(crate) fn handle_bidi_event(&self, event: BidiStreamEvent) -> Result<(), EventRouterError> { @@ -108,17 +108,17 @@ mod tests { }; use defguard_common::db::{ + NoId, models::{ - wireguard::{LocationMfaMode, ServiceLocationMode}, Device, DeviceType, WireguardNetwork, + wireguard::{LocationMfaMode, ServiceLocationMode}, }, - NoId, }; use defguard_core::{ events::{BidiRequestContext, BidiStreamEventType}, grpc::GatewayEvent, }; - use tokio::sync::{broadcast, mpsc::unbounded_channel, Notify}; + use tokio::sync::{Notify, broadcast, mpsc::unbounded_channel}; use super::*; use crate::RouterReceiverSet; From 8a95cf74bd09476df4805da728478f35cacc32fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Tue, 17 Mar 2026 10:52:58 +0100 Subject: [PATCH 12/18] review fixes --- crates/defguard_session_manager/src/error.rs | 2 + .../src/session_state.rs | 130 ++++++++++++------ .../tests/common/mod.rs | 4 + .../tests/session_manager/mfa.rs | 74 ++++++++++ 4 files changed, 171 insertions(+), 39 deletions(-) diff --git a/crates/defguard_session_manager/src/error.rs b/crates/defguard_session_manager/src/error.rs index 4f065ce6e0..4904a26038 100644 --- a/crates/defguard_session_manager/src/error.rs +++ b/crates/defguard_session_manager/src/error.rs @@ -27,6 +27,8 @@ pub enum SessionManagerError { LocationDoesNotExistError(Id), #[error("VPN client session with ID {0} does not exist")] SessionDoesNotExistError(Id), + #[error("VPN client session {0} is missing cached event context for transition")] + MissingSessionEventContextError(Id), #[error("Received out of order peer stats update")] PeerStatsUpdateOutOfOrderError, #[error("Peer stats channel closed")] diff --git a/crates/defguard_session_manager/src/session_state.rs b/crates/defguard_session_manager/src/session_state.rs index 14d4c95a83..562528f45a 100644 --- a/crates/defguard_session_manager/src/session_state.rs +++ b/crates/defguard_session_manager/src/session_state.rs @@ -1,4 +1,7 @@ -use std::collections::{HashMap, hash_map::Entry}; +use std::{ + collections::{HashMap, hash_map::Entry}, + net::IpAddr, +}; use chrono::{NaiveDateTime, TimeDelta}; use defguard_common::{ @@ -91,9 +94,29 @@ pub(crate) struct SessionState { session_id: Id, state: VpnClientSessionState, last_stats_update: LastGatewayUpdate, + event_context_data: Option, +} + +struct SessionEventContextData { + location: WireguardNetwork, + user: User, + device: Device, + is_mfa_session: bool, } impl SessionState { + fn new( + session: &VpnClientSession, + event_context_data: Option, + ) -> Self { + Self { + session_id: session.id, + state: session.state.clone(), + last_stats_update: LastGatewayUpdate::new(), + event_context_data, + } + } + fn try_get_last_stats_update(&self, gateway_id: Id) -> Option<&LastStatsUpdate> { self.last_stats_update.0.get(&gateway_id) } @@ -107,6 +130,10 @@ impl SessionState { ) -> Result<(), SessionManagerError> { // mark new MFA session as connected if necessary if self.state == VpnClientSessionState::New { + let event_context_data = self.event_context_data.as_ref().ok_or( + SessionManagerError::MissingSessionEventContextError(self.session_id), + )?; + // fetch DB session let mut db_session = VpnClientSession::find_by_id(&mut *transaction, self.session_id) .await? @@ -118,37 +145,18 @@ impl SessionState { db_session.connected_at = Some(peer_stats_update.latest_handshake); db_session.save(&mut *transaction).await?; - let user = User::find_by_id(&mut *transaction, db_session.user_id) - .await? - .ok_or(SessionManagerError::UserDoesNotExistError( - db_session.user_id, - ))?; - let device = Device::find_by_id(&mut *transaction, db_session.device_id) - .await? - .ok_or(SessionManagerError::DeviceDoesNotExistError( - db_session.device_id, - ))?; - let location = WireguardNetwork::find_by_id(&mut *transaction, db_session.location_id) - .await? - .ok_or(SessionManagerError::LocationDoesNotExistError( - db_session.location_id, - ))?; + // update local session state before event emission so the transition stays idempotent + // even if the event channel is closed. + self.state = VpnClientSessionState::Connected; - let context = SessionManagerEventContext { - timestamp: peer_stats_update.latest_handshake, - location, - user, - device, - public_ip: peer_stats_update.endpoint.ip(), - }; let event = SessionManagerEvent::connected_for_session( - context, - db_session.mfa_method.is_some(), + event_context_data.build_context( + peer_stats_update.latest_handshake, + peer_stats_update.endpoint.ip(), + ), + event_context_data.is_mfa_session, ); event_tx.send(event)?; - - // update local session state - self.state = VpnClientSessionState::Connected; } // get previous stats for a given gateway if available and calculate transfer change @@ -189,12 +197,18 @@ impl SessionState { } } -impl From<&VpnClientSession> for SessionState { - fn from(value: &VpnClientSession) -> Self { - Self { - session_id: value.id, - state: value.state.clone(), - last_stats_update: LastGatewayUpdate::new(), +impl SessionEventContextData { + fn build_context( + &self, + timestamp: NaiveDateTime, + public_ip: IpAddr, + ) -> SessionManagerEventContext { + SessionManagerEventContext { + timestamp, + location: self.location.clone(), + user: self.user.clone(), + device: self.device.clone(), + public_ip, } } } @@ -249,11 +263,22 @@ impl ActiveSessionsMap { device_pubkey: String, ) -> Result, SessionManagerError> { // translate pubkey into device ID - let device_id = self.get_device(&mut *transaction, device_pubkey).await?.id; + let device = self + .get_device(&mut *transaction, device_pubkey) + .await? + .clone(); + let device_id = device.id; // try to get session from current map - let session_map = self.get_or_create_location_session_map(location_id); - if session_map.0.contains_key(&device_id) { + let session_exists_in_batch = self + .sessions + .get(&location_id) + .is_some_and(|session_map| session_map.0.contains_key(&device_id)); + if session_exists_in_batch { + let session_map = self + .sessions + .get_mut(&location_id) + .expect("location session map should exist once checked"); return Ok(session_map.0.get_mut(&device_id)); } @@ -265,7 +290,25 @@ impl ActiveSessionsMap { match maybe_db_session { None => Ok(None), Some(db_session) => { - let mut session_state = SessionState::from(&db_session); + let event_context_data = if db_session.state == VpnClientSessionState::New { + let user = self + .get_user(&mut *transaction, device.user_id) + .await? + .clone(); + let location = self + .get_location(&mut *transaction, location_id) + .await? + .clone(); + Some(SessionEventContextData { + location, + user, + device, + is_mfa_session: db_session.mfa_method.is_some(), + }) + } else { + None + }; + let mut session_state = SessionState::new(&db_session, event_context_data); // fetch latest available stats for each gateway for a given session let latest_gateway_stats = db_session @@ -276,6 +319,7 @@ impl ActiveSessionsMap { } // put session state in map + let session_map = self.get_or_create_location_session_map(location_id); let maybe_existing_session = session_map.insert(device_id, session_state); // if a session exists already there was an error in earlier logic @@ -353,7 +397,15 @@ impl ActiveSessionsMap { .await?; // add to session map - let session_state = SessionState::from(&session); + let session_state = SessionState::new( + &session, + Some(SessionEventContextData { + location: location.clone(), + user: user.clone(), + device: device.clone(), + is_mfa_session: false, + }), + ); let session_map = self.get_or_create_location_session_map(location_id); let maybe_existing_session = session_map.insert(device.id, session_state); diff --git a/crates/defguard_session_manager/tests/common/mod.rs b/crates/defguard_session_manager/tests/common/mod.rs index ddd8181288..89b7050b20 100644 --- a/crates/defguard_session_manager/tests/common/mod.rs +++ b/crates/defguard_session_manager/tests/common/mod.rs @@ -85,6 +85,10 @@ impl SessionManagerHarness { .expect("failed to send peer stats update"); } + pub(crate) fn close_event_channel(&mut self) { + self.event_rx.close(); + } + pub(crate) async fn run_iteration(&mut self) -> IterationOutcome { let mut session_update_timer = interval(Duration::from_secs(SESSION_UPDATE_INTERVAL)); run_session_manager_iteration( diff --git a/crates/defguard_session_manager/tests/session_manager/mfa.rs b/crates/defguard_session_manager/tests/session_manager/mfa.rs index c55c880c7b..066ea355e2 100644 --- a/crates/defguard_session_manager/tests/session_manager/mfa.rs +++ b/crates/defguard_session_manager/tests/session_manager/mfa.rs @@ -364,6 +364,80 @@ async fn test_repeated_later_stats_on_mfa_session_remain_idempotent( assert_no_gateway_events(&mut harness); } +#[sqlx::test] +async fn test_closed_event_channel_keeps_mfa_first_stats_upgrade_idempotent( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let location = create_location_with_mfa_mode(&pool, LocationMfaMode::Internal).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_location(&pool, location.id, device.id).await; + let gateway = create_gateway(&pool, location.id, user.fullname()).await; + let mut harness = SessionManagerHarness::new(pool.clone()); + + let session = create_session( + &pool, + location.id, + user.id, + device.id, + None, + Some(VpnClientMfaMethod::Totp), + ) + .await; + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let first_handshake = truncate_timestamp(Utc::now().naive_utc() - TimeDelta::seconds(30)); + let second_collected_at = first_handshake + TimeDelta::seconds(30); + let second_handshake = first_handshake + TimeDelta::seconds(20); + + harness.close_event_channel(); + harness.send_stats(build_stats_update( + location.id, + gateway.id, + &device.wireguard_pubkey, + first_handshake, + endpoint, + 100, + 200, + first_handshake, + )); + harness.send_stats(build_stats_update( + location.id, + gateway.id, + &device.wireguard_pubkey, + second_collected_at, + endpoint, + 160, + 280, + second_handshake, + )); + + let _ = harness.run_iteration().await; + + let refreshed_session = VpnClientSession::find_by_id(&pool, session.id) + .await + .expect("failed to query session") + .expect("expected session"); + assert_eq!(refreshed_session.state, VpnClientSessionState::Connected); + assert_eq!(refreshed_session.connected_at, Some(first_handshake)); + + assert_eq!(count_session_stats(&pool, session.id).await, 1); + + let latest_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, location.id) + .await + .expect("failed to query latest stats") + .expect("expected latest stats"); + assert_eq!(latest_stats.session_id, session.id); + assert_eq!(latest_stats.total_upload, 160); + assert_eq!(latest_stats.total_download, 280); + assert_eq!(latest_stats.upload_diff, 0); + assert_eq!(latest_stats.download_diff, 0); + + assert_no_gateway_events(&mut harness); +} + #[sqlx::test] async fn test_inactive_mfa_connected_sessions_disconnect_and_clear_authorization( _: PgPoolOptions, From 19c4263350ab4f761eee9e4001b7d0002424c98e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Tue, 17 Mar 2026 11:27:52 +0100 Subject: [PATCH 13/18] more review fixes --- crates/defguard_session_manager/src/lib.rs | 14 ++++++----- .../src/session_state.rs | 25 +++++++++++-------- .../tests/session_manager/mfa.rs | 20 +++++++++++++++ 3 files changed, 43 insertions(+), 16 deletions(-) diff --git a/crates/defguard_session_manager/src/lib.rs b/crates/defguard_session_manager/src/lib.rs index 27993de125..8d5db170d7 100644 --- a/crates/defguard_session_manager/src/lib.rs +++ b/crates/defguard_session_manager/src/lib.rs @@ -278,6 +278,8 @@ impl SessionManager { location: &WireguardNetwork, ) -> Result<(), SessionManagerError> { let disconnect_timestamp = Utc::now().naive_utc(); + let is_connected = session.connected_at.is_some(); + let is_mfa_session = session.mfa_method.is_some(); // update session record in DB session.disconnected_at = Some(disconnect_timestamp); @@ -305,7 +307,10 @@ impl SessionManager { device_network_info.preshared_key = None; device_network_info.update(&mut *transaction).await?; } - self.send_peer_disconnect_message(location, &device)?; + + if is_mfa_session { + self.send_peer_disconnect_message(location, &device)?; + } } // emit event @@ -317,11 +322,8 @@ impl SessionManager { // FIXME: this is a workaround since we require an IP for each audit log event public_ip: IpAddr::V4(Ipv4Addr::UNSPECIFIED), }; - if session.connected_at.is_some() { - let event = SessionManagerEvent::disconnected_for_session( - context, - session.mfa_method.is_some(), - ); + if is_connected { + let event = SessionManagerEvent::disconnected_for_session(context, is_mfa_session); self.session_manager_event_tx.send(event)?; } diff --git a/crates/defguard_session_manager/src/session_state.rs b/crates/defguard_session_manager/src/session_state.rs index 562528f45a..1eaa6c4567 100644 --- a/crates/defguard_session_manager/src/session_state.rs +++ b/crates/defguard_session_manager/src/session_state.rs @@ -130,9 +130,19 @@ impl SessionState { ) -> Result<(), SessionManagerError> { // mark new MFA session as connected if necessary if self.state == VpnClientSessionState::New { - let event_context_data = self.event_context_data.as_ref().ok_or( - SessionManagerError::MissingSessionEventContextError(self.session_id), - )?; + let (connected_context, is_mfa_session) = { + let event_context_data = self.event_context_data.as_ref().ok_or( + SessionManagerError::MissingSessionEventContextError(self.session_id), + )?; + + ( + event_context_data.build_context( + peer_stats_update.latest_handshake, + peer_stats_update.endpoint.ip(), + ), + event_context_data.is_mfa_session, + ) + }; // fetch DB session let mut db_session = VpnClientSession::find_by_id(&mut *transaction, self.session_id) @@ -149,13 +159,8 @@ impl SessionState { // even if the event channel is closed. self.state = VpnClientSessionState::Connected; - let event = SessionManagerEvent::connected_for_session( - event_context_data.build_context( - peer_stats_update.latest_handshake, - peer_stats_update.endpoint.ip(), - ), - event_context_data.is_mfa_session, - ); + let event = + SessionManagerEvent::connected_for_session(connected_context, is_mfa_session); event_tx.send(event)?; } diff --git a/crates/defguard_session_manager/tests/session_manager/mfa.rs b/crates/defguard_session_manager/tests/session_manager/mfa.rs index 066ea355e2..83f0c499d8 100644 --- a/crates/defguard_session_manager/tests/session_manager/mfa.rs +++ b/crates/defguard_session_manager/tests/session_manager/mfa.rs @@ -529,6 +529,7 @@ async fn test_never_connected_mfa_new_sessions_disconnect_after_threshold( let user = create_user(&pool).await; let device = create_device(&pool, user.id).await; attach_device_to_location(&pool, location.id, device.id).await; + authorize_device_in_location(&pool, location.id, device.id, "psk-before-timeout").await; let mut harness = SessionManagerHarness::new(pool.clone()); let session = create_session( @@ -554,5 +555,24 @@ async fn test_never_connected_mfa_new_sessions_disconnect_after_threshold( ); assert!(disconnected_session.disconnected_at.is_some()); + let network_device = WireguardNetworkDevice::find(&pool, device.id, location.id) + .await + .expect("failed to query network device") + .expect("expected network device"); + assert!(!network_device.is_authorized); + assert_eq!(network_device.preshared_key, None); + + let gateway_event = timeout(RECEIVE_TIMEOUT, harness.gateway_rx.recv()) + .await + .expect("timed out waiting for MFA disconnect gateway event for new session") + .expect("gateway event channel closed"); + match gateway_event { + GatewayEvent::MfaSessionDisconnected(location_id, disconnected_device) => { + assert_eq!(location_id, location.id); + assert_eq!(disconnected_device.id, device.id); + } + other => panic!("unexpected gateway event: {other:?}"), + } + assert_no_session_manager_events(&mut harness); } From 2bfe6ee77fd5551f02065bb520a6508fb4b422f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Tue, 17 Mar 2026 11:44:08 +0100 Subject: [PATCH 14/18] fix for new mfa sessions --- .../src/grpc/proxy/client_mfa.rs | 37 +++++++++++-------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/crates/defguard_core/src/grpc/proxy/client_mfa.rs b/crates/defguard_core/src/grpc/proxy/client_mfa.rs index 8a72c5d94e..4808963d0e 100644 --- a/crates/defguard_core/src/grpc/proxy/client_mfa.rs +++ b/crates/defguard_core/src/grpc/proxy/client_mfa.rs @@ -844,18 +844,18 @@ impl ClientMfaServer { })?; } + // gateway update is only needed to remove peer for MFA sessions + // this is needed to remove peers for both Connected and New sessions + if is_mfa_session { + let gateway_event = GatewayEvent::MfaSessionDisconnected(location.id, device.clone()); + self.wireguard_tx.send(gateway_event).map_err(|err| { + error!("Error sending WireGuard event: {err}"); + Status::internal("unexpected error") + })?; + } + // only emit disconnect events if a session has actually been connected if is_connected { - // gateway update is only needed to remove peer for MFA sessions - if is_mfa_session { - let gateway_event = - GatewayEvent::MfaSessionDisconnected(location.id, device.clone()); - self.wireguard_tx.send(gateway_event).map_err(|err| { - error!("Error sending WireGuard event: {err}"); - Status::internal("unexpected error") - })?; - } - let context = BidiRequestContext { timestamp: disconnect_timestamp, user_id: user.id, @@ -973,7 +973,7 @@ mod tests { } #[sqlx::test] - async fn test_replacing_new_mfa_session_marks_session_disconnected_without_disconnect_event( + async fn test_replacing_new_mfa_session_marks_session_disconnected_without_disconnect_audit_event( _: PgPoolOptions, options: PgConnectOptions, ) { @@ -1007,14 +1007,21 @@ mod tests { .await .expect("should replace new MFA session"); + let gateway_event = gateway_rx + .try_recv() + .expect("expected MFA gateway disconnect event for replaced new session"); + match gateway_event { + GatewayEvent::MfaSessionDisconnected(location_id, disconnected_device) => { + assert_eq!(location_id, location.id); + assert_eq!(disconnected_device.id, device.id); + } + other => panic!("unexpected gateway event: {other:?}"), + } + assert!(matches!( event_rx.try_recv(), Err(tokio::sync::mpsc::error::TryRecvError::Empty) )); - assert!(matches!( - gateway_rx.try_recv(), - Err(broadcast::error::TryRecvError::Empty) - )); let old_session = VpnClientSession::find_by_id(&pool, old_session.id) .await From b74549431c9909f711d3f5b36b23ba5509b166db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Tue, 17 Mar 2026 12:26:37 +0100 Subject: [PATCH 15/18] remove unnecessary conditional --- crates/defguard_session_manager/src/lib.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/crates/defguard_session_manager/src/lib.rs b/crates/defguard_session_manager/src/lib.rs index 8d5db170d7..31293888db 100644 --- a/crates/defguard_session_manager/src/lib.rs +++ b/crates/defguard_session_manager/src/lib.rs @@ -308,9 +308,7 @@ impl SessionManager { device_network_info.update(&mut *transaction).await?; } - if is_mfa_session { - self.send_peer_disconnect_message(location, &device)?; - } + self.send_peer_disconnect_message(location, &device)?; } // emit event From 44a1b467a71a7e9e9f3417fe3792195b60a2357b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Wed, 18 Mar 2026 08:15:38 +0100 Subject: [PATCH 16/18] review fixes --- .../defguard_core/src/grpc/proxy/client_mfa.rs | 13 +++++-------- .../tests/integration/grpc/gateway.rs | 2 +- .../defguard_event_router/src/handlers/bidi.rs | 18 +++++++++--------- 3 files changed, 15 insertions(+), 18 deletions(-) diff --git a/crates/defguard_core/src/grpc/proxy/client_mfa.rs b/crates/defguard_core/src/grpc/proxy/client_mfa.rs index 4808963d0e..3e04b715ed 100644 --- a/crates/defguard_core/src/grpc/proxy/client_mfa.rs +++ b/crates/defguard_core/src/grpc/proxy/client_mfa.rs @@ -1,5 +1,6 @@ use std::{ collections::HashMap, + net::{IpAddr, Ipv4Addr}, sync::{Arc, RwLock}, time::Duration, }; @@ -469,12 +470,8 @@ impl ClientMfaServer { // Prepare event context let (ip, _user_agent) = parse_client_ip_agent(&info).map_err(Status::internal)?; - let context = BidiRequestContext::new( - user.id, - user.username.clone(), - ip, - format!("{} (ID {})", device.name, device.id), - ); + let context = + BidiRequestContext::new(user.id, user.username.clone(), ip, format!("{}", device)); // validate code match method { @@ -860,8 +857,8 @@ impl ClientMfaServer { timestamp: disconnect_timestamp, user_id: user.id, username: user.username.clone(), - ip: std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED), - device_name: format!("{} (ID {})", device.name, device.id), + ip: IpAddr::V4(Ipv4Addr::UNSPECIFIED), + device_name: format!("{}", device), }; self.emit_event(BidiStreamEvent { context, diff --git a/crates/defguard_core/tests/integration/grpc/gateway.rs b/crates/defguard_core/tests/integration/grpc/gateway.rs index e5de5b8268..0934bd16a5 100644 --- a/crates/defguard_core/tests/integration/grpc/gateway.rs +++ b/crates/defguard_core/tests/integration/grpc/gateway.rs @@ -352,7 +352,7 @@ async fn test_vpn_client_disconnected(_: PgPoolOptions, options: PgConnectOption device_pubkey, &test_device, &test_user, - SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080), + SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080), &stats, ) .expect("failed to insert connected client"); diff --git a/crates/defguard_event_router/src/handlers/bidi.rs b/crates/defguard_event_router/src/handlers/bidi.rs index a3155d98d3..247ded3b11 100644 --- a/crates/defguard_event_router/src/handlers/bidi.rs +++ b/crates/defguard_event_router/src/handlers/bidi.rs @@ -108,7 +108,7 @@ mod tests { }; use defguard_common::db::{ - NoId, + Id, NoId, models::{ Device, DeviceType, WireguardNetwork, wireguard::{LocationMfaMode, ServiceLocationMode}, @@ -203,12 +203,12 @@ mod tests { BidiRequestContext::new( 1, "alice".to_string(), - IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + IpAddr::V4(Ipv4Addr::LOCALHOST), "desktop-app".to_string(), ) } - fn sample_device() -> Device { + fn sample_device() -> Device { Device::new( "vpn-device".to_string(), "pubkey".to_string(), @@ -220,7 +220,7 @@ mod tests { .save_placeholder_id(20) } - fn sample_location() -> WireguardNetwork { + fn sample_location() -> WireguardNetwork { WireguardNetwork::new( "vpn-location".to_string(), vec!["10.0.0.0/24".parse().unwrap()], @@ -242,11 +242,11 @@ mod tests { } trait WithPlaceholderId { - fn save_placeholder_id(self, id: i64) -> T; + fn save_placeholder_id(self, id: Id) -> T; } - impl WithPlaceholderId> for Device { - fn save_placeholder_id(self, id: i64) -> Device { + impl WithPlaceholderId> for Device { + fn save_placeholder_id(self, id: Id) -> Device { Device { id, name: self.name, @@ -260,8 +260,8 @@ mod tests { } } - impl WithPlaceholderId> for WireguardNetwork { - fn save_placeholder_id(self, id: i64) -> WireguardNetwork { + impl WithPlaceholderId> for WireguardNetwork { + fn save_placeholder_id(self, id: Id) -> WireguardNetwork { WireguardNetwork { id, name: self.name, From ec38e55a2d28dcc6bc097318adf51df8dabd9b3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Wed, 18 Mar 2026 09:53:00 +0100 Subject: [PATCH 17/18] post-merge fixes --- .../src/grpc/proxy/client_mfa.rs | 9 ++-- .../tests/integration/grpc/gateway.rs | 39 +++++++--------- crates/defguard_event_logger/src/lib.rs | 37 ++------------- .../src/handlers/bidi.rs | 45 ++++--------------- crates/defguard_gateway_manager/src/tests.rs | 9 ++-- 5 files changed, 33 insertions(+), 106 deletions(-) diff --git a/crates/defguard_core/src/grpc/proxy/client_mfa.rs b/crates/defguard_core/src/grpc/proxy/client_mfa.rs index 3e04b715ed..fb69014060 100644 --- a/crates/defguard_core/src/grpc/proxy/client_mfa.rs +++ b/crates/defguard_core/src/grpc/proxy/client_mfa.rs @@ -1152,21 +1152,18 @@ mod tests { async fn create_mfa_location(pool: &PgPool) -> WireguardNetwork { WireguardNetwork::new( "client-mfa-location".to_string(), - vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 10, 0, 0)), 24).unwrap()], 51820, "vpn.example.com".to_string(), None, - 1420, - 0, - vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap()], + [IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap()], true, - 25, - 300, false, false, LocationMfaMode::Internal, ServiceLocationMode::Disabled, ) + .set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 10, 0, 1)), 24).unwrap()]) + .expect("failed to set location address") .save(pool) .await .expect("failed to create location") diff --git a/crates/defguard_core/tests/integration/grpc/gateway.rs b/crates/defguard_core/tests/integration/grpc/gateway.rs index eec8615ada..14d66342df 100644 --- a/crates/defguard_core/tests/integration/grpc/gateway.rs +++ b/crates/defguard_core/tests/integration/grpc/gateway.rs @@ -49,24 +49,21 @@ async fn setup_test_server( let test_server = make_grpc_test_server(&pool).await; // create a test location - let location = WireguardNetwork::new( + let mut location = WireguardNetwork::new( "test location".to_string(), - Vec::new(), 1000, "endpoint1".to_string(), None, Vec::new(), false, - 100, - 100, false, false, LocationMfaMode::Disabled, ServiceLocationMode::Disabled, - ) - .save(&pool) - .await - .unwrap(); + ); + location.keepalive_interval = 100; + location.peer_disconnect_threshold = 100; + let location = location.save(&pool).await.unwrap(); // set auth token for gateway let token = generate_gateway_token(&location); @@ -398,24 +395,21 @@ async fn test_gateway_update_routing(_: PgPoolOptions, options: PgConnectOptions setup_test_server(pool.clone()).await; // setup another test location & gateway - let test_location_2 = WireguardNetwork::new( + let mut test_location_2 = WireguardNetwork::new( "test location 2".to_string(), - Vec::new(), 1000, "endpoint2".to_string(), None, Vec::new(), false, - 100, - 100, false, false, LocationMfaMode::Disabled, ServiceLocationMode::Disabled, - ) - .save(&pool) - .await - .unwrap(); + ); + test_location_2.keepalive_interval = 100; + test_location_2.peer_disconnect_threshold = 100; + let test_location_2 = test_location_2.save(&pool).await.unwrap(); // set auth token for gateway let token = generate_gateway_token(&test_location_2); @@ -516,24 +510,21 @@ async fn test_gateway_config(_: PgPoolOptions, options: PgConnectOptions) { // unset the license and create another location to exceed limits and disable enterprise features set_cached_license(None); - let _test_location_2 = WireguardNetwork::new( + let mut test_location_2 = WireguardNetwork::new( "test location 2".to_string(), - Vec::new(), 1000, "endpoint2".to_string(), None, Vec::new(), false, - 100, - 100, false, false, LocationMfaMode::Disabled, ServiceLocationMode::Disabled, - ) - .save(&pool) - .await - .unwrap(); + ); + test_location_2.keepalive_interval = 100; + test_location_2.peer_disconnect_threshold = 100; + let _test_location_2 = test_location_2.save(&pool).await.unwrap(); update_counts(&pool).await.unwrap(); let config = gateway.get_gateway_config().await.unwrap().into_inner(); diff --git a/crates/defguard_event_logger/src/lib.rs b/crates/defguard_event_logger/src/lib.rs index 5332efccfe..45a3535b9b 100644 --- a/crates/defguard_event_logger/src/lib.rs +++ b/crates/defguard_event_logger/src/lib.rs @@ -652,22 +652,19 @@ mod tests { fn sample_location() -> WireguardNetwork { WireguardNetwork::new( "vpn-location".to_string(), - vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 0)), 24).unwrap()], 51820, "vpn.example.com".to_string(), None, - 1420, - 0, - vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap()], + [IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap()], true, - 25, - 300, false, false, LocationMfaMode::Internal, ServiceLocationMode::Disabled, ) - .save_placeholder_id(10) + .set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 24).unwrap()]) + .expect("sample location address should be valid") + .with_id(10) } #[test] @@ -704,30 +701,4 @@ mod tests { } } } - - impl WithPlaceholderId> for WireguardNetwork { - fn save_placeholder_id(self, id: i64) -> WireguardNetwork { - WireguardNetwork { - id, - name: self.name, - address: self.address, - port: self.port, - pubkey: self.pubkey, - prvkey: self.prvkey, - endpoint: self.endpoint, - dns: self.dns, - mtu: self.mtu, - fwmark: self.fwmark, - allowed_ips: self.allowed_ips, - allow_all_groups: self.allow_all_groups, - connected_at: self.connected_at, - acl_enabled: self.acl_enabled, - acl_default_allow: self.acl_default_allow, - keepalive_interval: self.keepalive_interval, - peer_disconnect_threshold: self.peer_disconnect_threshold, - location_mfa_mode: self.location_mfa_mode, - service_location_mode: self.service_location_mode, - } - } - } } diff --git a/crates/defguard_event_router/src/handlers/bidi.rs b/crates/defguard_event_router/src/handlers/bidi.rs index 247ded3b11..6e8493b8f7 100644 --- a/crates/defguard_event_router/src/handlers/bidi.rs +++ b/crates/defguard_event_router/src/handlers/bidi.rs @@ -4,7 +4,7 @@ use defguard_core::events::{ use defguard_event_logger::message::{EnrollmentEvent, EventContext, LoggerEvent, VpnEvent}; use tracing::debug; -use crate::{EventRouter, error::EventRouterError}; +use crate::{error::EventRouterError, EventRouter}; impl EventRouter { pub(crate) fn handle_bidi_event(&self, event: BidiStreamEvent) -> Result<(), EventRouterError> { @@ -108,17 +108,17 @@ mod tests { }; use defguard_common::db::{ - Id, NoId, models::{ - Device, DeviceType, WireguardNetwork, wireguard::{LocationMfaMode, ServiceLocationMode}, + Device, DeviceType, WireguardNetwork, }, + Id, NoId, }; use defguard_core::{ events::{BidiRequestContext, BidiStreamEventType}, grpc::GatewayEvent, }; - use tokio::sync::{Notify, broadcast, mpsc::unbounded_channel}; + use tokio::sync::{broadcast, mpsc::unbounded_channel, Notify}; use super::*; use crate::RouterReceiverSet; @@ -223,22 +223,19 @@ mod tests { fn sample_location() -> WireguardNetwork { WireguardNetwork::new( "vpn-location".to_string(), - vec!["10.0.0.0/24".parse().unwrap()], 51820, "vpn.example.com".to_string(), None, - 1420, - 0, - vec!["0.0.0.0/0".parse().unwrap()], + ["0.0.0.0/0".parse().expect("allowed IP should parse")], true, - 25, - 300, false, false, LocationMfaMode::Internal, ServiceLocationMode::Disabled, ) - .save_placeholder_id(10) + .set_address(["10.0.0.1/24".parse().expect("address should parse")]) + .expect("sample location address should be valid") + .with_id(10) } trait WithPlaceholderId { @@ -259,30 +256,4 @@ mod tests { } } } - - impl WithPlaceholderId> for WireguardNetwork { - fn save_placeholder_id(self, id: Id) -> WireguardNetwork { - WireguardNetwork { - id, - name: self.name, - address: self.address, - port: self.port, - pubkey: self.pubkey, - prvkey: self.prvkey, - endpoint: self.endpoint, - dns: self.dns, - mtu: self.mtu, - fwmark: self.fwmark, - allowed_ips: self.allowed_ips, - allow_all_groups: self.allow_all_groups, - connected_at: self.connected_at, - acl_enabled: self.acl_enabled, - acl_default_allow: self.acl_default_allow, - keepalive_interval: self.keepalive_interval, - peer_disconnect_threshold: self.peer_disconnect_threshold, - location_mfa_mode: self.location_mfa_mode, - service_location_mode: self.service_location_mode, - } - } - } } diff --git a/crates/defguard_gateway_manager/src/tests.rs b/crates/defguard_gateway_manager/src/tests.rs index 2257afcba0..e6fbe3cfc6 100644 --- a/crates/defguard_gateway_manager/src/tests.rs +++ b/crates/defguard_gateway_manager/src/tests.rs @@ -73,21 +73,18 @@ async fn test_gateway(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; let network = WireguardNetwork::new( "TestNet".to_string(), - vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 1, 1, 1)), 24).unwrap()], 50051, "0.0.0.0".to_string(), None, - 1420, - 0, - vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 1, 1, 0)), 24).unwrap()], + [IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 1, 1, 0)), 24).unwrap()], false, - 25, - 300, false, false, LocationMfaMode::default(), ServiceLocationMode::default(), ) + .set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 1, 1, 1)), 24).unwrap()]) + .expect("test network address should be valid") .save(&pool) .await .unwrap(); From 77fcc40cc602da5a2c09aecc59d99f44027fdc62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Wed, 18 Mar 2026 10:21:03 +0100 Subject: [PATCH 18/18] post-merge fixes --- crates/defguard_core/src/grpc/proxy/client_mfa.rs | 3 +-- crates/defguard_event_logger/src/lib.rs | 4 +++- crates/defguard_event_router/src/handlers/bidi.rs | 8 ++++---- crates/defguard_session_manager/src/session_state.rs | 2 +- .../defguard_session_manager/tests/session_manager/mfa.rs | 6 +++--- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/crates/defguard_core/src/grpc/proxy/client_mfa.rs b/crates/defguard_core/src/grpc/proxy/client_mfa.rs index fb69014060..865713675b 100644 --- a/crates/defguard_core/src/grpc/proxy/client_mfa.rs +++ b/crates/defguard_core/src/grpc/proxy/client_mfa.rs @@ -1,6 +1,5 @@ use std::{ collections::HashMap, - net::{IpAddr, Ipv4Addr}, sync::{Arc, RwLock}, time::Duration, }; @@ -857,7 +856,7 @@ impl ClientMfaServer { timestamp: disconnect_timestamp, user_id: user.id, username: user.username.clone(), - ip: IpAddr::V4(Ipv4Addr::UNSPECIFIED), + ip: None, device_name: format!("{}", device), }; self.emit_event(BidiStreamEvent { diff --git a/crates/defguard_event_logger/src/lib.rs b/crates/defguard_event_logger/src/lib.rs index e1ba13f17b..3daf2fd0f3 100644 --- a/crates/defguard_event_logger/src/lib.rs +++ b/crates/defguard_event_logger/src/lib.rs @@ -625,6 +625,7 @@ pub async fn run_event_logger( #[cfg(test)] mod tests { + use chrono::Utc; use defguard_common::db::{ NoId, models::{ @@ -633,6 +634,7 @@ mod tests { }, }; use ipnetwork::IpNetwork; + use serde_json::Value; use std::net::{IpAddr, Ipv4Addr}; use super::*; @@ -646,7 +648,7 @@ mod tests { None, true, ) - .save_placeholder_id(20) + .with_id(20) } fn sample_location() -> WireguardNetwork { diff --git a/crates/defguard_event_router/src/handlers/bidi.rs b/crates/defguard_event_router/src/handlers/bidi.rs index 6e8493b8f7..441cc6c4f7 100644 --- a/crates/defguard_event_router/src/handlers/bidi.rs +++ b/crates/defguard_event_router/src/handlers/bidi.rs @@ -4,7 +4,7 @@ use defguard_core::events::{ use defguard_event_logger::message::{EnrollmentEvent, EventContext, LoggerEvent, VpnEvent}; use tracing::debug; -use crate::{error::EventRouterError, EventRouter}; +use crate::{EventRouter, error::EventRouterError}; impl EventRouter { pub(crate) fn handle_bidi_event(&self, event: BidiStreamEvent) -> Result<(), EventRouterError> { @@ -108,17 +108,17 @@ mod tests { }; use defguard_common::db::{ + Id, NoId, models::{ - wireguard::{LocationMfaMode, ServiceLocationMode}, Device, DeviceType, WireguardNetwork, + wireguard::{LocationMfaMode, ServiceLocationMode}, }, - Id, NoId, }; use defguard_core::{ events::{BidiRequestContext, BidiStreamEventType}, grpc::GatewayEvent, }; - use tokio::sync::{broadcast, mpsc::unbounded_channel, Notify}; + use tokio::sync::{Notify, broadcast, mpsc::unbounded_channel}; use super::*; use crate::RouterReceiverSet; diff --git a/crates/defguard_session_manager/src/session_state.rs b/crates/defguard_session_manager/src/session_state.rs index 793003d38e..95bbf012f8 100644 --- a/crates/defguard_session_manager/src/session_state.rs +++ b/crates/defguard_session_manager/src/session_state.rs @@ -213,7 +213,7 @@ impl SessionEventContextData { location: self.location.clone(), user: self.user.clone(), device: self.device.clone(), - public_ip, + public_ip: Some(public_ip), } } } diff --git a/crates/defguard_session_manager/tests/session_manager/mfa.rs b/crates/defguard_session_manager/tests/session_manager/mfa.rs index 83f0c499d8..086d15fd2a 100644 --- a/crates/defguard_session_manager/tests/session_manager/mfa.rs +++ b/crates/defguard_session_manager/tests/session_manager/mfa.rs @@ -125,7 +125,7 @@ async fn test_mfa_new_session_upgrades_to_connected_on_stats( assert_eq!(connected_event.context.location.id, location.id); assert_eq!(connected_event.context.user.id, user.id); assert_eq!(connected_event.context.device.id, device.id); - assert_eq!(connected_event.context.public_ip, endpoint.ip()); + assert_eq!(connected_event.context.public_ip, Some(endpoint.ip())); let second_collected_at = handshake + TimeDelta::seconds(30); let second_handshake = handshake + TimeDelta::seconds(25); @@ -254,7 +254,7 @@ async fn test_duplicate_first_stats_on_mfa_new_session_are_idempotent( assert_eq!(connected_event.context.location.id, location.id); assert_eq!(connected_event.context.user.id, user.id); assert_eq!(connected_event.context.device.id, device.id); - assert_eq!(connected_event.context.public_ip, endpoint.ip()); + assert_eq!(connected_event.context.public_ip, Some(endpoint.ip())); assert_no_session_manager_events(&mut harness); assert_no_gateway_events(&mut harness); @@ -316,7 +316,7 @@ async fn test_repeated_later_stats_on_mfa_session_remain_idempotent( assert_eq!(connected_event.context.location.id, location.id); assert_eq!(connected_event.context.user.id, user.id); assert_eq!(connected_event.context.device.id, device.id); - assert_eq!(connected_event.context.public_ip, endpoint.ip()); + assert_eq!(connected_event.context.public_ip, Some(endpoint.ip())); assert_no_session_manager_events(&mut harness); assert_no_gateway_events(&mut harness);