diff --git a/crates/defguard_core/src/enterprise/grpc/desktop_client_mfa.rs b/crates/defguard_core/src/enterprise/grpc/desktop_client_mfa.rs index fc975e5425..e23c402c20 100644 --- a/crates/defguard_core/src/enterprise/grpc/desktop_client_mfa.rs +++ b/crates/defguard_core/src/enterprise/grpc/desktop_client_mfa.rs @@ -42,7 +42,13 @@ impl ClientMfaServer { let pubkey = Self::parse_token(&token)?; // fetch login session - let Some(session) = self.sessions.get(&pubkey).cloned() else { + let Some(session) = self + .sessions + .read() + .expect("Failed to read-lock ClientMfaServer::sessions") + .get(&pubkey) + .cloned() + else { debug!("Client login session not found"); return Err(Status::invalid_argument("login session not found")); }; @@ -62,7 +68,10 @@ impl ClientMfaServer { if method != MfaMethod::Oidc { debug!("Invalid MFA method for OIDC authentication: {method:?}"); - self.sessions.remove(&pubkey); + self.sessions + .write() + .expect("Failed to write-lock ClientMfaServer::sessions") + .remove(&pubkey); return Err(Status::invalid_argument("invalid MFA method")); } @@ -81,7 +90,10 @@ impl ClientMfaServer { }) { Ok(url) => url, Err(status) => { - self.sessions.remove(&pubkey); + self.sessions + .write() + .expect("Failed to write-lock ClientMfaServer::sessions") + .remove(&pubkey); self.emit_event(BidiStreamEvent { context, event: BidiStreamEventType::DesktopClientMfa(Box::new( @@ -102,7 +114,10 @@ impl ClientMfaServer { // if thats not our user, prevent login if claims_user.id != user.id { info!("User {claims_user} tried to use OIDC MFA for another user: {user}"); - self.sessions.remove(&pubkey); + self.sessions + .write() + .expect("Failed to write-lock ClientMfaServer::sessions") + .remove(&pubkey); self.emit_event(BidiStreamEvent { context, event: BidiStreamEventType::DesktopClientMfa(Box::new( @@ -123,7 +138,10 @@ impl ClientMfaServer { } Err(err) => { info!("Failed to verify OIDC code: {err}"); - self.sessions.remove(&pubkey); + self.sessions + .write() + .expect("Failed to write-lock ClientMfaServer::sessions") + .remove(&pubkey); self.emit_event(BidiStreamEvent { context, event: BidiStreamEventType::DesktopClientMfa(Box::new( @@ -139,17 +157,20 @@ impl ClientMfaServer { } } - self.sessions.insert( - pubkey.clone(), - ClientLoginSession { - method, - device: device.clone(), - location: location.clone(), - user: user.clone(), - openid_auth_completed: true, - biometric_challenge: None, - }, - ); + self.sessions + .write() + .expect("Failed to write-lock ClientMfaServer::sessions") + .insert( + pubkey.clone(), + ClientLoginSession { + method, + device: device.clone(), + location: location.clone(), + user: user.clone(), + openid_auth_completed: true, + biometric_challenge: None, + }, + ); Ok(()) } diff --git a/crates/defguard_core/src/grpc/proxy/client_mfa.rs b/crates/defguard_core/src/grpc/proxy/client_mfa.rs index 6d3432af46..667314eaf4 100644 --- a/crates/defguard_core/src/grpc/proxy/client_mfa.rs +++ b/crates/defguard_core/src/grpc/proxy/client_mfa.rs @@ -1,4 +1,8 @@ -use std::collections::HashMap; +use std::{ + collections::HashMap, + sync::{Arc, RwLock}, + time::Duration, +}; use chrono::Utc; use defguard_common::{ @@ -15,15 +19,20 @@ use defguard_common::{ }; use defguard_mail::Mail; use defguard_proto::proxy::{ - self, ClientMfaFinishRequest, ClientMfaFinishResponse, ClientMfaStartRequest, - ClientMfaStartResponse, ClientMfaTokenValidationRequest, ClientMfaTokenValidationResponse, - MfaMethod, + self, AwaitRemoteMfaFinishRequest, AwaitRemoteMfaFinishResponse, ClientMfaFinishRequest, + ClientMfaFinishResponse, ClientMfaStartRequest, ClientMfaStartResponse, + ClientMfaTokenValidationRequest, ClientMfaTokenValidationResponse, CoreResponse, MfaMethod, + core_response::Payload, }; use sqlx::PgPool; use thiserror::Error; -use tokio::sync::{ - broadcast::Sender, - mpsc::{UnboundedSender, error::SendError}, +use tokio::{ + sync::{ + broadcast::Sender, + mpsc::{UnboundedSender, error::SendError}, + oneshot, + }, + time, }; use tonic::{Code, Status}; @@ -36,6 +45,9 @@ use crate::{ const CLIENT_SESSION_TIMEOUT: u64 = 60 * 5; // 10 minutes +// How much time the user has to approve remote MFA with mobile device +const REMOTE_AUTH_TIMEOUT: Duration = Duration::from_secs(60); + #[derive(Debug, Error)] pub enum ClientMfaServerError { #[error("gRPC event channel error: {0}")] @@ -49,7 +61,7 @@ impl From for Status { } #[derive(Clone)] -pub(crate) struct ClientLoginSession { +pub struct ClientLoginSession { pub(crate) method: MfaMethod, pub(crate) location: WireguardNetwork, pub(crate) device: Device, @@ -62,7 +74,8 @@ pub struct ClientMfaServer { pub(crate) pool: PgPool, mail_tx: UnboundedSender, wireguard_tx: Sender, - pub(crate) sessions: HashMap, + pub(crate) sessions: Arc>>, + remote_mfa_responses: Arc>>>, bidi_event_tx: UnboundedSender, } @@ -73,13 +86,16 @@ impl ClientMfaServer { mail_tx: UnboundedSender, wireguard_tx: Sender, bidi_event_tx: UnboundedSender, + remote_mfa_responses: Arc>>>, + sessions: Arc>>, ) -> Self { Self { pool, mail_tx, wireguard_tx, bidi_event_tx, - sessions: HashMap::new(), + remote_mfa_responses, + sessions, } } @@ -117,7 +133,11 @@ impl ClientMfaServer { request: ClientMfaTokenValidationRequest, ) -> Result { let pubkey = Self::parse_token(&request.token)?; - let session_active = self.sessions.contains_key(&pubkey); + let session_active = self + .sessions + .read() + .expect("Failed to read-lock ClientMfaServer::sessions") + .contains_key(&pubkey); Ok(ClientMfaTokenValidationResponse { token_valid: session_active, }) @@ -312,17 +332,20 @@ impl ClientMfaServer { .map(|challenge| challenge.challenge.clone()); // store login session - self.sessions.insert( - request.pubkey, - ClientLoginSession { - method: selected_method, - location, - device, - user, - openid_auth_completed: false, - biometric_challenge, - }, - ); + self.sessions + .write() + .expect("Failed to write-lock ClientMfaServer::sessions") + .insert( + request.pubkey, + ClientLoginSession { + method: selected_method, + location, + device, + user, + openid_auth_completed: false, + biometric_challenge, + }, + ); Ok(ClientMfaStartResponse { token, @@ -368,6 +391,45 @@ impl ClientMfaServer { Ok(()) } + #[instrument(skip_all)] + pub async fn await_remote_mfa_login( + &mut self, + request: AwaitRemoteMfaFinishRequest, + response_tx: UnboundedSender, + request_id: u64, + ) -> Result<(), Status> { + debug!("Finishing desktop client login: {request:?}"); + let (tx, rx) = oneshot::channel(); + self.remote_mfa_responses + .write() + .expect("Failed to write-lock ClientMfaServer::remote_mfa_responses") + .insert(request.token.clone(), tx); + + // Spawn a task that waits for remote MFA process to conclude to get the preshared key. + tokio::spawn(async move { + match time::timeout(REMOTE_AUTH_TIMEOUT, rx).await { + Ok(Ok(preshared_key)) => { + let req = CoreResponse { + id: request_id, + payload: Some(Payload::AwaitRemoteMfaFinish( + AwaitRemoteMfaFinishResponse { preshared_key }, + )), + }; + // Once the key is here, send it back to proxy. + let _ = response_tx.send(req); + } + Ok(Err(err)) => { + error!("Remote MFA response channel failed: {err:?}"); + } + Err(_) => { + warn!("Remote MFA process with request_id {request_id} timed out"); + } + } + }); + + Ok(()) + } + #[instrument(skip_all)] pub async fn finish_client_mfa_login( &mut self, @@ -379,7 +441,13 @@ impl ClientMfaServer { let pubkey = Self::parse_token(&request.token)?; // fetch login session - let Some(session) = self.sessions.get(&pubkey) else { + let Some(session) = self + .sessions + .read() + .expect("Failed to read-lock ClientMfaServer::sessions") + .get(&pubkey) + .cloned() + else { error!("Client login session not found"); return Err(Status::invalid_argument("login session not found")); }; @@ -436,7 +504,7 @@ impl ClientMfaServer { DesktopClientMfaEvent::Failed { location: location.clone(), device: device.clone(), - method: *method, + method, message: "Signed challenge rejected".to_string(), }, )), @@ -471,7 +539,7 @@ impl ClientMfaServer { DesktopClientMfaEvent::Failed { location: location.clone(), device: device.clone(), - method: *method, + method, message: "Signed challenge rejected".to_string(), }, )), @@ -491,7 +559,7 @@ impl ClientMfaServer { DesktopClientMfaEvent::Failed { location: location.clone(), device: device.clone(), - method: *method, + method, message: "TOTP code not provided in request".to_string(), }, )), @@ -506,7 +574,7 @@ impl ClientMfaServer { DesktopClientMfaEvent::Failed { location: location.clone(), device: device.clone(), - method: *method, + method, message: "invalid TOTP code".to_string(), }, )), @@ -525,7 +593,7 @@ impl ClientMfaServer { DesktopClientMfaEvent::Failed { location: location.clone(), device: device.clone(), - method: *method, + method, message: "email MFA code not provided in request".to_string(), }, )), @@ -540,7 +608,7 @@ impl ClientMfaServer { DesktopClientMfaEvent::Failed { location: location.clone(), device: device.clone(), - method: *method, + method, message: "invalid email MFA code".to_string(), }, )), @@ -549,7 +617,7 @@ impl ClientMfaServer { } } MfaMethod::Oidc => { - if !*openid_auth_completed { + if !openid_auth_completed { debug!( "User {user} tried to finish OIDC MFA login but they haven't completed \ the OIDC authentication yet." @@ -560,7 +628,7 @@ impl ClientMfaServer { DesktopClientMfaEvent::Failed { location: location.clone(), device: device.clone(), - method: *method, + method, message: "tried to finish OIDC MFA login but they haven't \ completed OIDC authentication yet" .to_string(), @@ -616,7 +684,7 @@ impl ClientMfaServer { network_info: vec![DeviceNetworkInfo { network_id: location.id, device_wireguard_ips: network_device.wireguard_ips, - preshared_key: network_device.preshared_key, + preshared_key: network_device.preshared_key.clone(), is_authorized: network_device.is_authorized, }], }; @@ -638,7 +706,7 @@ impl ClientMfaServer { DesktopClientMfaEvent::Connected { location: location.clone(), device: device.clone(), - method: *method, + method, }, )), })?; @@ -652,7 +720,10 @@ impl ClientMfaServer { }; // remove login session from map - self.sessions.remove(&pubkey); + self.sessions + .write() + .expect("Failed to write-lock ClientMfaServer::sessions") + .remove(&pubkey); // commit transaction transaction.commit().await.map_err(|_| { @@ -660,6 +731,17 @@ impl ClientMfaServer { Status::internal("unexpected error") })?; + // If there is a desktop client websocket waiting for the preshared key, send it. + if let (Some(tx), Some(ref preshared_key)) = ( + self.remote_mfa_responses + .write() + .expect("Failed to write-lock ClientMfaServer::remote_mfa_responses") + .remove(&request.token), + network_device.preshared_key, + ) { + let _ = tx.send(preshared_key.clone()); + } + Ok(response) } } diff --git a/crates/defguard_proxy_manager/src/lib.rs b/crates/defguard_proxy_manager/src/lib.rs index 921930af93..ddb680e1dd 100644 --- a/crates/defguard_proxy_manager/src/lib.rs +++ b/crates/defguard_proxy_manager/src/lib.rs @@ -30,7 +30,10 @@ use defguard_core::{ ldap::utils::ldap_update_user_state, }, events::BidiStreamEvent, - grpc::{gateway::events::GatewayEvent, proxy::client_mfa::ClientMfaServer}, + grpc::{ + gateway::events::GatewayEvent, + proxy::client_mfa::{ClientLoginSession, ClientMfaServer}, + }, version::{IncompatibleComponents, IncompatibleProxyData, is_proxy_version_supported}, }; use defguard_mail::Mail; @@ -53,6 +56,7 @@ use tokio::{ Mutex, broadcast::Sender, mpsc::{self, Receiver, UnboundedSender}, + oneshot, }, task::JoinSet, time::sleep, @@ -119,60 +123,6 @@ pub enum ProxyError { ConnectionTimeout(String), } -/// Maintains routing state for proxy-specific responses by associating -/// correlation tokens with the proxy senders that should receive them. -#[derive(Default)] -struct ProxyRouter { - response_map: HashMap>>, -} - -impl ProxyRouter { - /// Records the proxy sender associated with a request that expects a routed response. - pub(crate) fn register_request( - &mut self, - request: &CoreRequest, - sender: &UnboundedSender, - ) { - match &request.payload { - // Mobile-assisted MFA completion responses must go to the proxy that owns the WebSocket - // so it can send the preshared key. - // Corresponds to the `core_response::Payload::ClientMfaFinish(response)` response. - // https://github.com/DefGuard/defguard/issues/1700 - Some(core_request::Payload::ClientMfaTokenValidation(request)) => { - self.response_map - .insert(request.token.clone(), vec![sender.clone()]); - } - Some(core_request::Payload::ClientMfaFinish(request)) => { - if let Some(senders) = self.response_map.get_mut(&request.token) { - senders.push(sender.clone()); - } - } - _ => {} - } - } - - /// Determines whether the given `CoreResponse` must be routed to a specific proxy instance. - pub(crate) fn route_response( - &mut self, - response: &CoreResponse, - ) -> Option>> { - #[allow(clippy::single_match)] - match &response.payload { - // Mobile-assisted MFA completion responses must go to the proxy that owns the WebSocket - // so it can send the preshared key. - // Corresponds to the `core_request::Payload::ClientMfaTokenValidation(request)` request. - // https://github.com/DefGuard/defguard/issues/1700 - Some(core_response::Payload::ClientMfaFinish(response)) => { - if let Some(ref token) = response.token { - return self.response_map.remove(token); - } - } - _ => {} - } - None - } -} - /// Coordinates communication between the Core and multiple proxy instances. /// /// Responsibilities include: @@ -183,7 +133,6 @@ pub struct ProxyManager { pool: PgPool, tx: ProxyTxSet, incompatible_components: Arc>, - router: Arc>, proxy_control: Receiver, } @@ -198,7 +147,6 @@ impl ProxyManager { pool, tx, incompatible_components, - router: Arc::default(), proxy_control: proxy_control_rx, } } @@ -209,6 +157,8 @@ impl ProxyManager { /// such as routing state and compatibility tracking. pub async fn run(mut self) -> Result<(), ProxyError> { debug!("ProxyManager starting"); + let remote_mfa_responses = Arc::default(); + let sessions = Arc::default(); // Retrieve proxies from DB. let mut shutdown_channels = HashMap::new(); let mut proxies: Vec = Proxy::all(&self.pool) @@ -221,7 +171,8 @@ impl ProxyManager { proxy, self.pool.clone(), &self.tx, - Arc::clone(&self.router), + Arc::clone(&remote_mfa_responses), + Arc::clone(&sessions), Arc::new(Mutex::new(Some(shutdown_rx))), ) }) @@ -232,12 +183,12 @@ impl ProxyManager { if let Some(ref url) = server_config().proxy_url { debug!("Adding proxy from cli arg: {url}"); let url = Url::from_str(url)?; - let proxy = ProxyServer::new( self.pool.clone(), url, &self.tx, - Arc::clone(&self.router), + Arc::clone(&remote_mfa_responses), + Arc::clone(&sessions), // Currently we can't shutdown this proxy since it was started via CLI arguments (no ID in DB) // This should be removed when we do a proper import of old proxies Arc::new(Mutex::new(None)), @@ -275,7 +226,8 @@ impl ProxyManager { &proxy_model, self.pool.clone(), &self.tx, - Arc::clone(&self.router), + Arc::clone(&remote_mfa_responses), + Arc::clone(&sessions), Arc::new(Mutex::new(Some(shutdown_rx))), ) { Ok(proxy) => { @@ -345,8 +297,6 @@ struct ProxyServer { pool: PgPool, /// gRPC servers services: ProxyServices, - /// Router shared between proxies and the proxy manager - router: Arc>, /// Proxy server gRPC URL url: Url, shutdown_signal: Arc>>, @@ -357,16 +307,16 @@ impl ProxyServer { pool: PgPool, url: Url, tx: &ProxyTxSet, - router: Arc>, + remote_mfa_responses: Arc>>>, + sessions: Arc>>, shutdown_signal: Arc>>, ) -> Self { // Instantiate gRPC servers. - let services = ProxyServices::new(&pool, tx); + let services = ProxyServices::new(&pool, tx, remote_mfa_responses, sessions); Self { pool, services, - router, url, shutdown_signal, } @@ -376,11 +326,19 @@ impl ProxyServer { proxy: &Proxy, pool: PgPool, tx: &ProxyTxSet, - router: Arc>, + remote_mfa_responses: Arc>>>, + sessions: Arc>>, shutdown_signal: Arc>>, ) -> Result { let url = Url::from_str(&format!("http://{}:{}", proxy.address, proxy.port))?; - Ok(Self::new(pool, url, tx, router, shutdown_signal)) + Ok(Self::new( + pool, + url, + tx, + remote_mfa_responses, + sessions, + shutdown_signal, + )) } fn endpoint(&self, scheme: Scheme) -> Result { @@ -545,10 +503,6 @@ impl ProxyServer { } Ok(Some(received)) => { debug!("Received message from proxy; ID={}", received.id); - self.router - .write() - .unwrap() - .register_request(&received, &tx); let payload = match received.payload { // rpc CodeMfaSetupStart return (CodeMfaSetupStartResponse) Some(core_request::Payload::CodeMfaSetupStart(request)) => { @@ -738,6 +692,21 @@ impl ProxyServer { } } } + // rpc ClientRemoteMfaFinish (ClientRemoteMfaFinishRequest) returns (ClientRemoteMfaFinishResponse) + Some(core_request::Payload::AwaitRemoteMfaFinish(request)) => { + match self + .services + .client_mfa + .await_remote_mfa_login(request, tx.clone(), received.id) + .await + { + Ok(()) => None, + Err(err) => { + error!("Client remote MFA finish error: {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } // rpc ClientMfaFinish (ClientMfaFinishRequest) returns (ClientMfaFinishResponse) Some(core_request::Payload::ClientMfaFinish(request)) => { match self @@ -962,15 +931,11 @@ impl ProxyServer { None => None, }; - let req = CoreResponse { - id: received.id, - payload, - }; - if let Some(txs) = self.router.write().unwrap().route_response(&req) { - for tx in txs { - let _ = tx.send(req.clone()); - } - } else { + if let Some(payload) = payload { + let req = CoreResponse { + id: received.id, + payload: Some(payload), + }; let _ = tx.send(req); } } @@ -1001,7 +966,12 @@ struct ProxyServices { } impl ProxyServices { - pub fn new(pool: &PgPool, tx: &ProxyTxSet) -> Self { + pub fn new( + pool: &PgPool, + tx: &ProxyTxSet, + remote_mfa_responses: Arc>>>, + sessions: Arc>>, + ) -> Self { let enrollment = EnrollmentServer::new( pool.clone(), tx.wireguard.clone(), @@ -1015,6 +985,8 @@ impl ProxyServices { tx.mail.clone(), tx.wireguard.clone(), tx.bidi_events.clone(), + remote_mfa_responses, + sessions, ); let polling = PollingServer::new(pool.clone()); diff --git a/proto b/proto index 4134358160..0b982922c4 160000 --- a/proto +++ b/proto @@ -1 +1 @@ -Subproject commit 4134358160e4f819515c9a6e5c014434cfb46d74 +Subproject commit 0b982922c4dab3304a8cb01aed1d8cee806600b7