diff --git a/.sqlx/query-3b063ceba4e1b38bc3bb921468ba57f393d5eff60c4bc6703f926bef2583a781.json b/.sqlx/query-743a5ec9c0ef4e1f92465ad287dc211dedd7b66f89cdf11b36e4a5d7306258be.json similarity index 87% rename from .sqlx/query-3b063ceba4e1b38bc3bb921468ba57f393d5eff60c4bc6703f926bef2583a781.json rename to .sqlx/query-743a5ec9c0ef4e1f92465ad287dc211dedd7b66f89cdf11b36e4a5d7306258be.json index c398106510..b98cbb6ba2 100644 --- a/.sqlx/query-3b063ceba4e1b38bc3bb921468ba57f393d5eff60c4bc6703f926bef2583a781.json +++ b/.sqlx/query-743a5ec9c0ef4e1f92465ad287dc211dedd7b66f89cdf11b36e4a5d7306258be.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "SELECT s.id, location_id, user_id, device_id, created_at, s.connected_at, disconnected_at, mfa_method \"mfa_method: VpnClientMfaMethod\", state \"state: VpnClientSessionState\" FROM vpn_client_session s LEFT JOIN LATERAL ( SELECT latest_handshake FROM vpn_session_stats WHERE session_id = s.id ORDER BY collected_at DESC LIMIT 1 ) ss ON true WHERE location_id = $1 AND state = 'connected' AND (NOW() - ss.latest_handshake) > $2 * interval '1 second'", + "query": "SELECT s.id, location_id, user_id, device_id, created_at, s.connected_at, disconnected_at, mfa_method \"mfa_method: VpnClientMfaMethod\", state \"state: VpnClientSessionState\" FROM vpn_client_session s LEFT JOIN LATERAL ( SELECT latest_handshake FROM vpn_session_stats WHERE session_id = s.id ORDER BY latest_handshake DESC LIMIT 1 ) ss ON true WHERE location_id = $1 AND state = 'connected' AND (NOW() - ss.latest_handshake) > $2 * interval '1 second'", "describe": { "columns": [ { @@ -91,5 +91,5 @@ false ] }, - "hash": "3b063ceba4e1b38bc3bb921468ba57f393d5eff60c4bc6703f926bef2583a781" + "hash": "743a5ec9c0ef4e1f92465ad287dc211dedd7b66f89cdf11b36e4a5d7306258be" } diff --git a/Cargo.lock b/Cargo.lock index cdccff64e9..f511c7a28c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1846,9 +1846,9 @@ checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" [[package]] name = "flate2" -version = "1.1.8" +version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b375d6465b98090a5f25b1c7703f3859783755aa9a80433b36e0379a3ec2f369" +checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" dependencies = [ "crc32fast", "miniz_oxide", @@ -2087,9 +2087,9 @@ dependencies = [ [[package]] name = "git2" -version = "0.20.3" +version = "0.20.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e2b37e2f62729cdada11f0e6b3b6fe383c69c29fc619e391223e12856af308c" +checksum = "7b88256088d75a56f8ecfa070513a775dd9107f6530ef14919dac831af9cfe2b" dependencies = [ "bitflags 2.10.0", "libc", @@ -2427,14 +2427,13 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.19" +version = "0.1.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "727805d60e7938b76b826a6ef209eb70eaa1812794f9424d4a4e2d740662df5f" +checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" dependencies = [ "base64 0.22.1", "bytes", "futures-channel", - "futures-core", "futures-util", "http", "http-body", @@ -5504,9 +5503,9 @@ dependencies = [ [[package]] name = "system-configuration" -version = "0.6.1" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" +checksum = "a13f3d0daba03132c0aa9767f98351b3488edc2c100cda2d2ec2b04f3d8d3c8b" dependencies = [ "bitflags 2.10.0", "core-foundation 0.9.4", @@ -7079,9 +7078,9 @@ dependencies = [ [[package]] name = "zlib-rs" -version = "0.5.5" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40990edd51aae2c2b6907af74ffb635029d5788228222c4bb811e9351c0caad3" +checksum = "a7948af682ccbc3342b6e9420e8c51c1fe5d7bf7756002b4a3c6cabfe96a7e3c" [[package]] name = "zmij" diff --git a/crates/defguard_common/src/db/models/vpn_client_session.rs b/crates/defguard_common/src/db/models/vpn_client_session.rs index 9b5c98d2c1..6bf17ec77a 100644 --- a/crates/defguard_common/src/db/models/vpn_client_session.rs +++ b/crates/defguard_common/src/db/models/vpn_client_session.rs @@ -113,7 +113,7 @@ impl VpnClientSession { } /// Fetch active sessions which have become inactive for a specific location - pub async fn get_inactive<'e, E: sqlx::PgExecutor<'e>>( + pub async fn get_all_inactive_for_location<'e, E: sqlx::PgExecutor<'e>>( executor: E, location: &WireguardNetwork, ) -> Result, SqlxError> { @@ -126,7 +126,7 @@ impl VpnClientSession { SELECT latest_handshake \ FROM vpn_session_stats \ WHERE session_id = s.id \ - ORDER BY collected_at DESC \ + ORDER BY latest_handshake DESC \ LIMIT 1 \ ) ss ON true \ WHERE location_id = $1 AND state = 'connected' \ @@ -152,4 +152,21 @@ impl VpnClientSession { f64::from(location.peer_disconnect_threshold) ).fetch_all(executor).await } + + /// Fetch all active sessions for a given device in a given location + pub async fn get_all_active_device_sessions_in_location<'e, E: sqlx::PgExecutor<'e>>( + executor: E, + location_id: Id, + device_id: Id, + ) -> Result, SqlxError> { + query_as!( + Self, + "SELECT id, location_id, user_id, device_id, created_at, connected_at, disconnected_at, \ + mfa_method \"mfa_method: VpnClientMfaMethod\", state \"state: VpnClientSessionState\" \ + FROM vpn_client_session \ + WHERE location_id = $1 AND device_id = $2 AND state IN ('new', 'connected')", + location_id, + device_id, + ).fetch_all(executor).await + } } diff --git a/crates/defguard_core/src/grpc/proxy/client_mfa.rs b/crates/defguard_core/src/grpc/proxy/client_mfa.rs index 0ba29f102a..44beade104 100644 --- a/crates/defguard_core/src/grpc/proxy/client_mfa.rs +++ b/crates/defguard_core/src/grpc/proxy/client_mfa.rs @@ -11,7 +11,8 @@ use defguard_common::{ Id, models::{ BiometricAuth, BiometricChallenge, Device, User, WireguardNetwork, - device::WireguardNetworkDevice, vpn_client_session::VpnClientSession, + device::WireguardNetworkDevice, + vpn_client_session::{VpnClientMfaMethod, VpnClientSession, VpnClientSessionState}, wireguard::LocationMfaMode, }, }, @@ -24,7 +25,7 @@ use defguard_proto::proxy::{ ClientMfaTokenValidationRequest, ClientMfaTokenValidationResponse, CoreResponse, MfaMethod, core_response::Payload, }; -use sqlx::PgPool; +use sqlx::{PgConnection, PgPool}; use thiserror::Error; use tokio::{ sync::{ @@ -703,14 +704,13 @@ impl ClientMfaServer { })?; // create new VPN client session - let vpn_client_session = VpnClientSession::new( - location.id, - user.id, - device.id, - None, - Some(method.into()), + let vpn_client_session = self.create_new_mfa_session( + &mut transaction, + &location, + &user, + &device, + method.into(), ) - .save(&mut *transaction) .await .map_err(|err| { error!("Failed to create new VPN client session for device {device} in location {location}: {err}"); @@ -750,4 +750,97 @@ impl ClientMfaServer { Ok(response) } + + /// Helper used to close all existing active sessions while creating a new MFA session + /// and send relevant gateway updates + async fn create_new_mfa_session( + &self, + conn: &mut PgConnection, + location: &WireguardNetwork, + user: &User, + device: &Device, + mfa_method: VpnClientMfaMethod, + ) -> Result, Status> { + debug!( + "Creating new VPN session for device {device} of user {user} in location {location} after successful MFA authorization." + ); + + // find all active sessions for a given device and location + let active_sessions = VpnClientSession::get_all_active_device_sessions_in_location(&mut *conn, location.id, device.id).await + .map_err(|err| { + error!("Failed to fetch active VPN sessions for device {device} in location {location}: {err}"); + Status::internal("unexpected error") + })?; + if !active_sessions.is_empty() { + info!( + "Found {} active sessions for device {device} in location {location}. Disconnecting them before creating a new MFA session", + active_sessions.len() + ); + } + + // disconnect all active sessions + for session in active_sessions { + debug!("Disconnecting previous active MFA VPN session {session:?}."); + self.disconnect_session(&mut *conn, session, location, device) + .await?; + } + + // create new MFA session + VpnClientSession::new(location.id, user.id, device.id, None, Some(mfa_method)).save(conn).await + .map_err(|err| { + error!("Failed to create new VPN client session for device {device} in location {location}: {err}"); + Status::internal("unexpected error") + }) + } + + /// Update session state as disconnected and send relevant gateway update + async fn disconnect_session( + &self, + conn: &mut PgConnection, + mut session: VpnClientSession, + location: &WireguardNetwork, + device: &Device, + ) -> Result<(), Status> { + // update session state in DB + let disconnect_timestamp = Utc::now().naive_utc(); + session.disconnected_at = Some(disconnect_timestamp); + session.state = VpnClientSessionState::Disconnected; + session.save(&mut *conn).await.map_err(|err| { + error!("Failed to update VPN session {session:?}: {err}"); + Status::internal("unexpected error") + })?; + + // FIXME: remove once MFA-related data is no longer stored here + // update device network config + if let Some(mut device_network_info) = WireguardNetworkDevice::find( + &mut *conn, + device.id, + location.id, + ) + .await + .map_err(|err| { + error!( + "Failed to fetch WireGuard config for device {device} in location {location}: {err}" + ); + Status::internal("unexpected error") + })? { + device_network_info.is_authorized = false; + device_network_info.preshared_key = None; + device_network_info.update(&mut *conn).await.map_err(|err| { + error!( + "Failed to update WireGuard config for device {device} in location {location}: {err}" + ); + Status::internal("unexpected error") + })?; + }; + let event = GatewayEvent::MfaSessionDisconnected(location.id, device.clone()); + self.wireguard_tx.send(event).map_err(|err| { + error!("Error sending WireGuard event: {err}"); + Status::internal("unexpected error") + })?; + + // FIXME: add audit log event + + Ok(()) + } } diff --git a/crates/defguard_session_manager/src/lib.rs b/crates/defguard_session_manager/src/lib.rs index 451e90c0e1..0d0308f25b 100644 --- a/crates/defguard_session_manager/src/lib.rs +++ b/crates/defguard_session_manager/src/lib.rs @@ -5,8 +5,9 @@ use defguard_common::{ db::{ Id, models::{ - Device, User, WireguardNetwork, device::WireguardNetworkDevice, - vpn_client_session::VpnClientSession, + Device, User, WireguardNetwork, + device::WireguardNetworkDevice, + vpn_client_session::{VpnClientSession, VpnClientSessionState}, }, }, messages::peer_stats_update::PeerStatsUpdate, @@ -193,7 +194,8 @@ impl SessionManager { // get all connected sessions which have become inactive let inactive_sessions = - VpnClientSession::get_inactive(&mut *transaction, &location).await?; + VpnClientSession::get_all_inactive_for_location(&mut *transaction, &location) + .await?; debug!( "Found {} inactive VPN sessions in location {location}", @@ -250,8 +252,7 @@ impl SessionManager { // update session record in DB session.disconnected_at = Some(disconnect_timestamp); - session.state = - defguard_common::db::models::vpn_client_session::VpnClientSessionState::Disconnected; + session.state = VpnClientSessionState::Disconnected; session.save(&mut *transaction).await?; // fetch related objects necessary for event context @@ -266,7 +267,7 @@ impl SessionManager { // remove peers from GW for MFA locations if location.mfa_enabled() { - // FIXME: remove one MFA-related data is no longer stored here + // FIXME: remove once MFA-related data is no longer stored here // update device network config if let Some(mut device_network_info) = WireguardNetworkDevice::find(&mut *transaction, device.id, location.id).await? diff --git a/flake.lock b/flake.lock index b3064ed172..49cfdf332e 100644 --- a/flake.lock +++ b/flake.lock @@ -32,11 +32,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1769789167, - "narHash": "sha256-kKB3bqYJU5nzYeIROI82Ef9VtTbu4uA3YydSk/Bioa8=", + "lastModified": 1770019141, + "narHash": "sha256-VKS4ZLNx4PNrABoB0L8KUpc1fE7CLpQXQs985tGfaCU=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "62c8382960464ceb98ea593cb8321a2cf8f9e3e5", + "rev": "cb369ef2efd432b3cdf8622b0ffc0a97a02f3137", "type": "github" }, "original": { @@ -74,11 +74,11 @@ ] }, "locked": { - "lastModified": 1770001842, - "narHash": "sha256-ZAyTeILfdWwDp1nuF0RK3McBduMi49qnJvrS+3Ezpac=", + "lastModified": 1770088046, + "narHash": "sha256-4hfYDnUTvL1qSSZEA4CEThxfz+KlwSFQ30Z9jgDguO0=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "5018343419ea808f8a413241381976b7e60951f2", + "rev": "71f9daa4e05e49c434d08627e755495ae222bc34", "type": "github" }, "original": {