Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 10 additions & 11 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 19 additions & 2 deletions crates/defguard_common/src/db/models/vpn_client_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl VpnClientSession<Id> {
}

/// Fetch active sessions which have become inactive for a specific location
pub async fn get_inactive<'e, E: sqlx::PgExecutor<'e>>(
pub async fn get_all_inactive_for_location<'e, E: sqlx::PgExecutor<'e>>(
executor: E,
location: &WireguardNetwork<Id>,
) -> Result<Vec<Self>, SqlxError> {
Expand All @@ -126,7 +126,7 @@ impl VpnClientSession<Id> {
SELECT latest_handshake \
FROM vpn_session_stats \
WHERE session_id = s.id \
ORDER BY collected_at DESC \
ORDER BY latest_handshake DESC \
LIMIT 1 \
) ss ON true \
WHERE location_id = $1 AND state = 'connected' \
Expand All @@ -152,4 +152,21 @@ impl VpnClientSession<Id> {
f64::from(location.peer_disconnect_threshold)
).fetch_all(executor).await
}

/// Fetch all active sessions for a given device in a given location
pub async fn get_all_active_device_sessions_in_location<'e, E: sqlx::PgExecutor<'e>>(
executor: E,
location_id: Id,
device_id: Id,
) -> Result<Vec<Self>, SqlxError> {
query_as!(
Self,
"SELECT id, location_id, user_id, device_id, created_at, connected_at, disconnected_at, \
mfa_method \"mfa_method: VpnClientMfaMethod\", state \"state: VpnClientSessionState\" \
FROM vpn_client_session \
WHERE location_id = $1 AND device_id = $2 AND state IN ('new', 'connected')",
location_id,
device_id,
).fetch_all(executor).await
}
}
111 changes: 102 additions & 9 deletions crates/defguard_core/src/grpc/proxy/client_mfa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ use defguard_common::{
Id,
models::{
BiometricAuth, BiometricChallenge, Device, User, WireguardNetwork,
device::WireguardNetworkDevice, vpn_client_session::VpnClientSession,
device::WireguardNetworkDevice,
vpn_client_session::{VpnClientMfaMethod, VpnClientSession, VpnClientSessionState},
wireguard::LocationMfaMode,
},
},
Expand All @@ -24,7 +25,7 @@ use defguard_proto::proxy::{
ClientMfaTokenValidationRequest, ClientMfaTokenValidationResponse, CoreResponse, MfaMethod,
core_response::Payload,
};
use sqlx::PgPool;
use sqlx::{PgConnection, PgPool};
use thiserror::Error;
use tokio::{
sync::{
Expand Down Expand Up @@ -703,14 +704,13 @@ impl ClientMfaServer {
})?;

// create new VPN client session
let vpn_client_session = VpnClientSession::new(
location.id,
user.id,
device.id,
None,
Some(method.into()),
let vpn_client_session = self.create_new_mfa_session(
&mut transaction,
&location,
&user,
&device,
method.into(),
)
.save(&mut *transaction)
.await
.map_err(|err| {
error!("Failed to create new VPN client session for device {device} in location {location}: {err}");
Expand Down Expand Up @@ -750,4 +750,97 @@ impl ClientMfaServer {

Ok(response)
}

/// Helper used to close all existing active sessions while creating a new MFA session
/// and send relevant gateway updates
async fn create_new_mfa_session(
&self,
conn: &mut PgConnection,
location: &WireguardNetwork<Id>,
user: &User<Id>,
device: &Device<Id>,
mfa_method: VpnClientMfaMethod,
) -> Result<VpnClientSession<Id>, Status> {
debug!(
"Creating new VPN session for device {device} of user {user} in location {location} after successful MFA authorization."
);

// find all active sessions for a given device and location
let active_sessions = VpnClientSession::get_all_active_device_sessions_in_location(&mut *conn, location.id, device.id).await
.map_err(|err| {
error!("Failed to fetch active VPN sessions for device {device} in location {location}: {err}");
Status::internal("unexpected error")
})?;
if !active_sessions.is_empty() {
info!(
"Found {} active sessions for device {device} in location {location}. Disconnecting them before creating a new MFA session",
active_sessions.len()
);
}

// disconnect all active sessions
for session in active_sessions {
debug!("Disconnecting previous active MFA VPN session {session:?}.");
self.disconnect_session(&mut *conn, session, location, device)
.await?;
}

// create new MFA session
VpnClientSession::new(location.id, user.id, device.id, None, Some(mfa_method)).save(conn).await
.map_err(|err| {
error!("Failed to create new VPN client session for device {device} in location {location}: {err}");
Status::internal("unexpected error")
})
}

/// Update session state as disconnected and send relevant gateway update
async fn disconnect_session(
&self,
conn: &mut PgConnection,
mut session: VpnClientSession<Id>,
location: &WireguardNetwork<Id>,
device: &Device<Id>,
) -> Result<(), Status> {
// update session state in DB
let disconnect_timestamp = Utc::now().naive_utc();
session.disconnected_at = Some(disconnect_timestamp);
session.state = VpnClientSessionState::Disconnected;
session.save(&mut *conn).await.map_err(|err| {
error!("Failed to update VPN session {session:?}: {err}");
Status::internal("unexpected error")
})?;

// FIXME: remove once MFA-related data is no longer stored here
// update device network config
if let Some(mut device_network_info) = WireguardNetworkDevice::find(
&mut *conn,
device.id,
location.id,
)
.await
.map_err(|err| {
error!(
"Failed to fetch WireGuard config for device {device} in location {location}: {err}"
);
Status::internal("unexpected error")
})? {
device_network_info.is_authorized = false;
device_network_info.preshared_key = None;
device_network_info.update(&mut *conn).await.map_err(|err| {
error!(
"Failed to update WireGuard config for device {device} in location {location}: {err}"
);
Status::internal("unexpected error")
})?;
};
let event = GatewayEvent::MfaSessionDisconnected(location.id, device.clone());
self.wireguard_tx.send(event).map_err(|err| {
error!("Error sending WireGuard event: {err}");
Status::internal("unexpected error")
})?;

// FIXME: add audit log event

Ok(())
}
}
13 changes: 7 additions & 6 deletions crates/defguard_session_manager/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ use defguard_common::{
db::{
Id,
models::{
Device, User, WireguardNetwork, device::WireguardNetworkDevice,
vpn_client_session::VpnClientSession,
Device, User, WireguardNetwork,
device::WireguardNetworkDevice,
vpn_client_session::{VpnClientSession, VpnClientSessionState},
},
},
messages::peer_stats_update::PeerStatsUpdate,
Expand Down Expand Up @@ -193,7 +194,8 @@ impl SessionManager {

// get all connected sessions which have become inactive
let inactive_sessions =
VpnClientSession::get_inactive(&mut *transaction, &location).await?;
VpnClientSession::get_all_inactive_for_location(&mut *transaction, &location)
.await?;

debug!(
"Found {} inactive VPN sessions in location {location}",
Expand Down Expand Up @@ -250,8 +252,7 @@ impl SessionManager {

// update session record in DB
session.disconnected_at = Some(disconnect_timestamp);
session.state =
defguard_common::db::models::vpn_client_session::VpnClientSessionState::Disconnected;
session.state = VpnClientSessionState::Disconnected;
session.save(&mut *transaction).await?;

// fetch related objects necessary for event context
Expand All @@ -266,7 +267,7 @@ impl SessionManager {

// remove peers from GW for MFA locations
if location.mfa_enabled() {
// FIXME: remove one MFA-related data is no longer stored here
// FIXME: remove once MFA-related data is no longer stored here
// update device network config
if let Some(mut device_network_info) =
WireguardNetworkDevice::find(&mut *transaction, device.id, location.id).await?
Expand Down
12 changes: 6 additions & 6 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading