diff --git a/Cargo.lock b/Cargo.lock index 4bab995879..26c8d2cdff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -160,9 +160,9 @@ checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" [[package]] name = "ar_archive_writer" -version = "0.2.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0c269894b6fe5e9d7ada0cf69b5bf847ff35bc25fc271f08e1d080fce80339a" +checksum = "7eb93bbb63b9c227414f6eb3a0adfddca591a8ce1e9b60661bb08969b87e340b" dependencies = [ "object", ] @@ -1278,6 +1278,7 @@ dependencies = [ "chrono", "defguard_common", "defguard_core", + "defguard_session_manager", "serde_json", "sqlx", "thiserror 2.0.18", @@ -1292,6 +1293,7 @@ dependencies = [ "defguard_core", "defguard_event_logger", "defguard_mail", + "defguard_session_manager", "thiserror 2.0.18", "tokio", "tracing", @@ -2839,9 +2841,9 @@ dependencies = [ [[package]] name = "libm" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" +checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" [[package]] name = "libredox" @@ -3150,9 +3152,9 @@ dependencies = [ [[package]] name = "num-conv" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050" [[package]] name = "num-derive" @@ -3407,9 +3409,9 @@ dependencies = [ [[package]] name = "object" -version = "0.32.2" +version = "0.37.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" +checksum = "ff76201f031d8863c38aa7f905eca4f53abbfa15f609db4277d44cd8938f33fe" dependencies = [ "memchr", ] @@ -4086,9 +4088,9 @@ checksum = "33cb294fe86a74cbcf50d4445b37da762029549ebeea341421c7c70370f86cac" [[package]] name = "psm" -version = "0.1.28" +version = "0.1.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d11f2fedc3b7dafdc2851bc52f277377c5473d378859be234bc7ebb593144d01" +checksum = "1fa96cb91275ed31d6da3e983447320c4eb219ac180fa1679a0889ff32861e2d" dependencies = [ "ar_archive_writer", "cc", @@ -5076,9 +5078,9 @@ dependencies = [ [[package]] name = "socket2" -version = "0.6.1" +version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17129e116933cf371d018bb80ae557e889637989d8638274fb25622827b03881" +checksum = "86f4aa3ad99f2088c990dfa82d367e19cb29268ed67c574d10d0a4bfe71f07e0" dependencies = [ "libc", "windows-sys 0.60.2", @@ -5606,9 +5608,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.45" +version = "0.3.46" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9e442fc33d7fdb45aa9bfeb312c095964abdf596f7567261062b2a7107aaabd" +checksum = "9da98b7d9b7dad93488a84b8248efc35352b0b2657397d4167e7ad67e5d535e5" dependencies = [ "deranged", "itoa", @@ -5623,15 +5625,15 @@ dependencies = [ [[package]] name = "time-core" -version = "0.1.7" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b36ee98fd31ec7426d599183e8fe26932a8dc1fb76ddb6214d05493377d34ca" +checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca" [[package]] name = "time-macros" -version = "0.2.25" +version = "0.2.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71e552d1249bf61ac2a52db88179fd0673def1e1ad8243a00d9ec9ed71fee3dd" +checksum = "78cc610bac2dcee56805c99642447d4c5dbde4d01f752ffea0199aee1f601dc4" dependencies = [ "num-conv", "time-core", @@ -6198,9 +6200,9 @@ checksum = "e2eebbbfe4093922c2b6734d7c679ebfebd704a0d7e56dfcb0d05818ce28977d" [[package]] name = "uuid" -version = "1.19.0" +version = "1.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2e054861b4bd027cd373e18e8d8d8e6548085000e41290d95ce0c373a654b4a" +checksum = "ee48d38b119b0cd71fe4141b30f5ba9c7c5d9f4e7a3a8b4a674e4b6ef789976f" dependencies = [ "getrandom 0.3.4", "js-sys", @@ -6962,18 +6964,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.33" +version = "0.8.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "668f5168d10b9ee831de31933dc111a459c97ec93225beb307aed970d1372dfd" +checksum = "71ddd76bcebeed25db614f82bf31a9f4222d3fbba300e6fb6c00afa26cbd4d9d" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.33" +version = "0.8.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c7962b26b0a8685668b671ee4b54d007a67d4eaf05fda79ac0ecf41e32270f1" +checksum = "d8187381b52e32220d50b255276aa16a084ec0a9017a0ca2152a1f55c539758d" dependencies = [ "proc-macro2", "quote", @@ -7076,9 +7078,9 @@ checksum = "40990edd51aae2c2b6907af74ffb635029d5788228222c4bb811e9351c0caad3" [[package]] name = "zmij" -version = "1.0.16" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfcd145825aace48cff44a8844de64bf75feec3080e0aa5cdbde72961ae51a65" +checksum = "02aae0f83f69aafc94776e879363e9771d7ecbffe2c7fbb6c14c5e00dfe88439" [[package]] name = "zopfli" diff --git a/crates/defguard/src/main.rs b/crates/defguard/src/main.rs index ac6e2f75c1..6bc047ae7b 100644 --- a/crates/defguard/src/main.rs +++ b/crates/defguard/src/main.rs @@ -24,10 +24,10 @@ use defguard_core::{ license::{License, run_periodic_license_check, set_cached_license}, limits::update_counts, }, - events::{ApiEvent, BidiStreamEvent, GrpcEvent, InternalEvent}, + events::{ApiEvent, BidiStreamEvent, InternalEvent}, grpc::{ WorkerState, - gateway::{client_state::ClientMap, events::GatewayEvent, run_grpc_gateway_stream}, + gateway::{events::GatewayEvent, run_grpc_gateway_stream}, run_grpc_server, }, init_dev_env, init_vpn_location, run_web_server, @@ -40,7 +40,7 @@ use defguard_event_logger::{message::EventLoggerMessage, run_event_logger}; use defguard_event_router::{RouterReceiverSet, run_event_router}; use defguard_mail::{Mail, run_mail_handler}; use defguard_proxy_manager::{ProxyManager, ProxyTxSet}; -use defguard_session_manager::run_session_manager; +use defguard_session_manager::{events::SessionManagerEvent, run_session_manager}; use secrecy::ExposeSecret; use tokio::sync::{broadcast, mpsc::unbounded_channel}; @@ -101,7 +101,8 @@ async fn main() -> Result<(), anyhow::Error> { let (api_event_tx, api_event_rx) = unbounded_channel::(); let (bidi_event_tx, bidi_event_rx) = unbounded_channel::(); let (internal_event_tx, internal_event_rx) = unbounded_channel::(); - let (grpc_event_tx, grpc_event_rx) = unbounded_channel::(); + let (session_manager_event_tx, session_manager_event_rx) = + unbounded_channel::(); // Activity log stream setup let (activity_log_messages_tx, activity_log_messages_rx) = broadcast::channel::(100); @@ -115,7 +116,6 @@ async fn main() -> Result<(), anyhow::Error> { let (peer_stats_tx, peer_stats_rx) = unbounded_channel::(); let worker_state = Arc::new(Mutex::new(WorkerState::new(webhook_tx.clone()))); - let client_state = Arc::new(Mutex::new(ClientMap::new())); let incompatible_components: Arc> = Arc::default(); @@ -182,10 +182,8 @@ async fn main() -> Result<(), anyhow::Error> { res = proxy_manager.run() => error!("ProxyManager returned early: {res:?}"), res = run_grpc_gateway_stream( pool.clone(), - client_state, wireguard_tx.clone(), mail_tx.clone(), - grpc_event_tx, peer_stats_tx, ) => error!("Gateway gRPC stream returned early: {res:?}"), res = run_grpc_server( @@ -225,9 +223,9 @@ async fn main() -> Result<(), anyhow::Error> { res = run_event_router( RouterReceiverSet::new( api_event_rx, - grpc_event_rx, bidi_event_rx, - internal_event_rx + internal_event_rx, + session_manager_event_rx ), event_logger_tx, wireguard_tx, @@ -243,7 +241,8 @@ async fn main() -> Result<(), anyhow::Error> { ) => error!("Activity log stream manager returned early: {res:?}"), res = run_session_manager( pool.clone(), - peer_stats_rx + peer_stats_rx, + session_manager_event_tx ) => error!("VPN client session manager returned early: {res:?}"), } diff --git a/crates/defguard_core/src/events.rs b/crates/defguard_core/src/events.rs index 9a288fa280..70bae976cf 100644 --- a/crates/defguard_core/src/events.rs +++ b/crates/defguard_core/src/events.rs @@ -309,22 +309,8 @@ pub struct ApiEvent { /// Events from gRPC server #[derive(Debug)] pub enum GrpcEvent { - GatewayConnected { - location: WireguardNetwork, - }, - GatewayDisconnected { - location: WireguardNetwork, - }, - ClientConnected { - context: GrpcRequestContext, - location: WireguardNetwork, - device: Device, - }, - ClientDisconnected { - context: GrpcRequestContext, - location: WireguardNetwork, - device: Device, - }, + GatewayConnected { location: WireguardNetwork }, + GatewayDisconnected { location: WireguardNetwork }, } /// Shared context for every event generated from a user request in the bi-directional gRPC stream. diff --git a/crates/defguard_core/src/grpc/gateway/client_state.rs b/crates/defguard_core/src/grpc/gateway/client_state.rs deleted file mode 100644 index f75075f08e..0000000000 --- a/crates/defguard_core/src/grpc/gateway/client_state.rs +++ /dev/null @@ -1,207 +0,0 @@ -use std::{collections::HashMap, net::SocketAddr}; - -use chrono::{NaiveDateTime, TimeDelta, Utc}; -use defguard_common::db::{ - Id, - models::{Device, User, WireguardNetwork, wireguard_peer_stats::WireguardPeerStats}, -}; -use thiserror::Error; -use tonic::{Code, Status}; - -use crate::events::GrpcRequestContext; - -#[derive(Debug, Error)] -pub enum ClientMapError { - #[error("VPN client {public_key} is already connected to location {location_id}")] - ClientAlreadyConnected { public_key: String, location_id: Id }, - #[error("VPN client {public_key} is not connected to location {location_id}")] - ClientNotFound { public_key: String, location_id: Id }, - #[error("Client state for location {location_id} not found")] - LocationNotFound { location_id: Id }, -} - -impl From for Status { - fn from(value: ClientMapError) -> Self { - Self::new(Code::Internal, value.to_string()) - } -} - -/// Represents current information about a connected VPN client -#[derive(Debug, Serialize, Clone)] -pub struct ClientState { - pub device: Device, - pub user_id: Id, - pub username: String, - // current IP & port from which the client is connecting - pub endpoint: SocketAddr, - pub latest_handshake: NaiveDateTime, - // when last stats update was received - pub latest_update: NaiveDateTime, - // total bytes sent to peer - pub total_upload: i64, - // total bytes received from peer - pub total_download: i64, -} - -impl ClientState { - #[must_use] - pub fn new( - device: Device, - user: &User, - endpoint: SocketAddr, - latest_handshake: NaiveDateTime, - total_upload: i64, - total_download: i64, - ) -> Self { - let latest_update = Utc::now().naive_utc(); - Self { - device, - user_id: user.id, - username: user.username.clone(), - endpoint, - latest_handshake, - latest_update, - total_upload, - total_download, - } - } - - pub fn update_client_state( - &mut self, - current_device: Device, - current_endpoint: SocketAddr, - latest_handshake: NaiveDateTime, - upload: i64, - download: i64, - ) { - self.latest_update = Utc::now().naive_utc(); - self.device = current_device; - self.endpoint = current_endpoint; - self.latest_handshake = latest_handshake; - self.total_upload = upload; - self.total_download = download; - } -} - -/// Helper struct used to handle connected VPN clients state -/// Clients are grouped by location ID -type ClientPubKey = String; -#[derive(Debug, Default, Serialize, Clone)] -pub struct ClientMap(HashMap>); - -impl ClientMap { - #[must_use] - pub fn new() -> Self { - Self(HashMap::new()) - } - - pub fn get_vpn_client( - &mut self, - location_id: Id, - client_pubkey: &str, - ) -> Option<&mut ClientState> { - self.0 - .get_mut(&location_id) - .and_then(|location_map| location_map.get_mut(client_pubkey)) - } - - /// Adds newly connected VPN client to client state map - pub fn connect_vpn_client( - &mut self, - location_id: Id, - gateway_hostname: &str, - public_key: &str, - device: &Device, - user: &User, - endpoint: SocketAddr, - stats: &WireguardPeerStats, - ) -> Result<(), ClientMapError> { - info!( - "VPN client {} with public key {public_key} connected to location {location_id} \ - through Gateway {gateway_hostname}", - device.name - ); - - // initialize location map if it doesn't exist yet - let location_map = if let Some(location_map) = self.0.get_mut(&location_id) { - location_map - } else { - // initialize new map for location and immediately return a mutable reference - self.0.insert(location_id, HashMap::new()); - self.0.get_mut(&location_id).unwrap() - }; - - // check if client is already connected - if location_map.contains_key(public_key) { - return Err(ClientMapError::ClientAlreadyConnected { - public_key: public_key.to_string(), - location_id, - }); - } - - // add client state to location map - let client_state = ClientState::new( - device.clone(), - user, - endpoint, - stats.latest_handshake, - stats.upload, - stats.download, - ); - location_map.insert(public_key.to_string(), client_state); - - Ok(()) - } - - /// Removes all disconnected clients for a given location. - /// - /// A client is considered disconnected if there have not been any stats received for it in more than `peer_disconnect_threshold_secs`. - /// - /// Returns a list of devices. - pub fn disconnect_inactive_vpn_clients_for_location( - &mut self, - location: &WireguardNetwork, - ) -> Result, GrpcRequestContext)>, ClientMapError> { - debug!( - "Disconnecting inactive VPN clients for location {}", - location.id - ); - let peer_disconnect_threshold_secs = location.peer_disconnect_threshold; - - // initialize result - let mut disconnected_clients = Vec::new(); - - // get client state map for given location - if let Some(location_map) = self.0.get_mut(&location.id) { - let disconnect_threshold = TimeDelta::seconds(peer_disconnect_threshold_secs.into()); - - // remove clients which have been inactive longer than given location's `peer_disconnect_threshold` - location_map.retain(|public_key, client_state| { - let now = Utc::now().naive_utc(); - if (now - client_state.latest_handshake) > disconnect_threshold { - debug!("VPN client's {public_key} ({}, ID {}) latest handshake ({}) was more than {peer_disconnect_threshold_secs} seconds ago. Marking VPN client as disconnected", client_state.device.name, client_state.device.id, client_state.latest_handshake); - let disconnect_event_context = GrpcRequestContext::new( - client_state.user_id, - client_state.username.clone(), - client_state.endpoint.ip(), - client_state.device.id, - client_state.device.name.clone(), - location.clone() - ); - disconnected_clients - .push((client_state.device.clone(), disconnect_event_context)); - - return false; - } - true - }); - } - - Ok(disconnected_clients) - } - - #[must_use] - pub fn is_empty(&self) -> bool { - self.0.is_empty() - } -} diff --git a/crates/defguard_core/src/grpc/gateway/handler.rs b/crates/defguard_core/src/grpc/gateway/handler.rs index ae2f6e86da..9167b96f1c 100644 --- a/crates/defguard_core/src/grpc/gateway/handler.rs +++ b/crates/defguard_core/src/grpc/gateway/handler.rs @@ -1,20 +1,16 @@ use std::{ - net::SocketAddr, str::FromStr, - sync::{ - Arc, Mutex, - atomic::{AtomicU64, Ordering}, - }, + sync::atomic::{AtomicU64, Ordering}, }; -use chrono::{DateTime, TimeDelta, Utc}; +use chrono::{DateTime, Utc}; use defguard_certs::{Csr, der_to_pem}; use defguard_common::{ VERSION, db::{ Id, NoId, models::{ - Device, Settings, User, WireguardNetwork, gateway::Gateway, + Device, Settings, WireguardNetwork, gateway::Gateway, wireguard_peer_stats::WireguardPeerStats, }, }, @@ -41,9 +37,8 @@ use tonic::transport::{Certificate, ClientTlsConfig, Endpoint}; use crate::{ enterprise::firewall::try_get_location_firewall_config, - events::GrpcRequestContext, grpc::{ - ClientMap, GrpcEvent, TEN_SECS, + TEN_SECS, gateway::{GatewayError, events::GatewayEvent, get_peers, try_protos_into_stats_message}, }, handlers::mail::send_gateway_disconnected_email, @@ -92,10 +87,8 @@ pub(crate) struct GatewayHandler { gateway: Gateway, message_id: AtomicU64, pool: PgPool, - client_state: Arc>, events_tx: Sender, mail_tx: UnboundedSender, - grpc_event_tx: UnboundedSender, peer_stats_tx: UnboundedSender, } @@ -103,10 +96,8 @@ impl GatewayHandler { pub(crate) fn new( gateway: Gateway, pool: PgPool, - client_state: Arc>, events_tx: Sender, mail_tx: UnboundedSender, - grpc_event_tx: UnboundedSender, peer_stats_tx: UnboundedSender, ) -> Result { let url = Url::from_str(&gateway.url).map_err(|err| { @@ -121,10 +112,8 @@ impl GatewayHandler { gateway, message_id: AtomicU64::new(0), pool, - client_state, events_tx, mail_tx, - grpc_event_tx, peer_stats_tx, }) } @@ -289,48 +278,6 @@ impl GatewayHandler { Ok(device) } - /// Helper method to fetch `WireguardNetwork` info from DB and return appropriate errors - async fn fetch_location_from_db( - &self, - location_id: Id, - ) -> Result, GatewayError> { - let location = match WireguardNetwork::find_by_id(&self.pool, location_id).await? { - Some(location) => location, - None => { - error!("Location {location_id} not found"); - return Err(GatewayError::NotFound(format!( - "Location {location_id} not found" - ))); - } - }; - Ok(location) - } - - /// Helper method to fetch `User` info from DB and return appropriate errors - async fn fetch_user_from_db( - &self, - user_id: Id, - public_key: &str, - ) -> Result, GatewayError> { - let user = match User::find_by_id(&self.pool, user_id).await? { - Some(user) => user, - None => { - error!("User {user_id} assigned to device with public key {public_key} not found"); - return Err(GatewayError::NotFound(format!( - "User assigned to device with public key {public_key} not found" - ))); - } - }; - - Ok(user) - } - - fn emit_event(&self, event: GrpcEvent) { - if self.grpc_event_tx.send(event).is_err() { - warn!("Failed to send gRPC event"); - } - } - pub(crate) async fn handle_setup(&mut self) -> Result<(), GatewayError> { debug!("Handling initial setup for Gateway {}", self.gateway); let endpoint = self.endpoint(Scheme::Http)?; @@ -543,126 +490,14 @@ impl GatewayHandler { // copy device ID for easier reference later let device_id = device.id; - // fetch user and location from DB for activity log - // TODO: cache usernames since they don't change - let Ok(user) = - self.fetch_user_from_db(device.user_id, &public_key).await - else { - continue; - }; - let Ok(location) = - self.fetch_location_from_db(self.gateway.network_id).await - else { - continue; - }; - // Convert stats to database storage format. + // FIXME: remove once legacy table is removed let stats = peer_stats_from_proto( peer_stats.clone(), self.gateway.network_id, device_id, ); - // Only perform client state update if stats include an endpoint IP. - // Otherwise, a peer was added to the gateway interface, but hasn't - // connected yet. - if let Some(endpoint) = &stats.endpoint { - // parse client endpoint IP - let Ok(socket_addr) = endpoint.clone().parse::() - else { - error!("Failed to parse VPN client endpoint"); - continue; - }; - - // Perform client state operations in a dedicated block to drop - // mutex guard. - let disconnected_clients = { - // acquire lock on client state map - let mut client_map = self.client_state.lock().unwrap(); - - // update connected clients map - match client_map - .get_vpn_client(self.gateway.network_id, &public_key) - { - Some(client_state) => { - // update connected client state - client_state.update_client_state( - device, - socket_addr, - stats.latest_handshake, - stats.upload, - stats.download, - ); - } - None => { - // don't mark inactive peers as connected - if (Utc::now().naive_utc() - stats.latest_handshake) - < TimeDelta::seconds( - location.peer_disconnect_threshold.into(), - ) - { - // mark new VPN client as connected - if client_map - .connect_vpn_client( - self.gateway.network_id, - // Hostname is for logging only. - &self - .gateway - .hostname - .clone() - .unwrap_or_default(), - &public_key, - &device, - &user, - socket_addr, - &stats, - ) - .is_err() - { - // TODO: log message - continue; - } - - // emit connection event - let context = GrpcRequestContext::new( - user.id, - user.username.clone(), - socket_addr.ip(), - device.id, - device.name.clone(), - location.clone(), - ); - self.emit_event(GrpcEvent::ClientConnected { - context, - location: location.clone(), - device: device.clone(), - }); - } - } - } - - // disconnect inactive clients - let Ok(clients) = client_map - .disconnect_inactive_vpn_clients_for_location( - &location, - ) - else { - // TODO: log message - continue; - }; - clients - }; - - // emit client disconnect events - for (device, context) in disconnected_clients { - self.emit_event(GrpcEvent::ClientDisconnected { - context, - location: location.clone(), - device, - }); - } - } - // convert stats to DB storage format match try_protos_into_stats_message( peer_stats.clone(), @@ -686,6 +521,7 @@ impl GatewayHandler { }; // Save stats to database. + // FIXME: remove once legacy table is removed let stats = match stats.save(&self.pool).await { Ok(stats) => stats, Err(err) => { diff --git a/crates/defguard_core/src/grpc/gateway/mod.rs b/crates/defguard_core/src/grpc/gateway/mod.rs index 60bfc863aa..73e95b9fa6 100644 --- a/crates/defguard_core/src/grpc/gateway/mod.rs +++ b/crates/defguard_core/src/grpc/gateway/mod.rs @@ -1,9 +1,4 @@ -use std::{ - collections::HashMap, - net::IpAddr, - sync::{Arc, Mutex}, - time::Duration, -}; +use std::{collections::HashMap, net::IpAddr, time::Duration}; use chrono::DateTime; use defguard_common::{ @@ -36,10 +31,9 @@ use tonic::{Code, Status}; use crate::{ enterprise::{firewall::FirewallError, is_enterprise_license_active}, events::GrpcEvent, - grpc::gateway::{client_state::ClientMap, events::GatewayEvent, handler::GatewayHandler}, + grpc::gateway::{events::GatewayEvent, handler::GatewayHandler}, }; -pub mod client_state; pub mod events; pub(crate) mod handler; // #[cfg(test)] @@ -68,32 +62,6 @@ pub fn send_multiple_wireguard_events(events: Vec, wg_tx: &Sender< } } -// Helper used to convert peer stats coming from gRPC client -// into an internal representation -// fn protos_into_internal_stats( -// proto_stats: PeerStats, -// location_id: Id, -// device_id: Id, -// ) -> WireguardPeerStats { -// let endpoint = match proto_stats.endpoint { -// endpoint if endpoint.is_empty() => None, -// _ => Some(proto_stats.endpoint), -// }; -// WireguardPeerStats { -// id: NoId, -// network: location_id, -// endpoint, -// device_id, -// collected_at: Utc::now().naive_utc(), -// upload: proto_stats.upload as i64, -// download: proto_stats.download as i64, -// latest_handshake: DateTime::from_timestamp(proto_stats.latest_handshake as i64, 0) -// .unwrap_or_default() -// .naive_utc(), -// allowed_ips: Some(proto_stats.allowed_ips), -// } -// } - /// Helper used to convert peer stats coming from gRPC client /// into an internal representation fn try_protos_into_stats_message( @@ -255,10 +223,8 @@ const GATEWAY_RECONNECT_DELAY: Duration = Duration::from_secs(5); /// Bi-directional gRPC stream for communication with Defguard Gateway. pub async fn run_grpc_gateway_stream( pool: PgPool, - client_state: Arc>, events_tx: Sender, mail_tx: UnboundedSender, - grpc_event_tx: UnboundedSender, peer_stats_tx: UnboundedSender, ) -> Result<(), anyhow::Error> { let mut abort_handles = HashMap::new(); @@ -269,10 +235,8 @@ pub async fn run_grpc_gateway_stream( let mut gateway_handler = GatewayHandler::new( gateway, pool.clone(), - Arc::clone(&client_state), events_tx.clone(), mail_tx.clone(), - grpc_event_tx.clone(), peer_stats_tx.clone(), )?; let abort_handle = tasks.spawn(async move { diff --git a/crates/defguard_core/src/grpc/mod.rs b/crates/defguard_core/src/grpc/mod.rs index 73edecd08d..db9308c44b 100644 --- a/crates/defguard_core/src/grpc/mod.rs +++ b/crates/defguard_core/src/grpc/mod.rs @@ -27,8 +27,6 @@ use crate::{ }, is_business_license_active, }, - events::GrpcEvent, - grpc::gateway::client_state::ClientMap, server_config, }; diff --git a/crates/defguard_event_logger/Cargo.toml b/crates/defguard_event_logger/Cargo.toml index 4f26e19fb5..6761794377 100644 --- a/crates/defguard_event_logger/Cargo.toml +++ b/crates/defguard_event_logger/Cargo.toml @@ -11,6 +11,7 @@ rust-version.workspace = true # internal crates defguard_common.workspace = true defguard_core.workspace = true +defguard_session_manager.workspace = true # external dependencies bytes.workspace = true diff --git a/crates/defguard_event_logger/src/message.rs b/crates/defguard_event_logger/src/message.rs index 1befe676c0..3d736281eb 100644 --- a/crates/defguard_event_logger/src/message.rs +++ b/crates/defguard_event_logger/src/message.rs @@ -19,6 +19,7 @@ use defguard_core::{ InternalEventContext, }, }; +use defguard_session_manager::events::SessionManagerEventContext; /// Messages that can be sent to the event logger pub struct EventLoggerMessage { @@ -58,7 +59,7 @@ impl EventContext { ) -> Self { let location = location.map(|location| location.name); - EventContext { + Self { timestamp: val.timestamp, user_id: val.user_id, username: val.username, @@ -75,7 +76,7 @@ impl EventContext { ) -> Self { let location = location.map(|location| location.name); - EventContext { + Self { timestamp: val.timestamp, user_id: val.user_id, username: val.username, @@ -92,7 +93,7 @@ impl EventContext { ) -> Self { let location = location.map(|location| location.name); - EventContext { + Self { timestamp: val.timestamp, user_id: val.user_id, username: val.username, @@ -101,6 +102,18 @@ impl EventContext { device: format!("{} (ID {})", val.device.name, val.device.id), } } + + #[must_use] + pub fn from_session_manager_context(val: SessionManagerEventContext) -> Self { + Self { + timestamp: val.timestamp, + user_id: val.user.id, + username: val.user.username, + location: Some(val.location.name), + ip: val.public_ip, + device: format!("{} (ID {})", val.device.name, val.device.id), + } + } } impl From for EventContext { diff --git a/crates/defguard_event_router/Cargo.toml b/crates/defguard_event_router/Cargo.toml index c4c9b3c150..389ddaf36e 100644 --- a/crates/defguard_event_router/Cargo.toml +++ b/crates/defguard_event_router/Cargo.toml @@ -12,6 +12,7 @@ rust-version.workspace = true defguard_core = { workspace = true } defguard_event_logger = { workspace = true } defguard_mail = { workspace = true } +defguard_session_manager = { workspace = true } # external dependencies thiserror = { workspace = true } diff --git a/crates/defguard_event_router/src/error.rs b/crates/defguard_event_router/src/error.rs index e8b86c277f..5dda3a09e7 100644 --- a/crates/defguard_event_router/src/error.rs +++ b/crates/defguard_event_router/src/error.rs @@ -4,8 +4,6 @@ use thiserror::Error; pub enum EventRouterError { #[error("API event channel closed")] ApiEventChannelClosed, - #[error("gRPC event channel closed")] - GrpcEventChannelClosed, #[error("Bidi gRPC stream event channel closed")] BidiEventChannelClosed, #[error("Internal event channel closed")] diff --git a/crates/defguard_event_router/src/events.rs b/crates/defguard_event_router/src/events.rs index d911b82586..ccbf7c794c 100644 --- a/crates/defguard_event_router/src/events.rs +++ b/crates/defguard_event_router/src/events.rs @@ -1,12 +1,14 @@ -use defguard_core::events::{ApiEvent, BidiStreamEvent, GrpcEvent, InternalEvent}; +use defguard_core::events::{ApiEvent, BidiStreamEvent, InternalEvent}; +use defguard_session_manager::events::SessionManagerEvent; /// Enum representing all possible events that can be generated in the system. /// /// System components can send events to the event router through their own event channels. /// The enum itself is organized based on event source to make splitting logic into smaller chunks easier. +#[derive(Debug)] pub enum Event { Api(ApiEvent), - Grpc(Box), Bidi(BidiStreamEvent), Internal(Box), + SessionManager(Box), } diff --git a/crates/defguard_event_router/src/handlers/grpc.rs b/crates/defguard_event_router/src/handlers/grpc.rs deleted file mode 100644 index 1fda6b4441..0000000000 --- a/crates/defguard_event_router/src/handlers/grpc.rs +++ /dev/null @@ -1,41 +0,0 @@ -use defguard_core::events::GrpcEvent; -use defguard_event_logger::message::{LoggerEvent, VpnEvent}; -use tracing::debug; - -use crate::{EventRouter, error::EventRouterError}; - -impl EventRouter { - pub(crate) fn handle_grpc_event(&self, event: GrpcEvent) -> Result<(), EventRouterError> { - debug!("Processing gRPC server event: {event:?}"); - - match event { - GrpcEvent::GatewayConnected { location: _ } => todo!(), - GrpcEvent::GatewayDisconnected { location: _ } => todo!(), - GrpcEvent::ClientConnected { - context, - location, - device, - } => { - self.log_event( - context.into(), - LoggerEvent::Vpn(Box::new(VpnEvent::ConnectedToLocation { location, device })), - )?; - } - GrpcEvent::ClientDisconnected { - context, - location, - device, - } => { - self.log_event( - context.into(), - LoggerEvent::Vpn(Box::new(VpnEvent::DisconnectedFromLocation { - location, - device, - })), - )?; - } - } - - Ok(()) - } -} diff --git a/crates/defguard_event_router/src/handlers/mod.rs b/crates/defguard_event_router/src/handlers/mod.rs index 133a412811..0c40a83878 100644 --- a/crates/defguard_event_router/src/handlers/mod.rs +++ b/crates/defguard_event_router/src/handlers/mod.rs @@ -1,4 +1,4 @@ pub(crate) mod api; pub(crate) mod bidi; -pub(crate) mod grpc; pub(crate) mod internal; +pub(crate) mod session_manager; diff --git a/crates/defguard_event_router/src/handlers/session_manager.rs b/crates/defguard_event_router/src/handlers/session_manager.rs new file mode 100644 index 0000000000..e818c4d0f4 --- /dev/null +++ b/crates/defguard_event_router/src/handlers/session_manager.rs @@ -0,0 +1,38 @@ +use defguard_event_logger::message::{EventContext, LoggerEvent, VpnEvent}; +use defguard_session_manager::events::{SessionManagerEvent, SessionManagerEventType}; +use tracing::debug; + +use crate::{EventRouter, error::EventRouterError}; + +impl EventRouter { + pub(crate) fn handle_session_manager_event( + &self, + event: SessionManagerEvent, + ) -> Result<(), EventRouterError> { + debug!("Processing session manager event: {event:?}"); + + let SessionManagerEvent { context, event } = event; + + // FIXME: consider if we actually need this as part of event since we have the context anyway + let location = context.location.clone(); + let device = context.device.clone(); + + let logger_event = match event { + SessionManagerEventType::ClientConnected => { + LoggerEvent::Vpn(Box::new(VpnEvent::ConnectedToLocation { location, device })) + } + SessionManagerEventType::ClientDisconnected => { + LoggerEvent::Vpn(Box::new(VpnEvent::DisconnectedFromLocation { + location, + device, + })) + } + SessionManagerEventType::MfaClientConnected => todo!(), + SessionManagerEventType::MfaClientDisconnected => todo!(), + }; + self.log_event( + EventContext::from_session_manager_context(context), + logger_event, + ) + } +} diff --git a/crates/defguard_event_router/src/lib.rs b/crates/defguard_event_router/src/lib.rs index 636d41de63..132f56d754 100644 --- a/crates/defguard_event_router/src/lib.rs +++ b/crates/defguard_event_router/src/lib.rs @@ -20,11 +20,12 @@ use std::sync::Arc; use defguard_core::{ - events::{ApiEvent, BidiStreamEvent, GrpcEvent, InternalEvent}, + events::{ApiEvent, BidiStreamEvent, InternalEvent}, grpc::gateway::events::GatewayEvent, }; use defguard_event_logger::message::{EventContext, EventLoggerMessage, LoggerEvent}; use defguard_mail::Mail; +use defguard_session_manager::events::SessionManagerEvent; use error::EventRouterError; use events::Event; use tokio::sync::{ @@ -40,24 +41,24 @@ mod handlers; pub struct RouterReceiverSet { api: UnboundedReceiver, - grpc: UnboundedReceiver, bidi: UnboundedReceiver, internal: UnboundedReceiver, + session_manager: UnboundedReceiver, } impl RouterReceiverSet { #[must_use] pub fn new( api: UnboundedReceiver, - grpc: UnboundedReceiver, bidi: UnboundedReceiver, internal: UnboundedReceiver, + session_manager: UnboundedReceiver, ) -> Self { Self { api, - grpc, bidi, internal, + session_manager, } } } @@ -115,10 +116,6 @@ impl EventRouter { error!("API event channel closed"); return Err(EventRouterError::ApiEventChannelClosed); }, - event = self.receivers.grpc.recv() => if let Some(grpc_event) = event { Event::Grpc(Box::new(grpc_event)) } else { - error!("gRPC event channel closed"); - return Err(EventRouterError::GrpcEventChannelClosed); - }, event = self.receivers.bidi.recv() => if let Some(bidi_event) = event { Event::Bidi(bidi_event) } else { error!("Bidi gRPC stream event channel closed"); return Err(EventRouterError::BidiEventChannelClosed); @@ -127,16 +124,22 @@ impl EventRouter { error!("Internal event channel closed"); return Err(EventRouterError::InternalEventChannelClosed); }, + event = self.receivers.session_manager.recv() => if let Some(session_manager_event) = event { Event::SessionManager(Box::new(session_manager_event)) } else { + error!("Internal event channel closed"); + return Err(EventRouterError::InternalEventChannelClosed); + }, }; - debug!("Received event"); + debug!("Received event: {event:?}"); // Route the event to the appropriate handler match event { Event::Api(api_event) => self.handle_api_event(api_event)?, - Event::Grpc(grpc_event) => self.handle_grpc_event(*grpc_event)?, Event::Bidi(bidi_event) => self.handle_bidi_event(bidi_event)?, Event::Internal(internal_event) => self.handle_internal_event(*internal_event)?, + Event::SessionManager(session_manager_event) => { + self.handle_session_manager_event(*session_manager_event)? + } } } } diff --git a/crates/defguard_session_manager/src/error.rs b/crates/defguard_session_manager/src/error.rs index 95bf4f5496..e7fa251966 100644 --- a/crates/defguard_session_manager/src/error.rs +++ b/crates/defguard_session_manager/src/error.rs @@ -1,5 +1,8 @@ use defguard_common::db::Id; use thiserror::Error; +use tokio::sync::mpsc::error::SendError; + +use crate::events::SessionManagerEvent; #[derive(Debug, Error)] pub enum SessionManagerError { @@ -21,4 +24,12 @@ pub enum SessionManagerError { LocationDoesNotExistError(Id), #[error("Received out of order peer stats update")] PeerStatsUpdateOutOfOrderError, + #[error("Failed to send session manager event: {0}")] + SessionManagerEventError(Box>), +} + +impl From> for SessionManagerError { + fn from(error: SendError) -> Self { + Self::SessionManagerEventError(Box::new(error)) + } } diff --git a/crates/defguard_session_manager/src/events.rs b/crates/defguard_session_manager/src/events.rs new file mode 100644 index 0000000000..e22495fa41 --- /dev/null +++ b/crates/defguard_session_manager/src/events.rs @@ -0,0 +1,30 @@ +use std::net::IpAddr; + +use chrono::NaiveDateTime; +use defguard_common::db::{ + Id, + models::{Device, User, WireguardNetwork}, +}; + +#[derive(Debug)] +pub struct SessionManagerEvent { + pub context: SessionManagerEventContext, + pub event: SessionManagerEventType, +} + +#[derive(Debug)] +pub struct SessionManagerEventContext { + pub timestamp: NaiveDateTime, + pub location: WireguardNetwork, + pub user: User, + pub device: Device, + pub public_ip: IpAddr, +} + +#[derive(Debug)] +pub enum SessionManagerEventType { + ClientConnected, + ClientDisconnected, + MfaClientConnected, + MfaClientDisconnected, +} diff --git a/crates/defguard_session_manager/src/lib.rs b/crates/defguard_session_manager/src/lib.rs index 93ee0ce504..31563e1f6c 100644 --- a/crates/defguard_session_manager/src/lib.rs +++ b/crates/defguard_session_manager/src/lib.rs @@ -1,21 +1,28 @@ +use std::net::{IpAddr, Ipv4Addr}; + use chrono::Utc; use defguard_common::{ db::{ Id, - models::{WireguardNetwork, vpn_client_session::VpnClientSession}, + models::{Device, User, WireguardNetwork, vpn_client_session::VpnClientSession}, }, messages::peer_stats_update::PeerStatsUpdate, }; use sqlx::{PgConnection, PgPool}; use tokio::{ - sync::mpsc::UnboundedReceiver, + sync::mpsc::{UnboundedReceiver, UnboundedSender}, time::{Duration, interval}, }; use tracing::{debug, error, info, trace, warn}; -use crate::{error::SessionManagerError, session_state::ActiveSessionsMap}; +use crate::{ + error::SessionManagerError, + events::{SessionManagerEvent, SessionManagerEventContext, SessionManagerEventType}, + session_state::ActiveSessionsMap, +}; pub mod error; +pub mod events; pub mod session_state; const MESSAGE_LIMIT: usize = 100; @@ -24,12 +31,13 @@ const SESSION_UPDATE_INTERVAL: u64 = 60; pub async fn run_session_manager( pool: PgPool, mut peer_stats_rx: UnboundedReceiver, + session_manager_event_tx: UnboundedSender, ) -> Result<(), SessionManagerError> { info!("Starting VPN client session manager service"); let mut session_update_timer = interval(Duration::from_secs(SESSION_UPDATE_INTERVAL)); // initialize session manager - let mut session_manager = SessionManager::new(pool).await?; + let mut session_manager = SessionManager::new(pool, session_manager_event_tx); loop { // receive next batch of peer stats messages @@ -60,18 +68,15 @@ pub async fn run_session_manager( struct SessionManager { pool: PgPool, - // active_sessions: LocationSessionsMap, + session_manager_event_tx: UnboundedSender, } impl SessionManager { - async fn new(pool: PgPool) -> Result { - // initialize active sessions state based on DB content - // let active_sessions = LocationSessionsMap::initialize_from_db(&pool).await?; - - Ok(Self { + fn new(pool: PgPool, session_manager_event_tx: UnboundedSender) -> Self { + Self { pool, - // active_sessions, - }) + session_manager_event_tx, + } } /// Helper function for processing all messages read from the channel in a single batch @@ -129,7 +134,7 @@ impl SessionManager { message.device_id, message.location_id ); active_sessions - .try_add_new_session(transaction, &message) + .try_add_new_session(transaction, &message, &self.session_manager_event_tx) .await? } }; @@ -177,7 +182,8 @@ impl SessionManager { "Disconnecting inactive session for user {}, device {} in location {location}", session.user_id, session.device_id ); - Self::disconnect_session(&mut transaction, session).await?; + self.disconnect_session(&mut transaction, session, &location) + .await?; } // get all sessions which were created but have never connected @@ -195,7 +201,8 @@ impl SessionManager { "Disconnecting never connected session for user {}, device {} in location {location}", session.user_id, session.device_id ); - Self::disconnect_session(&mut transaction, session).await?; + self.disconnect_session(&mut transaction, session, &location) + .await?; } } @@ -209,13 +216,44 @@ impl SessionManager { /// Helper user to mark session as disconnected and trigger necessary sideffects async fn disconnect_session( + &self, transaction: &mut PgConnection, mut session: VpnClientSession, + location: &WireguardNetwork, ) -> Result<(), SessionManagerError> { - session.disconnected_at = Some(Utc::now().naive_utc()); + let disconnect_timestamp = Utc::now().naive_utc(); + + // update session record in DB + session.disconnected_at = Some(disconnect_timestamp); session.state = defguard_common::db::models::vpn_client_session::VpnClientSessionState::Disconnected; session.save(&mut *transaction).await?; + + // fetch related objects necessary for event context + let user = User::find_by_id(&mut *transaction, session.user_id) + .await? + .ok_or(SessionManagerError::UserDoesNotExistError(session.user_id))?; + let device = Device::find_by_id(&mut *transaction, session.device_id) + .await? + .ok_or(SessionManagerError::DeviceDoesNotExistError( + session.device_id, + ))?; + + // emit event + let context = SessionManagerEventContext { + timestamp: disconnect_timestamp, + location: location.clone(), + user, + device, + // FIXME: this is a workaround since we require an IP for each audit log event + public_ip: IpAddr::V4(Ipv4Addr::UNSPECIFIED), + }; + let event = SessionManagerEvent { + context, + event: SessionManagerEventType::ClientDisconnected, + }; + self.session_manager_event_tx.send(event)?; + Ok(()) } } diff --git a/crates/defguard_session_manager/src/session_state.rs b/crates/defguard_session_manager/src/session_state.rs index 66981e08ff..c11815c5cf 100644 --- a/crates/defguard_session_manager/src/session_state.rs +++ b/crates/defguard_session_manager/src/session_state.rs @@ -12,9 +12,13 @@ use defguard_common::{ messages::peer_stats_update::PeerStatsUpdate, }; use sqlx::{PgConnection, types::chrono::Utc}; +use tokio::sync::mpsc::UnboundedSender; use tracing::{debug, warn}; -use crate::error::SessionManagerError; +use crate::{ + error::SessionManagerError, + events::{SessionManagerEvent, SessionManagerEventContext, SessionManagerEventType}, +}; struct LastStatsUpdate { collected_at: NaiveDateTime, @@ -214,32 +218,36 @@ impl ActiveSessionsMap { &mut self, transaction: &mut PgConnection, stats_update: &PeerStatsUpdate, + event_tx: &UnboundedSender, ) -> Result, SessionManagerError> { // fetch location let location_id = stats_update.location_id; - // wrap in block to avoid multiple mutable borrows - let (location_name, mfa_mode) = { - let location = self.get_location(&mut *transaction, location_id).await?; - // check if a given peer is considered active and should be added to active sessions - if Utc::now().naive_utc() - stats_update.latest_handshake - > TimeDelta::seconds(location.peer_disconnect_threshold.into()) - { - warn!( - "Received peer stats update for an inactive peer. Skipping creating a new session..." - ); - return Ok(None); - }; - (location.name.clone(), location.location_mfa_mode.clone()) + let location = self + .get_location(&mut *transaction, location_id) + .await? + .clone(); + + // check if a given peer is considered active and should be added to active sessions + if Utc::now().naive_utc() - stats_update.latest_handshake + > TimeDelta::seconds(location.peer_disconnect_threshold.into()) + { + warn!( + "Received peer stats update for an inactive peer. Skipping creating a new session..." + ); + return Ok(None); }; // fetch other related objects from DB + // clone them because we'll need those for event context let device_id = stats_update.device_id; - // wrap in block to avoid multiple mutable borrows - let user_id = { self.get_device(&mut *transaction, device_id).await?.user_id }; - let user = self.get_user(&mut *transaction, user_id).await?; + let device = self.get_device(&mut *transaction, device_id).await?.clone(); + let user = self + .get_user(&mut *transaction, device.user_id) + .await? + .clone(); - debug!("Adding new VPN client session for location {location_name}"); + debug!("Adding new VPN client session for location {location}"); // create a client session object and save it to DB let session = VpnClientSession::new( @@ -247,7 +255,7 @@ impl ActiveSessionsMap { user.id, device_id, Some(stats_update.latest_handshake), - mfa_mode, + location.location_mfa_mode.clone(), ) .save(transaction) .await?; @@ -255,10 +263,26 @@ impl ActiveSessionsMap { // add to session map let session_state = SessionState::new(session.id); let session_map = self.get_or_create_location_session_map(location_id); - let maybe_existing_session = session_map.insert(device_id, session_state); + let maybe_existing_session = session_map.insert(device.id, session_state); + // if a session exists already there was an error in earlier logic assert!(maybe_existing_session.is_none()); + // emit event + let public_ip = stats_update.endpoint.ip(); + let context = SessionManagerEventContext { + timestamp: stats_update.latest_handshake, + location, + user, + device, + public_ip, + }; + let event = SessionManagerEvent { + context, + event: SessionManagerEventType::ClientConnected, + }; + event_tx.send(event)?; + Ok(session_map.0.get_mut(&device_id)) } diff --git a/flake.lock b/flake.lock index 023f36f5b1..96f53666ca 100644 --- a/flake.lock +++ b/flake.lock @@ -32,11 +32,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1768564909, - "narHash": "sha256-Kell/SpJYVkHWMvnhqJz/8DqQg2b6PguxVWOuadbHCc=", + "lastModified": 1769170682, + "narHash": "sha256-oMmN1lVQU0F0W2k6OI3bgdzp2YOHWYUAw79qzDSjenU=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "e4bae1bd10c9c57b2cf517953ab70060a828ee6f", + "rev": "c5296fdd05cfa2c187990dd909864da9658df755", "type": "github" }, "original": { @@ -74,11 +74,11 @@ ] }, "locked": { - "lastModified": 1768877311, - "narHash": "sha256-abSDl0cNr0B+YCsIDpO1SjXD9JMxE4s8EFnhLEFVovI=", + "lastModified": 1769396217, + "narHash": "sha256-YNzh46h8fby49yOIB40lNoQ9ucVoXe1bHVwkZ4AwGe0=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "59e4ab96304585fde3890025fd59bd2717985cc1", + "rev": "e9bcd12156a577ac4e47d131c14dc0293cc9c8c2", "type": "github" }, "original": {