diff --git a/crates/defguard/src/main.rs b/crates/defguard/src/main.rs index b38f497619..b4287f01d0 100644 --- a/crates/defguard/src/main.rs +++ b/crates/defguard/src/main.rs @@ -30,7 +30,7 @@ use defguard_core::{ }; use defguard_event_logger::{message::EventLoggerMessage, run_event_logger}; use defguard_event_router::{RouterReceiverSet, run_event_router}; -use defguard_gateway_manager::GatewayManager; +use defguard_gateway_manager::{GatewayManager, GatewayTxSet}; use defguard_proxy_manager::{ProxyManager, ProxyTxSet}; use defguard_session_manager::{events::SessionManagerEvent, run_session_manager}; use defguard_setup::setup::run_setup_web_server; @@ -172,24 +172,22 @@ async fn main() -> Result<(), anyhow::Error> { } let (proxy_control_tx, proxy_control_rx) = channel::(100); - let proxy_tx = ProxyTxSet::new(gateway_tx.clone(), bidi_event_tx.clone()); let proxy_manager = ProxyManager::new( pool.clone(), - proxy_tx, + ProxyTxSet::new(gateway_tx.clone(), bidi_event_tx.clone()), Arc::clone(&incompatible_components), proxy_control_rx, ); - let mut gateway_manager = GatewayManager::default(); + let mut gateway_manager = GatewayManager::new( + pool.clone(), + GatewayTxSet::new(gateway_tx.clone(), peer_stats_tx), + ); // run services tokio::select! { res = proxy_manager.run() => error!("ProxyManager returned early: {res:?}"), - res = gateway_manager.run( - pool.clone(), - gateway_tx.clone(), - peer_stats_tx, - ) => error!("Gateway gRPC stream returned early: {res:?}"), + res = gateway_manager.run() => error!("GatewayManager returned early: {res:?}"), res = run_grpc_server( Arc::clone(&worker_state), pool.clone(), diff --git a/crates/defguard_gateway_manager/src/lib.rs b/crates/defguard_gateway_manager/src/lib.rs index b100560eb1..5cd8755aef 100644 --- a/crates/defguard_gateway_manager/src/lib.rs +++ b/crates/defguard_gateway_manager/src/lib.rs @@ -15,12 +15,12 @@ use defguard_proto::gateway::gateway_client::GatewayClient; use defguard_version::client::ClientVersionInterceptor; use sqlx::{PgPool, postgres::PgListener}; use tokio::{ - sync::{broadcast::Sender, mpsc::UnboundedSender}, + sync::{broadcast::Sender, mpsc::UnboundedSender, watch::Receiver}, task::{AbortHandle, JoinSet}, }; use tonic::{Request, service::interceptor::InterceptedService, transport::Channel}; -use crate::handler::GatewayHandler; +use crate::{error::GatewayError, handler::GatewayHandler}; #[macro_use] extern crate tracing; @@ -39,22 +39,29 @@ const TEN_SECS: Duration = Duration::from_secs(10); type Client = GatewayClient>; -#[derive(Default)] pub struct GatewayManager { clients: Arc>>, + pool: PgPool, + handlers: JoinSet>, + tx: GatewayTxSet, } impl GatewayManager { + #[must_use] + pub fn new(pool: PgPool, tx: GatewayTxSet) -> Self { + Self { + clients: Arc::default(), + handlers: JoinSet::new(), + pool, + tx, + } + } + /// Bi-directional gRPC stream for communication with Defguard Gateway. - pub async fn run( - &mut self, - pool: PgPool, - events_tx: Sender, - peer_stats_tx: UnboundedSender, - ) -> Result<(), anyhow::Error> { + pub async fn run(&mut self) -> Result<(), anyhow::Error> { let (certs_tx, certs_rx) = tokio::sync::watch::channel(Arc::new(HashMap::new())); - certs::refresh_certs(&pool, &certs_tx).await; - let refresh_pool = pool.clone(); + certs::refresh_certs(&self.pool, &certs_tx).await; + let refresh_pool = self.pool.clone(); tokio::spawn(async move { loop { certs::refresh_certs(&refresh_pool, &certs_tx).await; @@ -62,41 +69,15 @@ impl GatewayManager { } }); let mut abort_handles = HashMap::new(); - - let mut tasks = JoinSet::new(); - // Helper closure to launch `GatewayHandler`. - // TODO: Store arguments in GatewayManager and rewrite this to method - let mut launch_gateway_handler = |gateway: Gateway, - clients: Arc>>| - -> Result { - let mut gateway_handler = GatewayHandler::new( - gateway, - pool.clone(), - events_tx.clone(), - peer_stats_tx.clone(), - certs_rx.clone(), - )?; - let abort_handle = tasks.spawn(async move { - loop { - if let Err(err) = gateway_handler - .handle_connection(Arc::clone(&clients)) - .await - { - error!("Gateway connection error: {err}, retrying in 5 seconds..."); - tokio::time::sleep(GATEWAY_RECONNECT_DELAY).await; - } - } - }); - Ok(abort_handle) - }; - for gateway in Gateway::all(&pool).await? { + for gateway in Gateway::all(&self.pool).await? { let id = gateway.id; - let abort_handle = launch_gateway_handler(gateway, Arc::clone(&self.clients))?; + let abort_handle = + self.run_handler(gateway, Arc::clone(&self.clients), certs_rx.clone())?; abort_handles.insert(id, abort_handle); } // Observe gateway URL changes. - let mut listener = PgListener::connect_with(&pool).await?; + let mut listener = PgListener::connect_with(&self.pool).await?; listener.listen(GATEWAY_TABLE_TRIGGER).await?; while let Ok(notification) = listener.recv().await { let payload = notification.payload(); @@ -106,7 +87,7 @@ impl GatewayManager { if let Some(new) = gateway_notification.new { let id = new.id; let abort_handle = - launch_gateway_handler(new, Arc::clone(&self.clients))?; + self.run_handler(new, Arc::clone(&self.clients), certs_rx.clone())?; abort_handles.insert(id, abort_handle); } } @@ -124,8 +105,11 @@ impl GatewayManager { ); abort_handle.abort(); let id = new.id; - let abort_handle = - launch_gateway_handler(new, Arc::clone(&self.clients))?; + let abort_handle = self.run_handler( + new, + Arc::clone(&self.clients), + certs_rx.clone(), + )?; abort_handles.insert(id, abort_handle); } else { warn!("Cannot find {old} on the list of connected gateways"); @@ -173,10 +157,55 @@ impl GatewayManager { } } - while let Some(Ok(_result)) = tasks.join_next().await { + while let Some(Ok(_result)) = self.handlers.join_next().await { debug!("Gateway gRPC task has ended"); } Ok(()) } + + fn run_handler( + &mut self, + gateway: Gateway, + clients: Arc>>, + certs_rx: Receiver>>, + ) -> Result { + let mut gateway_handler = GatewayHandler::new( + gateway, + self.pool.clone(), + self.tx.events.clone(), + self.tx.peer_stats.clone(), + certs_rx.clone(), + )?; + let abort_handle = self.handlers.spawn(async move { + loop { + if let Err(err) = gateway_handler + .handle_connection(Arc::clone(&clients)) + .await + { + error!("Gateway connection error: {err}, retrying in 5 seconds..."); + tokio::time::sleep(GATEWAY_RECONNECT_DELAY).await; + } + } + }); + Ok(abort_handle) + } +} + +/// Shared set of outbound channels that gateway instances use to forward +/// events, notifications, and side effects to Core components. +#[derive(Clone)] +pub struct GatewayTxSet { + events: Sender, + peer_stats: UnboundedSender, +} + +impl GatewayTxSet { + #[must_use] + pub const fn new( + events: Sender, + peer_stats: UnboundedSender, + ) -> Self { + Self { events, peer_stats } + } } diff --git a/migrations/20260218054705_[2.0.0]_gateway_cascade_delete.down.sql b/migrations/20260218054705_[2.0.0]_gateway_cascade_delete.down.sql new file mode 100644 index 0000000000..d438bcd26b --- /dev/null +++ b/migrations/20260218054705_[2.0.0]_gateway_cascade_delete.down.sql @@ -0,0 +1,7 @@ +ALTER TABLE gateway +DROP CONSTRAINT gateway_network_id_fkey; + +ALTER TABLE gateway +ADD CONSTRAINT gateway_network_id_fkey +FOREIGN KEY (network_id) +REFERENCES wireguard_network(id); diff --git a/migrations/20260218054705_[2.0.0]_gateway_cascade_delete.up.sql b/migrations/20260218054705_[2.0.0]_gateway_cascade_delete.up.sql new file mode 100644 index 0000000000..bf0b321e3b --- /dev/null +++ b/migrations/20260218054705_[2.0.0]_gateway_cascade_delete.up.sql @@ -0,0 +1,8 @@ +ALTER TABLE gateway +DROP CONSTRAINT gateway_network_id_fkey; + +ALTER TABLE gateway +ADD CONSTRAINT gateway_network_id_fkey +FOREIGN KEY (network_id) +REFERENCES wireguard_network(id) +ON DELETE CASCADE;