Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions crates/defguard/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use defguard_core::{
};
use defguard_event_logger::{message::EventLoggerMessage, run_event_logger};
use defguard_event_router::{RouterReceiverSet, run_event_router};
use defguard_gateway_manager::GatewayManager;
use defguard_gateway_manager::{GatewayManager, GatewayTxSet};
use defguard_proxy_manager::{ProxyManager, ProxyTxSet};
use defguard_session_manager::{events::SessionManagerEvent, run_session_manager};
use defguard_setup::setup::run_setup_web_server;
Expand Down Expand Up @@ -172,24 +172,22 @@ async fn main() -> Result<(), anyhow::Error> {
}

let (proxy_control_tx, proxy_control_rx) = channel::<ProxyControlMessage>(100);
let proxy_tx = ProxyTxSet::new(gateway_tx.clone(), bidi_event_tx.clone());
let proxy_manager = ProxyManager::new(
pool.clone(),
proxy_tx,
ProxyTxSet::new(gateway_tx.clone(), bidi_event_tx.clone()),
Arc::clone(&incompatible_components),
proxy_control_rx,
);

let mut gateway_manager = GatewayManager::default();
let mut gateway_manager = GatewayManager::new(
pool.clone(),
GatewayTxSet::new(gateway_tx.clone(), peer_stats_tx),
);

// run services
tokio::select! {
res = proxy_manager.run() => error!("ProxyManager returned early: {res:?}"),
res = gateway_manager.run(
pool.clone(),
gateway_tx.clone(),
peer_stats_tx,
) => error!("Gateway gRPC stream returned early: {res:?}"),
res = gateway_manager.run() => error!("GatewayManager returned early: {res:?}"),
res = run_grpc_server(
Arc::clone(&worker_state),
pool.clone(),
Expand Down
119 changes: 74 additions & 45 deletions crates/defguard_gateway_manager/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ use defguard_proto::gateway::gateway_client::GatewayClient;
use defguard_version::client::ClientVersionInterceptor;
use sqlx::{PgPool, postgres::PgListener};
use tokio::{
sync::{broadcast::Sender, mpsc::UnboundedSender},
sync::{broadcast::Sender, mpsc::UnboundedSender, watch::Receiver},
task::{AbortHandle, JoinSet},
};
use tonic::{Request, service::interceptor::InterceptedService, transport::Channel};

use crate::handler::GatewayHandler;
use crate::{error::GatewayError, handler::GatewayHandler};

#[macro_use]
extern crate tracing;
Expand All @@ -39,64 +39,45 @@ const TEN_SECS: Duration = Duration::from_secs(10);

type Client = GatewayClient<InterceptedService<Channel, ClientVersionInterceptor>>;

#[derive(Default)]
pub struct GatewayManager {
clients: Arc<Mutex<HashMap<Id, Client>>>,
pool: PgPool,
handlers: JoinSet<Result<(), GatewayError>>,
tx: GatewayTxSet,
}

impl GatewayManager {
#[must_use]
pub fn new(pool: PgPool, tx: GatewayTxSet) -> Self {
Self {
clients: Arc::default(),
handlers: JoinSet::new(),
pool,
tx,
}
}

/// Bi-directional gRPC stream for communication with Defguard Gateway.
pub async fn run(
&mut self,
pool: PgPool,
events_tx: Sender<GatewayEvent>,
peer_stats_tx: UnboundedSender<PeerStatsUpdate>,
) -> Result<(), anyhow::Error> {
pub async fn run(&mut self) -> Result<(), anyhow::Error> {
let (certs_tx, certs_rx) = tokio::sync::watch::channel(Arc::new(HashMap::new()));
certs::refresh_certs(&pool, &certs_tx).await;
let refresh_pool = pool.clone();
certs::refresh_certs(&self.pool, &certs_tx).await;
let refresh_pool = self.pool.clone();
tokio::spawn(async move {
loop {
certs::refresh_certs(&refresh_pool, &certs_tx).await;
tokio::time::sleep(TEN_SECS).await;
}
});
let mut abort_handles = HashMap::new();

let mut tasks = JoinSet::new();
// Helper closure to launch `GatewayHandler`.
// TODO: Store arguments in GatewayManager and rewrite this to method
let mut launch_gateway_handler = |gateway: Gateway<Id>,
clients: Arc<Mutex<HashMap<Id, Client>>>|
-> Result<AbortHandle, anyhow::Error> {
let mut gateway_handler = GatewayHandler::new(
gateway,
pool.clone(),
events_tx.clone(),
peer_stats_tx.clone(),
certs_rx.clone(),
)?;
let abort_handle = tasks.spawn(async move {
loop {
if let Err(err) = gateway_handler
.handle_connection(Arc::clone(&clients))
.await
{
error!("Gateway connection error: {err}, retrying in 5 seconds...");
tokio::time::sleep(GATEWAY_RECONNECT_DELAY).await;
}
}
});
Ok(abort_handle)
};
for gateway in Gateway::all(&pool).await? {
for gateway in Gateway::all(&self.pool).await? {
let id = gateway.id;
let abort_handle = launch_gateway_handler(gateway, Arc::clone(&self.clients))?;
let abort_handle =
self.run_handler(gateway, Arc::clone(&self.clients), certs_rx.clone())?;
abort_handles.insert(id, abort_handle);
}

// Observe gateway URL changes.
let mut listener = PgListener::connect_with(&pool).await?;
let mut listener = PgListener::connect_with(&self.pool).await?;
listener.listen(GATEWAY_TABLE_TRIGGER).await?;
while let Ok(notification) = listener.recv().await {
let payload = notification.payload();
Expand All @@ -106,7 +87,7 @@ impl GatewayManager {
if let Some(new) = gateway_notification.new {
let id = new.id;
let abort_handle =
launch_gateway_handler(new, Arc::clone(&self.clients))?;
self.run_handler(new, Arc::clone(&self.clients), certs_rx.clone())?;
abort_handles.insert(id, abort_handle);
}
}
Expand All @@ -124,8 +105,11 @@ impl GatewayManager {
);
abort_handle.abort();
let id = new.id;
let abort_handle =
launch_gateway_handler(new, Arc::clone(&self.clients))?;
let abort_handle = self.run_handler(
new,
Arc::clone(&self.clients),
certs_rx.clone(),
)?;
abort_handles.insert(id, abort_handle);
} else {
warn!("Cannot find {old} on the list of connected gateways");
Expand Down Expand Up @@ -173,10 +157,55 @@ impl GatewayManager {
}
}

while let Some(Ok(_result)) = tasks.join_next().await {
while let Some(Ok(_result)) = self.handlers.join_next().await {
debug!("Gateway gRPC task has ended");
}

Ok(())
}

fn run_handler(
&mut self,
gateway: Gateway<Id>,
clients: Arc<Mutex<HashMap<Id, Client>>>,
certs_rx: Receiver<Arc<HashMap<Id, String>>>,
) -> Result<AbortHandle, GatewayError> {
let mut gateway_handler = GatewayHandler::new(
gateway,
self.pool.clone(),
self.tx.events.clone(),
self.tx.peer_stats.clone(),
certs_rx.clone(),
)?;
let abort_handle = self.handlers.spawn(async move {
loop {
if let Err(err) = gateway_handler
.handle_connection(Arc::clone(&clients))
.await
{
error!("Gateway connection error: {err}, retrying in 5 seconds...");
tokio::time::sleep(GATEWAY_RECONNECT_DELAY).await;
}
}
});
Ok(abort_handle)
}
}

/// Shared set of outbound channels that gateway instances use to forward
/// events, notifications, and side effects to Core components.
#[derive(Clone)]
pub struct GatewayTxSet {
events: Sender<GatewayEvent>,
peer_stats: UnboundedSender<PeerStatsUpdate>,
}

impl GatewayTxSet {
#[must_use]
pub const fn new(
events: Sender<GatewayEvent>,
peer_stats: UnboundedSender<PeerStatsUpdate>,
) -> Self {
Self { events, peer_stats }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
ALTER TABLE gateway
DROP CONSTRAINT gateway_network_id_fkey;

ALTER TABLE gateway
ADD CONSTRAINT gateway_network_id_fkey
FOREIGN KEY (network_id)
REFERENCES wireguard_network(id);
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
ALTER TABLE gateway
DROP CONSTRAINT gateway_network_id_fkey;

ALTER TABLE gateway
ADD CONSTRAINT gateway_network_id_fkey
FOREIGN KEY (network_id)
REFERENCES wireguard_network(id)
ON DELETE CASCADE;
Loading