From fe1fa6229365eb7d109625dd098fba430129cc5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Fri, 13 Mar 2026 16:00:38 +0100 Subject: [PATCH 01/36] add tests for worker grpc service --- Cargo.lock | 1 + crates/defguard_core/Cargo.toml | 1 + crates/defguard_core/src/grpc/mod.rs | 5 +- .../integration/grpc/common/mock_gateway.rs | 159 ---- .../tests/integration/grpc/common/mod.rs | 201 +++--- .../tests/integration/grpc/gateway.rs | 677 ------------------ .../tests/integration/grpc/health.rs | 25 + .../tests/integration/grpc/mod.rs | 3 +- .../tests/integration/grpc/worker.rs | 419 +++++++++++ .../defguard_core/tests/integration/main.rs | 2 +- 10 files changed, 557 insertions(+), 936 deletions(-) delete mode 100644 crates/defguard_core/tests/integration/grpc/common/mock_gateway.rs delete mode 100644 crates/defguard_core/tests/integration/grpc/gateway.rs create mode 100644 crates/defguard_core/tests/integration/grpc/health.rs create mode 100644 crates/defguard_core/tests/integration/grpc/worker.rs diff --git a/Cargo.lock b/Cargo.lock index fc5217cc04..b5a5d69e3b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1386,6 +1386,7 @@ dependencies = [ "defguard_web_ui", "futures", "humantime", + "hyper-util", "ipnetwork", "jsonwebkey", "jsonwebtoken", diff --git a/crates/defguard_core/Cargo.toml b/crates/defguard_core/Cargo.toml index 6a4c176b61..119d4a8c7c 100644 --- a/crates/defguard_core/Cargo.toml +++ b/crates/defguard_core/Cargo.toml @@ -87,6 +87,7 @@ async-stream = "0.3" [dev-dependencies] claims.workspace = true +hyper-util = "0.1" matches.workspace = true reqwest = { version = "0.12", features = [ "cookies", diff --git a/crates/defguard_core/src/grpc/mod.rs b/crates/defguard_core/src/grpc/mod.rs index beb66fd8a1..a4ec3e97aa 100644 --- a/crates/defguard_core/src/grpc/mod.rs +++ b/crates/defguard_core/src/grpc/mod.rs @@ -96,7 +96,7 @@ pub async fn run_grpc_server( Ok(()) } -pub(crate) async fn build_grpc_service_router( +pub async fn build_grpc_service_router( server: Server, pool: PgPool, worker_state: Arc>, @@ -114,6 +114,9 @@ pub(crate) async fn build_grpc_service_router( health_reporter .set_serving::>() .await; + health_reporter + .set_serving::>() + .await; let router = server .http2_keepalive_interval(Some(TEN_SECS)) diff --git a/crates/defguard_core/tests/integration/grpc/common/mock_gateway.rs b/crates/defguard_core/tests/integration/grpc/common/mock_gateway.rs deleted file mode 100644 index 11bcdafbfd..0000000000 --- a/crates/defguard_core/tests/integration/grpc/common/mock_gateway.rs +++ /dev/null @@ -1,159 +0,0 @@ -use std::time::Duration; - -use defguard_core::grpc::{AUTHORIZATION_HEADER, HOSTNAME_HEADER}; -use defguard_proto::gateway::{ - Configuration, ConfigurationRequest, Update, - -}; -use defguard_version::{Version, client::ClientVersionInterceptor}; -use tokio::{ - sync::mpsc::{UnboundedSender, unbounded_channel}, - task::JoinHandle, - time::timeout, -}; -use tokio_stream::wrappers::UnboundedReceiverStream; -use tonic::{ - Request, Response, Status, Streaming, - metadata::MetadataValue, - service::{Interceptor, InterceptorLayer, interceptor::InterceptedService}, - transport::Channel, -}; -use tower::ServiceBuilder; - -pub(crate) struct MockGateway { - client: GatewayServiceClient< - InterceptedService, ClientVersionInterceptor>, - >, - hostname: Option, - stats_update_thread_handle: Option>, - updates_stream: Option>, -} - -impl Drop for MockGateway { - fn drop(&mut self) { - if let Some(handle) = &self.stats_update_thread_handle { - handle.abort(); - } - } -} - -#[derive(Clone)] -struct AuthInterceptor { - auth_token: Option, - hostname: Option, -} - -impl AuthInterceptor { - pub(crate) fn new(auth_token: Option, hostname: Option) -> Self { - Self { - auth_token, - hostname, - } - } -} - -impl Interceptor for AuthInterceptor { - fn call(&mut self, mut request: tonic::Request<()>) -> Result, Status> { - // add authorization token - if let Some(token) = &self.auth_token { - request.metadata_mut().insert( - AUTHORIZATION_HEADER, - MetadataValue::try_from(token).expect("failed to convert token into metadata"), - ); - }; - - // add gateway hostname - if let Some(hostname) = &self.hostname { - request.metadata_mut().insert( - HOSTNAME_HEADER, - MetadataValue::try_from(hostname) - .expect("failed to convert hostname into metadata"), - ); - }; - - Ok(request) - } -} - -impl MockGateway { - #[must_use] - pub(crate) async fn new( - client_channel: Channel, - version: Version, - auth_token: Option, - hostname: Option, - ) -> Self { - let intercepted_channel = ServiceBuilder::new() - .layer(InterceptorLayer::new(ClientVersionInterceptor::new( - version, - ))) - .layer(InterceptorLayer::new(AuthInterceptor::new( - auth_token, - hostname.clone(), - ))) - .service(client_channel); - - let client = GatewayServiceClient::new(intercepted_channel); - - Self { - client, - hostname, - stats_update_thread_handle: None, - updates_stream: None, - } - } - - // Fetch gateway config from core - pub(crate) async fn get_gateway_config(&mut self) -> Result, Status> { - let request = Request::new(ConfigurationRequest { - name: self.hostname.clone(), - }); - - self.client.config(request).await - } - - pub(crate) async fn connect_to_updates_stream(&mut self) { - let request = Request::new(()); - - let updates_stream = self.client.updates(request).await.unwrap().into_inner(); - - self.updates_stream = Some(updates_stream); - } - - pub(crate) fn disconnect_from_updates_stream(&mut self) { - self.updates_stream = None; - } - - #[must_use] - pub(crate) async fn receive_next_update(&mut self) -> Option { - match &mut self.updates_stream { - Some(stream) => match timeout(Duration::from_millis(100), stream.message()).await { - Ok(result) => result.expect("failed to reveive update message"), - Err(_) => None, - }, - None => None, - } - } - - // Connect to interface stats update endpoint - // and return a tx which can be used to send stats updates to test gRPC server - #[must_use] - pub(crate) async fn setup_stats_update_stream(&mut self) -> UnboundedSender { - let (tx, rx) = unbounded_channel(); - - let request = Request::new(UnboundedReceiverStream::new(rx)); - - let mut client = self.client.clone(); - let task_handle = tokio::spawn(async move { - client.stats(request).await.expect("stats stream closed"); - }); - - self.stats_update_thread_handle = Some(task_handle); - - tx - } - - pub(crate) fn hostname(&self) -> String { - self.hostname.clone().unwrap_or_default() - } -} diff --git a/crates/defguard_core/tests/integration/grpc/common/mod.rs b/crates/defguard_core/tests/integration/grpc/common/mod.rs index d5d3794685..ee3fe24645 100644 --- a/crates/defguard_core/tests/integration/grpc/common/mod.rs +++ b/crates/defguard_core/tests/integration/grpc/common/mod.rs @@ -1,45 +1,38 @@ -use std::sync::{Arc, Mutex}; +use std::{ + env, + sync::{Arc, Mutex, Once}, +}; -use axum::http::Uri; use defguard_common::{ - db::models::settings::initialize_current_settings, messages::peer_stats_update::PeerStatsUpdate, + auth::claims::{ + AUTH_SECRET_ENV, Claims, ClaimsType, GATEWAY_SECRET_ENV, YUBIBRIDGE_SECRET_ENV, + }, + db::setup_pool, }; use defguard_core::{ auth::failed_login::FailedLoginMap, db::AppEvent, - enterprise::license::{License, LicenseTier, set_cached_license}, - events::GrpcEvent, - grpc::{ - WorkerState, build_grpc_service_router, - gateway::{client_state::ClientMap, events::GatewayEvent, map::GatewayMap}, - }, + grpc::{AUTHORIZATION_HEADER, WorkerState, build_grpc_service_router}, }; -use defguard_mail::Mail; use hyper_util::rt::TokioIo; -use sqlx::PgPool; +use sqlx::{PgPool, postgres::{PgConnectOptions, PgPoolOptions}}; use tokio::{ io::DuplexStream, - sync::{ - broadcast::{self, Sender}, - mpsc::{UnboundedReceiver, unbounded_channel}, - }, + sync::mpsc::{UnboundedReceiver, unbounded_channel}, task::JoinHandle, }; -use tonic::transport::{Channel, Endpoint, Server, server::Router}; +use tonic::{Request, transport::{Channel, Endpoint, Server, Uri, server::Router}}; use tower::service_fn; -use crate::common::{init_config, initialize_users}; +use crate::common::initialize_users; -pub mod mock_gateway; +static JWT_SECRETS: Once = Once::new(); pub struct TestGrpcServer { grpc_server_task_handle: JoinHandle<()>, - pub grpc_event_rx: UnboundedReceiver, - wireguard_tx: Sender, - client_state: Arc>, + pub worker_state: Arc>, pub client_channel: Channel, - #[allow(dead_code)] - peer_stats_rx: UnboundedReceiver, + pub app_event_rx: UnboundedReceiver, } impl TestGrpcServer { @@ -47,13 +40,10 @@ impl TestGrpcServer { pub async fn new( server_stream: DuplexStream, grpc_router: Router, - grpc_event_rx: UnboundedReceiver, - wireguard_tx: Sender, - client_state: Arc>, + worker_state: Arc>, client_channel: Channel, - peer_stats_rx: UnboundedReceiver, + app_event_rx: UnboundedReceiver, ) -> Self { - // spawn test gRPC server let grpc_server_task_handle = tokio::spawn(async move { grpc_router .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server_stream))) @@ -64,103 +54,120 @@ impl TestGrpcServer { Self { grpc_server_task_handle, - grpc_event_rx, - wireguard_tx, - client_state, + worker_state, client_channel, - peer_stats_rx, + app_event_rx, } } - - pub fn get_client_map(&self) -> std::sync::MutexGuard<'_, ClientMap> { - self.client_state - .lock() - .expect("failed to acquire lock on client state") - } - - pub fn send_wireguard_event(&self, event: GatewayEvent) { - self.wireguard_tx - .send(event) - .expect("failed to send gateway event"); - } } impl Drop for TestGrpcServer { fn drop(&mut self) { - // explicitly stop spawned gRPC server task self.grpc_server_task_handle.abort(); } } +pub(crate) async fn setup_grpc_pool(_: PgPoolOptions, options: PgConnectOptions) -> PgPool { + setup_pool(options).await +} + pub(crate) async fn create_client_channel(client_stream: DuplexStream) -> Channel { - // Move client to an option so we can _move_ the inner value - // on the first attempt to connect. All other attempts will fail. - // reference: https://github.com/hyperium/tonic/blob/master/examples/src/mock/mock.rs#L31 let mut client = Some(client_stream); + let connector = service_fn(move |_: Uri| { + let client = client.take(); + + async move { + if let Some(client) = client { + Ok::<_, std::io::Error>(TokioIo::new(client)) + } else { + Err(std::io::Error::other("Client already taken")) + } + } + }); + Endpoint::try_from("http://[::]:50051") .expect("Failed to create channel") - .connect_with_connector(service_fn(move |_: Uri| { - let client = client.take(); - - async move { - if let Some(client) = client { - Ok(TokioIo::new(client)) - } else { - Err(std::io::Error::other("Client already taken")) - } - } - })) + .connect_with_connector(connector) .await .expect("Failed to create client channel") } pub(crate) async fn make_grpc_test_server(pool: &PgPool) -> TestGrpcServer { - // create communication channel for clients - let (client_stream, server_stream) = tokio::io::duplex(1024); - let client_channel = create_client_channel(client_stream).await; - - // setup helper structs - let (grpc_event_tx, grpc_event_rx) = unbounded_channel::(); - let (app_event_tx, _app_event_rx) = unbounded_channel::(); - let worker_state = Arc::new(Mutex::new(WorkerState::new(app_event_tx.clone()))); - let (wg_tx, _wg_rx) = broadcast::channel::(16); - let (peer_stats_tx, peer_stats_rx) = unbounded_channel::(); - let gateway_state = Arc::new(Mutex::new(GatewayMap::new())); - let client_state = Arc::new(Mutex::new(ClientMap::new())); - - let failed_logins = FailedLoginMap::new(); - let failed_logins = Arc::new(Mutex::new(failed_logins)); - + initialize_jwt_secrets(); initialize_users(pool).await; - initialize_current_settings(pool) - .await - .expect("Could not initialize settings"); - - let license = License::new( - "test_customer".to_string(), - false, - // Permanent license - None, - None, - None, - LicenseTier::Business, - ); - set_cached_license(Some(license)); - let server = Server::builder(); + let (client_stream, server_stream) = tokio::io::duplex(1024); + let client_channel = create_client_channel(client_stream).await; - let grpc_router = build_grpc_service_router(server, pool.clone(), worker_state, failed_logins) - .await - .unwrap(); + let (app_event_tx, app_event_rx) = unbounded_channel::(); + let worker_state = Arc::new(Mutex::new(WorkerState::new(app_event_tx))); + let failed_logins = Arc::new(Mutex::new(FailedLoginMap::new())); + let grpc_router = build_grpc_service_router( + Server::builder(), + pool.clone(), + worker_state.clone(), + failed_logins, + ) + .await + .expect("failed to build gRPC router"); TestGrpcServer::new( server_stream, grpc_router, - grpc_event_rx, - wg_tx, - client_state, + worker_state, client_channel, - peer_stats_rx, + app_event_rx, ) .await } + +pub(crate) fn create_yubibridge_jwt(username: &str) -> String { + initialize_jwt_secrets(); + Claims::new( + ClaimsType::YubiBridge, + username.to_string(), + String::new(), + u32::MAX.into(), + ) + .to_jwt() + .expect("failed to generate YubiBridge token") +} + +pub(crate) fn create_gateway_jwt(username: &str, client_id: &str) -> String { + initialize_jwt_secrets(); + Claims::new( + ClaimsType::Gateway, + username.to_string(), + client_id.to_string(), + u32::MAX.into(), + ) + .to_jwt() + .expect("failed to generate gateway token") +} + +pub(crate) fn add_authorization_metadata(request: &mut Request, token: &str) { + request.metadata_mut().insert( + AUTHORIZATION_HEADER, + token.parse().expect("failed to encode authorization token"), + ); +} + +pub(crate) fn add_worker_auth_metadata(request: &mut Request, username: &str) { + add_authorization_metadata(request, &create_yubibridge_jwt(username)); +} + +pub(crate) fn worker_request(message: T, username: &str) -> Request { + let mut request = Request::new(message); + add_worker_auth_metadata(&mut request, username); + request +} + +fn initialize_jwt_secrets() { + JWT_SECRETS.call_once(|| { + unsafe { + env::set_var(AUTH_SECRET_ENV, "defguard-test-auth-secret"); + env::set_var(GATEWAY_SECRET_ENV, "defguard-test-gateway-secret"); + env::set_var(YUBIBRIDGE_SECRET_ENV, "defguard-test-yubibridge-secret"); + } + }); +} diff --git a/crates/defguard_core/tests/integration/grpc/gateway.rs b/crates/defguard_core/tests/integration/grpc/gateway.rs deleted file mode 100644 index 5564f02c43..0000000000 --- a/crates/defguard_core/tests/integration/grpc/gateway.rs +++ /dev/null @@ -1,677 +0,0 @@ -use std::{ - net::{IpAddr, Ipv4Addr, SocketAddr}, - time::Duration, -}; - -use chrono::{Days, Utc}; -use claims::{assert_err_eq, assert_matches}; -use defguard_common::db::{ - Id, NoId, - models::{ - Device, DeviceType, User, WireguardNetwork, - wireguard::{LocationMfaMode, ServiceLocationMode}, - wireguard_peer_stats::WireguardPeerStats, - }, - setup_pool, -}; -use defguard_core::{ - enterprise::{license::set_cached_license, limits::update_counts}, - events::GrpcEvent, - grpc::{MIN_GATEWAY_VERSION, gateway::events::GatewayEvent}, -}; -use defguard_proto::{ - enterprise::firewall::FirewallPolicy, - gateway::{Configuration, PeerStats, Update, stats_update::Payload, update}, -}; -use semver::Version; -use sqlx::{ - PgPool, - postgres::{PgConnectOptions, PgPoolOptions}, -}; -use tokio::{sync::mpsc::error::TryRecvError, time::sleep}; -use tonic::Code; - -use crate::grpc::common::{TestGrpcServer, make_grpc_test_server, mock_gateway::MockGateway}; - -fn generate_gateway_token(location: &Location) -> String { - Claims::new( - ClaimsType::Gateway, - format!("DEFGUARD-NETWORK-{location_id}"), - location.id.to_string(), - u32::MAX.into(), - ) - .to_jwt() - .expect("failed to generate gateway token") -} -async fn setup_test_server( - pool: PgPool, -) -> (TestGrpcServer, MockGateway, WireguardNetwork, User) { - let test_server = make_grpc_test_server(&pool).await; - - // create a test location - let location = WireguardNetwork::new( - "test location".to_string(), - Vec::new(), - 1000, - "endpoint1".to_string(), - None, - Vec::new(), - 100, - 100, - false, - false, - LocationMfaMode::Disabled, - ServiceLocationMode::Disabled, - ) - .save(&pool) - .await - .unwrap(); - - // set auth token for gateway - let token = generate_gateway_token(&location); - - // setup mock gateway - let gateway = MockGateway::new( - test_server.client_channel.clone(), - MIN_GATEWAY_VERSION, - Some(token), - Some("test gateway".into()), - ) - .await; - - // get test user - let test_user = User::find_by_username(&pool, "hpotter") - .await - .unwrap() - .unwrap(); - - (test_server, gateway, location, test_user) -} - -#[sqlx::test] -async fn test_gateway_authorization(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - let (test_server, _gateway, test_location, _test_user) = setup_test_server(pool).await; - - // setup another test gateway without a token - let mut test_gateway = MockGateway::new( - test_server.client_channel.clone(), - MIN_GATEWAY_VERSION, - None, - Some("test gateway".into()), - ) - .await; - - // make a request without auth token - let response = test_gateway.get_gateway_config().await; - - // check that response code is `Code::Unauthenticated` - assert!(response.is_err()); - let status = response.err().unwrap(); - assert_eq!(status.code(), Code::Unauthenticated); - - // setup another test gateway with an invalid token - let mut test_gateway = MockGateway::new( - test_server.client_channel.clone(), - MIN_GATEWAY_VERSION, - Some("invalid_token".into()), - Some("test gateway".into()), - ) - .await; - let response = test_gateway.get_gateway_config().await; - assert!(response.is_err()); - let status = response.err().unwrap(); - assert_eq!(status.code(), Code::Unauthenticated); - - // use valid token and retry - let token = generate_gateway_token(&test_location); - // setup another test gateway without a token - let mut test_gateway = MockGateway::new( - test_server.client_channel.clone(), - MIN_GATEWAY_VERSION, - Some(token), - Some("test gateway".into()), - ) - .await; - let response = test_gateway.get_gateway_config().await; - assert!(response.is_ok()); -} - -#[sqlx::test] -async fn test_gateway_hostname_is_required(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - let (test_server, _gateway, test_location, _test_user) = setup_test_server(pool).await; - - // setup gateway without hostname - let token = generate_gateway_token(&test_location); - let mut test_gateway = MockGateway::new( - test_server.client_channel.clone(), - MIN_GATEWAY_VERSION, - Some(token.clone()), - None, - ) - .await; - - // make a request without hostname - let response = test_gateway.get_gateway_config().await; - - // check that response code is `Code::Internal` - assert!(response.is_err()); - let status = response.err().unwrap(); - assert_eq!(status.code(), Code::Internal); - - // set hostname and retry - let mut test_gateway = MockGateway::new( - test_server.client_channel.clone(), - MIN_GATEWAY_VERSION, - Some(token), - Some("test gateway".into()), - ) - .await; - let response = test_gateway.get_gateway_config().await; - assert!(response.is_ok()); -} - -#[sqlx::test] -async fn test_gateway_status(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - let (test_server, mut gateway, test_location, _test_user) = setup_test_server(pool).await; - - // initial gateway map is empty - { - let gateway_map = test_server.get_gateway_map(); - assert!(gateway_map.is_empty()) - } - - // gateway request initial config - // it should be added to status map as disconnected - let response = gateway.get_gateway_config().await; - assert!(response.is_ok()); - { - let gateway_map = test_server.get_gateway_map(); - let location_gateways = gateway_map.get_network_gateway_status(test_location.id); - assert_eq!(location_gateways.len(), 1); - let gateway_state = location_gateways.first().unwrap(); - assert!(!gateway_state.connected); - assert!(gateway_state.connected_at.is_none()); - assert!(gateway_state.disconnected_at.is_none()); - assert_eq!(gateway_state.hostname, gateway.hostname()); - } - - // gateway connects to updates stream - // it should be marked as connected - gateway.connect_to_updates_stream().await; - { - let gateway_map = test_server.get_gateway_map(); - let location_gateways = gateway_map.get_network_gateway_status(test_location.id); - assert_eq!(location_gateways.len(), 1); - let gateway_state = location_gateways.first().unwrap(); - assert!(gateway_state.connected); - assert!(gateway_state.connected_at.is_some()); - assert!(gateway_state.disconnected_at.is_none()); - assert_eq!(gateway_state.hostname, gateway.hostname()); - } - - // gateway disconnect from updates stream - // it should be marked as disconnected - gateway.disconnect_from_updates_stream(); - // wait for the background thread to handle the disconnect - sleep(Duration::from_millis(100)).await; - - { - let gateway_map = test_server.get_gateway_map(); - let location_gateways = gateway_map.get_network_gateway_status(test_location.id); - assert_eq!(location_gateways.len(), 1); - let gateway_state = location_gateways.first().unwrap(); - assert!(!gateway_state.connected); - assert!(gateway_state.connected_at.is_some()); - assert!(gateway_state.disconnected_at.is_some()); - assert_eq!(gateway_state.hostname, gateway.hostname()); - } -} - -#[sqlx::test] -async fn test_vpn_client_connected(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - let (mut test_server, mut gateway, test_location, test_user) = - setup_test_server(pool.clone()).await; - - // initial client map is empty - { - let client_map = test_server.get_client_map(); - assert!(client_map.is_empty()) - } - - // connect stats stream - let stats_tx = gateway.setup_stats_update_stream().await; - let mut update_id = 1; - - // add user device - let device_pubkey = "wYOt6ImBaQ3BEMQ3Xf5P5fTnbqwOvjcqYkkSBt+1xOg="; - let test_device = Device::new( - "test device".into(), - device_pubkey.into(), - test_user.id, - DeviceType::User, - None, - true, - ) - .save(&pool) - .await - .unwrap(); - - // send stats update for existing device with old handshake - // and verify no gRPC event is emitted - stats_tx - .send(StatsUpdate { - id: update_id, - payload: Some(Payload::PeerStats(PeerStats { - public_key: device_pubkey.into(), - endpoint: "1.2.3.4:1234".into(), - latest_handshake: 0, - ..Default::default() - })), - }) - .expect("failed to send stats update"); - - assert_err_eq!(test_server.grpc_event_rx.try_recv(), TryRecvError::Empty); - - // send stats update with current handshake - update_id += 1; - stats_tx - .send(StatsUpdate { - id: update_id, - payload: Some(Payload::PeerStats(PeerStats { - public_key: device_pubkey.into(), - endpoint: "1.2.3.4:1234".into(), - latest_handshake: Utc::now().timestamp() as u64, - ..Default::default() - })), - }) - .expect("failed to send stats update"); - - // wait for event to be emitted - sleep(Duration::from_millis(100)).await; - let grpc_event = test_server - .grpc_event_rx - .try_recv() - .expect("failed to receive gRPC event"); - - assert_matches!( - grpc_event, - GrpcEvent::ClientConnected { - context: _, - location, - device - } if ((location.id == test_location.id) & (device.id == test_device.id)) - ); -} - -#[sqlx::test] -async fn test_vpn_client_disconnected(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - let (mut test_server, mut gateway, test_location, test_user) = - setup_test_server(pool.clone()).await; - - // add user device - let device_pubkey = "wYOt6ImBaQ3BEMQ3Xf5P5fTnbqwOvjcqYkkSBt+1xOg="; - let test_device = Device::new( - "test device".into(), - device_pubkey.into(), - test_user.id, - DeviceType::User, - None, - true, - ) - .save(&pool) - .await - .unwrap(); - - // insert device into client map with an old handshake - { - let mut client_map = test_server.get_client_map(); - let now = Utc::now().naive_utc(); - let stats = WireguardPeerStats { - id: NoId, - device_id: test_device.id, - collected_at: now, - network: test_location.id, - endpoint: None, - upload: 0, - download: 0, - latest_handshake: now.checked_sub_days(Days::new(1)).unwrap(), - allowed_ips: None, - }; - client_map - .connect_vpn_client( - test_location.id, - &gateway.hostname(), - device_pubkey, - &test_device, - &test_user, - SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080), - &stats, - ) - .expect("failed to insert connected client"); - } - - // connect stats stream - let stats_tx = gateway.setup_stats_update_stream().await; - let mut update_id = 1; - - // send stats update with old handshake - update_id += 1; - stats_tx - .send(StatsUpdate { - id: update_id, - payload: Some(Payload::PeerStats(PeerStats { - public_key: device_pubkey.into(), - endpoint: "1.2.3.4:1234".into(), - latest_handshake: 0, - ..Default::default() - })), - }) - .expect("failed to send stats update"); - - // wait for event to be emitted - sleep(Duration::from_millis(100)).await; - let grpc_event = test_server - .grpc_event_rx - .try_recv() - .expect("failed to receive gRPC event"); - - assert_matches!( - grpc_event, - GrpcEvent::ClientDisconnected { - context: _, - location, - device - } if ((location.id == test_location.id) & (device.id == test_device.id)) - ); -} - -#[sqlx::test] -async fn test_gateway_update_routing(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - let (test_server, mut gateway_1, test_location, _test_user) = - setup_test_server(pool.clone()).await; - - // setup another test location & gateway - let test_location_2 = WireguardNetwork::new( - "test location 2".to_string(), - Vec::new(), - 1000, - "endpoint2".to_string(), - None, - Vec::new(), - 100, - 100, - false, - false, - LocationMfaMode::Disabled, - ServiceLocationMode::Disabled, - ) - .save(&pool) - .await - .unwrap(); - - // set auth token for gateway - let token = generate_gateway_token(&test_location_2); - let mut gateway_2 = MockGateway::new( - test_server.client_channel.clone(), - MIN_GATEWAY_VERSION, - Some(token), - Some("test_gateway_2".into()), - ) - .await; - - // register gateways with core - let _config_1 = gateway_1.get_gateway_config().await; - let _config_2 = gateway_2.get_gateway_config().await; - - // connect gateways to the updates stream - gateway_1.connect_to_updates_stream().await; - gateway_2.connect_to_updates_stream().await; - - // send update for location 1 - test_server.send_wireguard_event(GatewayEvent::NetworkDeleted( - test_location.id, - "network name".into(), - )); - - // only one gateway should receive this update - assert!(gateway_2.receive_next_update().await.is_none()); - let update = gateway_1.receive_next_update().await.unwrap(); - let expected_update = Update { - update_type: 2, - update: Some(update::Update::Network(Configuration { - name: "network name".into(), - prvkey: String::new(), - addresses: Vec::new(), - port: 0, - peers: Vec::new(), - firewall_config: None, - })), - }; - assert_eq!(update, expected_update); - - // send update for location 2 - test_server.send_wireguard_event(GatewayEvent::NetworkDeleted( - test_location_2.id, - "network name 2".into(), - )); - - // only one gateway should receive this update - assert!(gateway_1.receive_next_update().await.is_none()); - let update = gateway_2.receive_next_update().await.unwrap(); - let expected_update = Update { - update_type: 2, - update: Some(update::Update::Network(Configuration { - name: "network name 2".into(), - prvkey: String::new(), - addresses: Vec::new(), - port: 0, - peers: Vec::new(), - firewall_config: None, - })), - }; - assert_eq!(update, expected_update); - - // send update for location which does not exist - test_server.send_wireguard_event(GatewayEvent::NetworkDeleted(1234, "does not exist".into())); - - // no gateway should receive this update - assert!(gateway_1.receive_next_update().await.is_none()); - assert!(gateway_2.receive_next_update().await.is_none()); -} - -#[sqlx::test] -async fn test_gateway_config(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - let (_test_server, mut gateway, mut test_location, _test_user) = - setup_test_server(pool.clone()).await; - - // get gateway config - let config = gateway.get_gateway_config().await.unwrap().into_inner(); - - assert_eq!(config.name, test_location.name); - assert!(config.firewall_config.is_none()); - - // enable ACL for test location - test_location.acl_enabled = true; - test_location - .save(&pool) - .await - .expect("failed to update location"); - - // get gateway config - let config = gateway.get_gateway_config().await.unwrap().into_inner(); - assert!(config.firewall_config.is_some()); - assert_eq!( - config.firewall_config.unwrap().default_policy == i32::from(FirewallPolicy::Allow), - test_location.acl_default_allow - ); - - // unset the license and create another location to exceed limits and disable enterprise features - set_cached_license(None); - let _test_location_2 = WireguardNetwork::new( - "test location 2".to_string(), - Vec::new(), - 1000, - "endpoint2".to_string(), - None, - Vec::new(), - 100, - 100, - false, - false, - LocationMfaMode::Disabled, - ServiceLocationMode::Disabled, - ) - .save(&pool) - .await - .unwrap(); - update_counts(&pool).await.unwrap(); - - let config = gateway.get_gateway_config().await.unwrap().into_inner(); - assert!(config.firewall_config.is_none()); -} - -#[sqlx::test] -async fn test_gateway_version_validation(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - let (test_server, _gateway, test_location, _test_user) = setup_test_server(pool.clone()).await; - - // setup gateway with unsupported version - let unsupported_version = - Version::new(MIN_GATEWAY_VERSION.major, MIN_GATEWAY_VERSION.minor - 1, 0); - let token = generate_gateway_token(&test_location); - // setup another test gateway without a token - let mut test_gateway = MockGateway::new( - test_server.client_channel.clone(), - unsupported_version, - Some(token), - Some("test gateway".into()), - ) - .await; - let response = test_gateway.get_gateway_config().await; - - // check that response code is `Code::FailedPrecondition` - assert!(response.is_err()); - let status = response.err().unwrap(); - assert_eq!(status.code(), Code::FailedPrecondition); -} - -// https://github.com/DefGuard/defguard/issues/1671 -#[sqlx::test] -async fn test_device_pubkey_change(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - let (mut test_server, mut gateway, test_location, test_user) = - setup_test_server(pool.clone()).await; - - // initial client map is empty - { - let client_map = test_server.get_client_map(); - assert!(client_map.is_empty()) - } - - // connect stats stream - let stats_tx = gateway.setup_stats_update_stream().await; - let mut update_id = 1; - - // add user device - let device_pubkey = "wYOt6ImBaQ3BEMQ3Xf5P5fTnbqwOvjcqYkkSBt+1xOg="; - let mut test_device = Device::new( - "test device".into(), - device_pubkey.into(), - test_user.id, - DeviceType::User, - None, - true, - ) - .save(&pool) - .await - .unwrap(); - - // send stats update for existing device - stats_tx - .send(StatsUpdate { - id: update_id, - payload: Some(Payload::PeerStats(PeerStats { - public_key: device_pubkey.into(), - endpoint: "1.2.3.4:1234".into(), - latest_handshake: Utc::now().timestamp() as u64, - ..Default::default() - })), - }) - .expect("failed to send stats update"); - - // wait for event to be emitted - sleep(Duration::from_millis(100)).await; - let grpc_event = test_server - .grpc_event_rx - .try_recv() - .expect("failed to receive gRPC event"); - assert_matches!( - grpc_event, - GrpcEvent::ClientConnected { - context: _, - location, - device - } if ((location.id == test_location.id) & (device.id == test_device.id)) - ); - - // change device pubkey - let new_device_pubkey = "TJG2T6rhndZtk06KnIIOlD6hhd7wpVkBss8sfyvMCAA="; - test_device.wireguard_pubkey = new_device_pubkey.to_owned(); - test_device.save(&pool).await.unwrap(); - - // send stats update with old pubkey - update_id += 1; - stats_tx - .send(StatsUpdate { - id: update_id, - payload: Some(Payload::PeerStats(PeerStats { - public_key: device_pubkey.into(), - endpoint: "1.2.3.4:1234".into(), - latest_handshake: Utc::now().timestamp() as u64, - ..Default::default() - })), - }) - .expect("failed to send stats update"); - - // no event should be emitted - sleep(Duration::from_millis(100)).await; - assert_err_eq!(test_server.grpc_event_rx.try_recv(), TryRecvError::Empty); - - // send stats update with new pubkey - update_id += 1; - stats_tx - .send(StatsUpdate { - id: update_id, - payload: Some(Payload::PeerStats(PeerStats { - public_key: new_device_pubkey.into(), - endpoint: "1.2.3.4:1234".into(), - latest_handshake: Utc::now().timestamp() as u64, - ..Default::default() - })), - }) - .expect("failed to send stats update"); - - // wait for event - // FIXME: ideally this should not be emitted; we'll fix it once we implement a more robust VPN session logic - sleep(Duration::from_millis(100)).await; - let grpc_event = test_server - .grpc_event_rx - .try_recv() - .expect("failed to receive gRPC event"); - - assert_matches!( - grpc_event, - GrpcEvent::ClientConnected { - context: _, - location, - device - } if ((location.id == test_location.id) & (device.id == test_device.id)) - ); -} diff --git a/crates/defguard_core/tests/integration/grpc/health.rs b/crates/defguard_core/tests/integration/grpc/health.rs new file mode 100644 index 0000000000..975ccb8319 --- /dev/null +++ b/crates/defguard_core/tests/integration/grpc/health.rs @@ -0,0 +1,25 @@ +use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; +use tonic_health::pb::{ + HealthCheckRequest, + health_check_response::ServingStatus, + health_client::HealthClient, +}; + +use super::common::{make_grpc_test_server, setup_grpc_pool}; + +#[sqlx::test] +async fn worker_service_health_is_serving(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_grpc_pool(PgPoolOptions::new(), options).await; + let server = make_grpc_test_server(&pool).await; + let mut client = HealthClient::new(server.client_channel.clone()); + + let response = client + .check(HealthCheckRequest { + service: "worker.WorkerService".into(), + }) + .await + .expect("health check should succeed") + .into_inner(); + + assert_eq!(response.status, ServingStatus::Serving as i32); +} diff --git a/crates/defguard_core/tests/integration/grpc/mod.rs b/crates/defguard_core/tests/integration/grpc/mod.rs index 5b53a1b0d6..d8d759d58f 100644 --- a/crates/defguard_core/tests/integration/grpc/mod.rs +++ b/crates/defguard_core/tests/integration/grpc/mod.rs @@ -1,2 +1,3 @@ mod common; -mod gateway; +mod health; +mod worker; diff --git a/crates/defguard_core/tests/integration/grpc/worker.rs b/crates/defguard_core/tests/integration/grpc/worker.rs new file mode 100644 index 0000000000..90856d2d9e --- /dev/null +++ b/crates/defguard_core/tests/integration/grpc/worker.rs @@ -0,0 +1,419 @@ +use claims::assert_matches; +use defguard_common::db::models::{AuthenticationKey, AuthenticationKeyType, User, YubiKey}; +use defguard_proto::worker::{ + JobStatus, Worker, + worker_service_client::WorkerServiceClient, +}; +use defguard_core::db::AppEvent; +use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; +use tokio::sync::mpsc::error::TryRecvError; +use tonic::Code; + +use super::common::{ + add_authorization_metadata, create_gateway_jwt, make_grpc_test_server, setup_grpc_pool, + worker_request, +}; + +#[sqlx::test] +async fn register_worker_success_and_duplicate(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_grpc_pool(PgPoolOptions::new(), options).await; + let server = make_grpc_test_server(&pool).await; + let mut client = WorkerServiceClient::new(server.client_channel.clone()); + + let response = client + .register_worker(worker_request( + Worker { + id: "worker-1".into(), + }, + "admin", + )) + .await; + + assert!(response.is_ok()); + + let status = client + .register_worker(worker_request( + Worker { + id: "worker-1".into(), + }, + "admin", + )) + .await + .expect_err("duplicate worker should fail"); + + assert_eq!(status.code(), Code::AlreadyExists); + assert_eq!(status.message(), "Worker already registered"); +} + +#[sqlx::test] +async fn get_job_returns_not_found_for_unknown_worker(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_grpc_pool(PgPoolOptions::new(), options).await; + let server = make_grpc_test_server(&pool).await; + let mut client = WorkerServiceClient::new(server.client_channel.clone()); + + let status = client + .get_job(worker_request( + Worker { + id: "missing-worker".into(), + }, + "admin", + )) + .await + .expect_err("missing worker should not have jobs"); + + assert_eq!(status.code(), Code::NotFound); + assert_eq!(status.message(), "No more jobs"); +} + +#[sqlx::test] +async fn get_job_returns_not_found_for_registered_worker_without_jobs( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_grpc_pool(PgPoolOptions::new(), options).await; + let server = make_grpc_test_server(&pool).await; + let mut client = WorkerServiceClient::new(server.client_channel.clone()); + + client + .register_worker(worker_request( + Worker { + id: "worker-1".into(), + }, + "admin", + )) + .await + .expect("worker registration should succeed"); + + let status = client + .get_job(worker_request( + Worker { + id: "worker-1".into(), + }, + "admin", + )) + .await + .expect_err("worker without jobs should get not found"); + + assert_eq!(status.code(), Code::NotFound); + assert_eq!(status.message(), "No more jobs"); +} + +#[sqlx::test] +async fn get_job_returns_seeded_payload(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_grpc_pool(PgPoolOptions::new(), options).await; + let server = make_grpc_test_server(&pool).await; + let mut client = WorkerServiceClient::new(server.client_channel.clone()); + + client + .register_worker(worker_request( + Worker { + id: "worker-1".into(), + }, + "admin", + )) + .await + .expect("worker registration should succeed"); + + let job_id = { + let mut state = server.worker_state.lock().unwrap(); + state.create_job( + "worker-1", + "Minerva".into(), + "McGonagall".into(), + "minerva@hogwart.edu.uk".into(), + "hpotter".into(), + ) + }; + + let response = client + .get_job(worker_request( + Worker { + id: "worker-1".into(), + }, + "admin", + )) + .await + .expect("seeded job should be returned") + .into_inner(); + + assert_eq!(response.job_id, job_id); + assert_eq!(response.first_name, "Minerva"); + assert_eq!(response.last_name, "McGonagall"); + assert_eq!(response.email, "minerva@hogwart.edu.uk"); +} + +#[sqlx::test] +async fn set_job_done_success_removes_job_and_stores_status( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_grpc_pool(PgPoolOptions::new(), options).await; + let mut server = make_grpc_test_server(&pool).await; + let mut client = WorkerServiceClient::new(server.client_channel.clone()); + + client + .register_worker(worker_request( + Worker { + id: "worker-1".into(), + }, + "admin", + )) + .await + .expect("worker registration should succeed"); + + let job_id = { + let mut state = server.worker_state.lock().unwrap(); + state.create_job( + "worker-1", + "Harry".into(), + "Potter".into(), + "h.potter@hogwart.edu.uk".into(), + "hpotter".into(), + ) + }; + + client + .set_job_done(worker_request( + JobStatus { + id: "worker-1".into(), + job_id, + success: true, + public_key: "gpg-public-key".into(), + ssh_key: "ssh-public-key".into(), + yubikey_serial: "yk-serial-1".into(), + error: String::new(), + }, + "admin", + )) + .await + .expect("job completion should succeed"); + + { + let mut state = server.worker_state.lock().unwrap(); + let status = state + .get_job_status(job_id) + .expect("job status should be recorded"); + assert!(status.success); + assert_eq!(status.serial, "yk-serial-1"); + assert_eq!(status.error, ""); + assert!(state.get_job("worker-1", std::net::IpAddr::from([127, 0, 0, 1])).is_none()); + } + + let user = User::find_by_username(&pool, "hpotter") + .await + .expect("user query should succeed") + .expect("user should exist"); + let yubikeys = YubiKey::find_by_user_id(&pool, user.id) + .await + .expect("yubikey query should succeed"); + let auth_keys = AuthenticationKey::find_by_user_id(&pool, user.id, None) + .await + .expect("auth key query should succeed"); + + assert_eq!(yubikeys.len(), 1); + assert_eq!(yubikeys[0].serial, "yk-serial-1"); + assert_eq!(auth_keys.len(), 2); + assert!(auth_keys.iter().any(|key| { + key.key_type == AuthenticationKeyType::Ssh + && key.key == "ssh-public-key" + && key.yubikey_id == Some(yubikeys[0].id) + })); + assert!(auth_keys.iter().any(|key| { + key.key_type == AuthenticationKeyType::Gpg + && key.key == "gpg-public-key" + && key.yubikey_id == Some(yubikeys[0].id) + })); + + let event = server + .app_event_rx + .try_recv() + .expect("success should emit an app event"); + assert_matches!( + event, + AppEvent::HWKeyProvision(data) + if data.username == "hpotter" + && data.email == "h.potter@hogwart.edu.uk" + && data.ssh_key == "ssh-public-key" + && data.pgp_key == "gpg-public-key" + && data.serial == "yk-serial-1" + ); + assert_matches!(server.app_event_rx.try_recv(), Err(TryRecvError::Empty)); +} + +#[sqlx::test] +async fn set_job_done_failure_stores_status_without_keys_or_event( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_grpc_pool(PgPoolOptions::new(), options).await; + let mut server = make_grpc_test_server(&pool).await; + let mut client = WorkerServiceClient::new(server.client_channel.clone()); + + client + .register_worker(worker_request( + Worker { + id: "worker-1".into(), + }, + "admin", + )) + .await + .expect("worker registration should succeed"); + + let job_id = { + let mut state = server.worker_state.lock().unwrap(); + state.create_job( + "worker-1", + "Harry".into(), + "Potter".into(), + "h.potter@hogwart.edu.uk".into(), + "hpotter".into(), + ) + }; + + client + .set_job_done(worker_request( + JobStatus { + id: "worker-1".into(), + job_id, + success: false, + public_key: "gpg-public-key".into(), + ssh_key: "ssh-public-key".into(), + yubikey_serial: "yk-serial-1".into(), + error: "worker failed".into(), + }, + "admin", + )) + .await + .expect("failed completion should still return ok"); + + { + let state = server.worker_state.lock().unwrap(); + let status = state + .get_job_status(job_id) + .expect("failure status should be recorded"); + assert!(!status.success); + assert_eq!(status.error, "worker failed"); + } + + let user = User::find_by_username(&pool, "hpotter") + .await + .expect("user query should succeed") + .expect("user should exist"); + assert!(YubiKey::find_by_user_id(&pool, user.id) + .await + .expect("yubikey query should succeed") + .is_empty()); + assert!(AuthenticationKey::find_by_user_id(&pool, user.id, None) + .await + .expect("auth key query should succeed") + .is_empty()); + assert_matches!(server.app_event_rx.try_recv(), Err(TryRecvError::Empty)); +} + +#[sqlx::test] +async fn set_job_done_unknown_job_is_ignored(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_grpc_pool(PgPoolOptions::new(), options).await; + let mut server = make_grpc_test_server(&pool).await; + let mut client = WorkerServiceClient::new(server.client_channel.clone()); + + client + .register_worker(worker_request( + Worker { + id: "worker-1".into(), + }, + "admin", + )) + .await + .expect("worker registration should succeed"); + + client + .set_job_done(worker_request( + JobStatus { + id: "worker-1".into(), + job_id: 999, + success: true, + public_key: "gpg-public-key".into(), + ssh_key: "ssh-public-key".into(), + yubikey_serial: "yk-serial-1".into(), + error: String::new(), + }, + "admin", + )) + .await + .expect("unknown jobs should be ignored"); + + { + let state = server.worker_state.lock().unwrap(); + assert!(state.get_job_status(999).is_none()); + } + + let user = User::find_by_username(&pool, "hpotter") + .await + .expect("user query should succeed") + .expect("user should exist"); + assert!(YubiKey::find_by_user_id(&pool, user.id) + .await + .expect("yubikey query should succeed") + .is_empty()); + assert!(AuthenticationKey::find_by_user_id(&pool, user.id, None) + .await + .expect("auth key query should succeed") + .is_empty()); + assert_matches!(server.app_event_rx.try_recv(), Err(TryRecvError::Empty)); +} + +#[sqlx::test] +async fn worker_interceptor_requires_valid_yubibridge_token( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_grpc_pool(PgPoolOptions::new(), options).await; + let server = make_grpc_test_server(&pool).await; + let mut client = WorkerServiceClient::new(server.client_channel.clone()); + + let missing_auth = client + .register_worker(tonic::Request::new(Worker { + id: "worker-missing-auth".into(), + })) + .await + .expect_err("missing auth should fail"); + assert_eq!(missing_auth.code(), Code::Unauthenticated); + assert_eq!(missing_auth.message(), "Missing authorization header"); + + let mut invalid_request = tonic::Request::new(Worker { + id: "worker-invalid-token".into(), + }); + add_authorization_metadata(&mut invalid_request, "not-a-jwt"); + let invalid_token = client + .register_worker(invalid_request) + .await + .expect_err("invalid token should fail"); + assert_eq!(invalid_token.code(), Code::Unauthenticated); + assert_eq!(invalid_token.message(), "Invalid token"); + + let mut wrong_claims_request = tonic::Request::new(Worker { + id: "worker-wrong-claims".into(), + }); + add_authorization_metadata( + &mut wrong_claims_request, + &create_gateway_jwt("admin", "gateway-network-1"), + ); + let wrong_claims = client + .register_worker(wrong_claims_request) + .await + .expect_err("wrong claims type should fail"); + assert_eq!(wrong_claims.code(), Code::Unauthenticated); + assert_eq!(wrong_claims.message(), "Invalid token"); + + let allowed = client + .get_job(worker_request( + Worker { + id: "worker-valid-token".into(), + }, + "admin", + )) + .await + .expect_err("valid token should reach service logic"); + assert_eq!(allowed.code(), Code::NotFound); + assert_eq!(allowed.message(), "No more jobs"); +} diff --git a/crates/defguard_core/tests/integration/main.rs b/crates/defguard_core/tests/integration/main.rs index 43ffcc1303..f0ca4e17ed 100644 --- a/crates/defguard_core/tests/integration/main.rs +++ b/crates/defguard_core/tests/integration/main.rs @@ -1,4 +1,4 @@ mod api; mod common; -// mod grpc; +mod grpc; mod ldap; From b91450fe481b8a18a27115b31b2096f4105e9862 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Mon, 16 Mar 2026 13:20:30 +0100 Subject: [PATCH 02/36] add basic gateway manager tests --- .../defguard_gateway_manager/src/handler.rs | 336 ++++++++----- crates/defguard_gateway_manager/src/lib.rs | 6 +- .../src/tests/handler.rs | 274 ++++++++++ .../defguard_gateway_manager/src/tests/mod.rs | 2 + .../src/tests/support.rs | 475 ++++++++++++++++++ 5 files changed, 966 insertions(+), 127 deletions(-) create mode 100644 crates/defguard_gateway_manager/src/tests/handler.rs create mode 100644 crates/defguard_gateway_manager/src/tests/mod.rs create mode 100644 crates/defguard_gateway_manager/src/tests/support.rs diff --git a/crates/defguard_gateway_manager/src/handler.rs b/crates/defguard_gateway_manager/src/handler.rs index ce705fdb5c..da921230b7 100644 --- a/crates/defguard_gateway_manager/src/handler.rs +++ b/crates/defguard_gateway_manager/src/handler.rs @@ -8,6 +8,9 @@ use std::{ }, }; +#[cfg(test)] +use std::path::{Path, PathBuf}; + use chrono::DateTime; #[cfg(not(test))] use defguard_common::db::models::Settings; @@ -52,6 +55,29 @@ use tonic::{Code, Status, transport::Endpoint}; use crate::{Client, TEN_SECS, error::GatewayError}; +#[cfg(test)] +#[derive(Debug, Default)] +struct GatewayTestTransport { + socket_path: Option, +} + +#[cfg(test)] +impl GatewayTestTransport { + fn with_socket_path(socket_path: PathBuf) -> Self { + Self { + socket_path: Some(socket_path), + } + } + + fn socket_path(&self) -> Result<&Path, GatewayError> { + self.socket_path.as_deref().ok_or_else(|| { + GatewayError::EndpointError( + "Missing test gateway transport socket path for GatewayHandler".to_string(), + ) + }) + } +} + /// One instance per connected Gateway. pub(super) struct GatewayHandler { // Gateway server endpoint URL. @@ -62,6 +88,8 @@ pub(super) struct GatewayHandler { events_tx: Sender, peer_stats_tx: UnboundedSender, certs_rx: watch::Receiver>>, + #[cfg(test)] + test_transport: GatewayTestTransport, } impl GatewayHandler { @@ -87,9 +115,25 @@ impl GatewayHandler { events_tx, peer_stats_tx, certs_rx, + #[cfg(test)] + test_transport: GatewayTestTransport::default(), }) } + #[cfg(test)] + pub(super) fn new_with_test_socket( + gateway: Gateway, + pool: PgPool, + events_tx: Sender, + peer_stats_tx: UnboundedSender, + certs_rx: watch::Receiver>>, + socket_path: PathBuf, + ) -> Result { + let mut handler = Self::new(gateway, pool, events_tx, peer_stats_tx, certs_rx)?; + handler.test_transport = GatewayTestTransport::with_socket_path(socket_path); + Ok(handler) + } + fn endpoint(&self) -> Result { let mut url = self.url.clone(); @@ -211,167 +255,211 @@ impl GatewayHandler { } } - /// Connect to Gateway and handle its messages through gRPC. - pub(super) async fn handle_connection( + async fn mark_disconnected(&mut self) { + if let Err(err) = self.gateway.touch_disconnected(&self.pool).await { + error!( + "Failed to update disconnection time for {} in the database: {err}", + self.gateway + ); + } + } + + async fn handle_disconnection_error(&mut self) { + if self.gateway.is_connected() { + self.send_disconnect_notification().await; + } + + self.mark_disconnected().await; + } + + async fn handle_connection_iteration( &mut self, clients: Arc>>, + retry_on_connect_failure: bool, ) -> Result<(), GatewayError> { #[cfg(test)] let _ = &self.certs_rx; let endpoint = self.endpoint()?; let uri = endpoint.uri().to_string(); - loop { - #[cfg(not(test))] - let channel = { - let settings = Settings::get_current_settings(); - let Some(ca_cert_der) = settings.ca_cert_der else { - return Err(GatewayError::EndpointError( - "Core CA is not setup, can't create a Gateway endpoint.".to_string(), - )); - }; - let tls_config = - tls_certs::client_config(&ca_cert_der, self.certs_rx.clone(), self.gateway.id) - .map_err(|err| GatewayError::EndpointError(err.to_string()))?; - let connector = HttpsConnectorBuilder::new() - .with_tls_config(tls_config) - .https_only() - .enable_http2() - .build(); - let connector = HttpsSchemeConnector::new(connector); - endpoint.connect_with_connector_lazy(connector) + + #[cfg(not(test))] + let channel = { + let settings = Settings::get_current_settings(); + let Some(ca_cert_der) = settings.ca_cert_der else { + return Err(GatewayError::EndpointError( + "Core CA is not setup, can't create a Gateway endpoint.".to_string(), + )); }; - #[cfg(test)] - let channel = endpoint.connect_with_connector_lazy(tower::service_fn( - |_: tonic::transport::Uri| async { - Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new( - tokio::net::UnixStream::connect(super::TONIC_SOCKET).await?, - )) + let tls_config = + tls_certs::client_config(&ca_cert_der, self.certs_rx.clone(), self.gateway.id) + .map_err(|err| GatewayError::EndpointError(err.to_string()))?; + let connector = HttpsConnectorBuilder::new() + .with_tls_config(tls_config) + .https_only() + .enable_http2() + .build(); + let connector = HttpsSchemeConnector::new(connector); + endpoint.connect_with_connector_lazy(connector) + }; + #[cfg(test)] + let channel = { + let socket_path = self.test_transport.socket_path()?.to_path_buf(); + endpoint.connect_with_connector_lazy(tower::service_fn( + move |_: tonic::transport::Uri| { + let socket_path = socket_path.clone(); + async move { + Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new( + tokio::net::UnixStream::connect(socket_path).await?, + )) + } }, - )); + )) + }; - debug!("Connecting to Gateway {uri}"); - let interceptor = ClientVersionInterceptor::new( - Version::parse(VERSION).expect("failed to parse self version"), - ); - let mut client = gateway_client::GatewayClient::with_interceptor(channel, interceptor); - clients - .lock() - .expect("GatewayHandler failed to lock clients") - .insert(self.gateway.id, client.clone()); - let (tx, rx) = mpsc::unbounded_channel(); - let response = match client.bidi(UnboundedReceiverStream::new(rx)).await { - Ok(response) => response, - Err(err) => { - error!("Failed to connect to Gateway {uri}, retrying: {err}"); + debug!("Connecting to Gateway {uri}"); + let interceptor = ClientVersionInterceptor::new( + Version::parse(VERSION).expect("failed to parse self version"), + ); + let mut client = gateway_client::GatewayClient::with_interceptor(channel, interceptor); + clients + .lock() + .expect("GatewayHandler failed to lock clients") + .insert(self.gateway.id, client.clone()); + let (tx, rx) = mpsc::unbounded_channel(); + let response = match client.bidi(UnboundedReceiverStream::new(rx)).await { + Ok(response) => response, + Err(err) => { + error!("Failed to connect to Gateway {uri}, retrying: {err}"); + if retry_on_connect_failure { sleep(TEN_SECS).await; - continue; + return Ok(()); } - }; - info!("Connected to Defguard Gateway {uri}"); - - let maybe_info = defguard_version::ComponentInfo::from_metadata(response.metadata()); - let (version, _info) = defguard_version::get_tracing_variables(&maybe_info); - if let Some(mut gateway) = Gateway::find_by_id(&self.pool, self.gateway.id).await? { - gateway.version = Some(version.to_string()); - gateway.save(&self.pool).await?; + return Err(err.into()); } + }; + info!("Connected to Defguard Gateway {uri}"); - let mut resp_stream = response.into_inner(); - let mut config_sent = false; + let maybe_info = defguard_version::ComponentInfo::from_metadata(response.metadata()); + let (version, _info) = defguard_version::get_tracing_variables(&maybe_info); - 'message: loop { - match resp_stream.message().await { - Ok(None) => { - info!("Stream was closed by the sender."); - break 'message; - } - Ok(Some(received)) => { - info!("Received message from Gateway."); - debug!("Message from Gateway {uri}"); + if let Some(mut gateway) = Gateway::find_by_id(&self.pool, self.gateway.id).await? { + gateway.version = Some(version.to_string()); + gateway.save(&self.pool).await?; + } - match received.payload { - Some(core_request::Payload::ConfigRequest(_config_request)) => { - if config_sent { - warn!( - "Ignoring repeated configuration request from {}", - self.gateway + let mut resp_stream = response.into_inner(); + let mut config_sent = false; + + loop { + match resp_stream.message().await { + Ok(None) => { + info!("Stream was closed by the sender."); + self.mark_disconnected().await; + return Ok(()); + } + Ok(Some(received)) => { + info!("Received message from Gateway."); + debug!("Message from Gateway {uri}"); + + match received.payload { + Some(core_request::Payload::ConfigRequest(_config_request)) => { + if config_sent { + warn!( + "Ignoring repeated configuration request from {}", + self.gateway + ); + continue; + } + + match self.send_configuration(&tx).await { + Ok(network) => { + info!("Sent configuration to {}", self.gateway); + config_sent = true; + let _ = self.gateway.touch_connected(&self.pool).await; + let mut updates_handler = GatewayUpdatesHandler::new( + self.gateway.location_id, + network, + self.gateway.name.clone(), + self.events_tx.subscribe(), + tx.clone(), ); - continue; + tokio::spawn(async move { + updates_handler.run().await; + }); } - - // Send network configuration to Gateway. - match self.send_configuration(&tx).await { - Ok(network) => { - info!("Sent configuration to {}", self.gateway); - config_sent = true; - let _ = self.gateway.touch_connected(&self.pool).await; - let mut updates_handler = GatewayUpdatesHandler::new( - self.gateway.location_id, - network, - self.gateway.name.clone(), - self.events_tx.subscribe(), - tx.clone(), - ); - tokio::spawn(async move { - updates_handler.run().await; - }); - } - Err(err) => { - error!( - "Failed to send configuration to {}: {err}", - self.gateway - ); - } + Err(err) => { + error!( + "Failed to send configuration to {}: {err}", + self.gateway + ); } } - Some(core_request::Payload::PeerStats(peer_stats)) => { - if !config_sent { + } + Some(core_request::Payload::PeerStats(peer_stats)) => { + if !config_sent { + warn!( + "Ignoring peer statistics from {} because it hasn't \ + authorized itself", + self.gateway + ); + continue; + } + + match try_protos_into_stats_message( + peer_stats.clone(), + self.gateway.location_id, + self.gateway.id, + ) { + None => { warn!( - "Ignoring peer statistics from {} because it hasn't \ - authorized itself", - self.gateway + "Failed to parse peer stats update. Skipping sending \ + message to session manager." ); - continue; } - - // convert stats to DB storage format - match try_protos_into_stats_message( - peer_stats.clone(), - self.gateway.location_id, - self.gateway.id, - ) { - None => { - warn!( - "Failed to parse peer stats update. Skipping sending \ - message to session manager." + Some(message) => { + if let Err(err) = self.peer_stats_tx.send(message) { + error!( + "Failed to send peers stats update to session manager: {err}" ); } - Some(message) => { - if let Err(err) = self.peer_stats_tx.send(message) { - error!( - "Failed to send peers stats update to session manager: {err}" - ); - } - } } } - None => (), } + None => (), } - Err(err) => { - error!("Disconnected from Gateway at {uri}, error: {err}"); - // Important: call this funtion before setting disconnection time. - self.send_disconnect_notification().await; - let _ = self.gateway.touch_disconnected(&self.pool).await; + } + Err(err) => { + error!("Disconnected from Gateway at {uri}, error: {err}"); + self.handle_disconnection_error().await; + if retry_on_connect_failure { debug!("Waiting 10s to re-establish the connection"); sleep(TEN_SECS).await; - break 'message; } + return Ok(()); } } } } + + /// Connect to Gateway and handle its messages through gRPC. + pub(super) async fn handle_connection( + &mut self, + clients: Arc>>, + ) -> Result<(), GatewayError> { + loop { + self.handle_connection_iteration(Arc::clone(&clients), true) + .await?; + } + } + + #[cfg(test)] + pub(super) async fn handle_connection_once( + &mut self, + clients: Arc>>, + ) -> Result<(), GatewayError> { + self.handle_connection_iteration(clients, false).await + } } /// Helper struct for handling gateway events. diff --git a/crates/defguard_gateway_manager/src/lib.rs b/crates/defguard_gateway_manager/src/lib.rs index 495138560d..d0eb05ec6e 100644 --- a/crates/defguard_gateway_manager/src/lib.rs +++ b/crates/defguard_gateway_manager/src/lib.rs @@ -26,11 +26,11 @@ extern crate tracing; mod certs; mod error; mod handler; -// #[cfg(test)] -// mod tests; #[cfg(test)] -static TONIC_SOCKET: &str = "tonic.sock"; +#[path = "tests/mod.rs"] +mod tests; + const GATEWAY_TABLE_TRIGGER: &str = "gateway_change"; const GATEWAY_RECONNECT_DELAY: Duration = Duration::from_secs(5); const TEN_SECS: Duration = Duration::from_secs(10); diff --git a/crates/defguard_gateway_manager/src/tests/handler.rs b/crates/defguard_gateway_manager/src/tests/handler.rs new file mode 100644 index 0000000000..072d9f7dff --- /dev/null +++ b/crates/defguard_gateway_manager/src/tests/handler.rs @@ -0,0 +1,274 @@ +use defguard_proto::gateway::{ + CoreResponse, Update, UpdateType, + core_response, + update::{self}, +}; +use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; +use tonic::Status; + +use super::support::{HandlerTestContext, build_peer_stats, reload_gateway}; +use defguard_core::grpc::GatewayEvent; + +macro_rules! assert_send_ok { + ($result:expr, $message:literal) => { + match $result { + Ok(value) => value, + Err(_) => panic!($message), + } + }; +} + +macro_rules! panic_unexpected { + ($message:literal) => { + panic!($message) + }; +} + +#[sqlx::test] +async fn test_sends_configuration_on_first_config_request( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + context.mock_gateway().send_config_request(); + let outbound = context.mock_gateway_mut().recv_outbound().await; + + match outbound.payload { + Some(core_response::Payload::Config(config)) => { + assert_eq!(config.name, context.network.name); + assert_eq!(config.port, context.network.port as u32); + assert_eq!(config.peers, Vec::new()); + } + _ => panic_unexpected!("expected configuration response"), + } + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_does_not_send_configuration_before_gateway_requests_it( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + let gateway_before = context.reload_gateway().await; + assert!(!gateway_before.is_connected()); + + context.mock_gateway_mut().expect_no_outbound().await; + + let gateway_after = context.reload_gateway().await; + assert!(!gateway_after.is_connected()); + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_ignores_repeated_config_request(_: PgPoolOptions, options: PgConnectOptions) { + let mut context = HandlerTestContext::new(options).await; + + context.mock_gateway().send_config_request(); + let first_outbound = context.mock_gateway_mut().recv_outbound().await; + assert!(matches!( + first_outbound.payload, + Some(core_response::Payload::Config(_)) + )); + + context.mock_gateway().send_config_request(); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_ignores_peer_stats_before_config_handshake( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + context + .mock_gateway() + .send_peer_stats(build_peer_stats("203.0.113.10:51820")); + + context.expect_no_peer_stats().await; + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_forwards_valid_peer_stats_after_config( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + context.mock_gateway().send_config_request(); + let _ = context.mock_gateway_mut().recv_outbound().await; + context + .mock_gateway() + .send_peer_stats(build_peer_stats("203.0.113.10:51820")); + + let forwarded = context.recv_peer_stats().await; + assert_eq!(forwarded.location_id, context.network.id); + assert_eq!(forwarded.gateway_id, context.gateway.id); + assert_eq!(forwarded.device_pubkey, "peer-public-key"); + assert_eq!(forwarded.endpoint.to_string(), "203.0.113.10:51820"); + assert_eq!(forwarded.upload, 123); + assert_eq!(forwarded.download, 456); + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_drops_malformed_or_missing_endpoint_peer_stats( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + context.mock_gateway().send_config_request(); + let _ = context.mock_gateway_mut().recv_outbound().await; + + context.mock_gateway().send_peer_stats(build_peer_stats("")); + context.expect_no_peer_stats().await; + + context + .mock_gateway() + .send_peer_stats(build_peer_stats("not-a-socket-address")); + context.expect_no_peer_stats().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_matching_location_network_event_produces_outbound_update( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + let _ = context.complete_config_handshake().await; + + assert_send_ok!( + context.events_tx().send(GatewayEvent::NetworkDeleted( + context.network.id, + context.network.name.clone(), + )), + "failed to broadcast gateway event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_network_delete_update(outbound, &context.network.name); + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_different_location_network_event_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + let other_network = context.create_other_network().await; + assert_ne!(other_network.id, context.network.id); + + let _ = context.complete_config_handshake().await; + assert_send_ok!( + context.events_tx().send(GatewayEvent::NetworkDeleted( + other_network.id, + other_network.name.clone(), + )), + "failed to broadcast unrelated gateway event" + ); + + context.mock_gateway_mut().expect_no_outbound().await; + + assert_send_ok!( + context.events_tx().send(GatewayEvent::NetworkDeleted( + context.network.id, + context.network.name.clone(), + )), + "failed to broadcast owned gateway event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_network_delete_update(outbound, &context.network.name); + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_gateway_is_marked_connected_after_successful_config_handshake( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + let gateway_before = context.reload_gateway().await; + assert!(!gateway_before.is_connected()); + + let gateway_after = context.complete_config_handshake().await; + assert!(gateway_after.is_connected()); + assert!(gateway_after.connected_at.is_some()); + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_gateway_is_marked_disconnected_when_stream_closes( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + let connected_gateway = context.complete_config_handshake().await; + assert!(connected_gateway.is_connected()); + + let pool = context.pool.clone(); + let gateway_id = context.gateway.id; + let mock_gateway = context.finish().await; + let disconnected_gateway = reload_gateway(&pool, gateway_id).await; + assert!(!disconnected_gateway.is_connected()); + assert!(disconnected_gateway.disconnected_at.is_some()); + + mock_gateway.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_gateway_is_marked_disconnected_when_stream_errors( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + let _ = context.complete_config_handshake().await; + + context + .mock_gateway() + .send_stream_error(Status::internal("mock gateway stream failure")); + + let pool = context.pool.clone(); + let gateway_id = context.gateway.id; + let mock_gateway = context.finish_after_error().await; + let disconnected_gateway = reload_gateway(&pool, gateway_id).await; + assert!(!disconnected_gateway.is_connected()); + assert!(disconnected_gateway.disconnected_at.is_some()); + + mock_gateway.expect_server_finished().await; +} + +fn assert_network_delete_update(outbound: CoreResponse, expected_network_name: &str) { + match outbound.payload { + Some(core_response::Payload::Update(Update { + update_type, + update: Some(update::Update::Network(network)), + })) => { + assert_eq!(update_type, UpdateType::Delete as i32); + assert_eq!(network.name, expected_network_name); + } + _ => panic_unexpected!("expected network delete update"), + } +} diff --git a/crates/defguard_gateway_manager/src/tests/mod.rs b/crates/defguard_gateway_manager/src/tests/mod.rs new file mode 100644 index 0000000000..f763168bef --- /dev/null +++ b/crates/defguard_gateway_manager/src/tests/mod.rs @@ -0,0 +1,2 @@ +mod handler; +mod support; diff --git a/crates/defguard_gateway_manager/src/tests/support.rs b/crates/defguard_gateway_manager/src/tests/support.rs new file mode 100644 index 0000000000..06178722f0 --- /dev/null +++ b/crates/defguard_gateway_manager/src/tests/support.rs @@ -0,0 +1,475 @@ +use std::{ + collections::HashMap, + io, + path::PathBuf, + sync::{ + Arc, Mutex, + atomic::{AtomicU64, Ordering}, + }, + time::Duration, +}; + +use defguard_common::{ + db::{ + Id, + models::{gateway::Gateway, wireguard::WireguardNetwork}, + setup_pool, + }, + messages::peer_stats_update::PeerStatsUpdate, +}; +use defguard_core::grpc::GatewayEvent; +use defguard_proto::gateway::{ + ConfigurationRequest, CoreRequest, CoreResponse, PeerStats, + core_request, gateway_server, +}; +use sqlx::{PgPool, postgres::PgConnectOptions}; +use tokio::{ + net::UnixListener, + sync::{ + broadcast, + mpsc::{self, UnboundedReceiver, UnboundedSender}, + oneshot, watch, + }, + task::JoinHandle, + time::timeout, +}; +use tokio_stream::{once, wrappers::UnboundedReceiverStream}; +use tonic::{Request, Response, Status, Streaming, transport::Server}; + +use crate::{Client, error::GatewayError, handler::GatewayHandler}; + +const TEST_TIMEOUT: Duration = Duration::from_secs(2); + +macro_rules! assert_some { + ($expr:expr, $message:literal) => { + match $expr { + Some(value) => value, + None => panic!($message), + } + }; +} + +static TEST_ID: AtomicU64 = AtomicU64::new(0); + +fn next_test_id() -> u64 { + TEST_ID.fetch_add(1, Ordering::Relaxed) +} + +fn unique_name(prefix: &str) -> String { + format!("{prefix}-{}", next_test_id()) +} + +fn unique_socket_path() -> PathBuf { + PathBuf::from(format!( + "/tmp/defguard-gateway-manager-{}-{}.sock", + std::process::id(), + next_test_id() + )) +} + +#[derive(Clone)] +struct MockGatewayService { + state: Arc, +} + +struct MockGatewayState { + outbound_tx: UnboundedSender, + inbound_rx: Mutex>>>, + connected_tx: Mutex>>, +} + +impl MockGatewayState { + fn notify_connected(&self) { + if let Some(tx) = self + .connected_tx + .lock() + .expect("failed to lock connected notifier") + .take() + { + let _ = tx.send(()); + } + } + + fn take_inbound_rx( + &self, + ) -> Result>, Status> { + self.inbound_rx + .lock() + .expect("failed to lock inbound receiver") + .take() + .ok_or_else(|| Status::failed_precondition("mock gateway already connected")) + } +} + +#[tonic::async_trait] +impl gateway_server::Gateway for MockGatewayService { + type BidiStream = UnboundedReceiverStream>; + + async fn bidi( + &self, + request: Request>, + ) -> Result, Status> { + let inbound_rx = self.state.take_inbound_rx()?; + self.state.notify_connected(); + + let mut outbound_stream = request.into_inner(); + let outbound_tx = self.state.outbound_tx.clone(); + tokio::spawn(async move { + while let Ok(Some(response)) = outbound_stream.message().await { + if outbound_tx.send(response).is_err() { + break; + } + } + }); + + Ok(Response::new(UnboundedReceiverStream::new(inbound_rx))) + } + + async fn purge(&self, _request: Request<()>) -> Result, Status> { + Ok(Response::new(())) + } +} + +pub(super) struct MockGatewayHarness { + socket_path: PathBuf, + inbound_tx: Option>>, + outbound_rx: UnboundedReceiver, + connected_rx: oneshot::Receiver<()>, + server_task: Option>>, + next_message_id: AtomicU64, +} + +impl MockGatewayHarness { + pub(super) async fn start() -> Self { + let socket_path = unique_socket_path(); + let _ = std::fs::remove_file(&socket_path); + + let listener = + UnixListener::bind(&socket_path).expect("failed to bind mock gateway unix socket"); + let (outbound_tx, outbound_rx) = mpsc::unbounded_channel(); + let (inbound_tx, inbound_rx) = mpsc::unbounded_channel(); + let (connected_tx, connected_rx) = oneshot::channel(); + let service = MockGatewayService { + state: Arc::new(MockGatewayState { + outbound_tx, + inbound_rx: Mutex::new(Some(inbound_rx)), + connected_tx: Mutex::new(Some(connected_tx)), + }), + }; + + let server_task = tokio::spawn(async move { + let (stream, _) = listener.accept().await?; + Server::builder() + .add_service(gateway_server::GatewayServer::new(service)) + .serve_with_incoming(once(Ok::<_, io::Error>(stream))) + .await + .map_err(io::Error::other) + }); + + Self { + socket_path, + inbound_tx: Some(inbound_tx), + outbound_rx, + connected_rx, + server_task: Some(server_task), + next_message_id: AtomicU64::new(1), + } + } + + pub(super) fn socket_path(&self) -> PathBuf { + self.socket_path.clone() + } + + pub(super) async fn wait_connected(&mut self) { + timeout(TEST_TIMEOUT, &mut self.connected_rx) + .await + .expect("timed out waiting for mock gateway connection") + .expect("mock gateway connection notifier dropped"); + } + + pub(super) fn send_config_request(&self) { + let request = ConfigurationRequest { + hostname: "mock-gateway".to_string(), + ..Default::default() + }; + self.send_request(CoreRequest { + id: self.next_message_id.fetch_add(1, Ordering::Relaxed), + payload: Some(core_request::Payload::ConfigRequest(request)), + }); + } + + pub(super) fn send_peer_stats(&self, peer_stats: PeerStats) { + self.send_request(CoreRequest { + id: self.next_message_id.fetch_add(1, Ordering::Relaxed), + payload: Some(core_request::Payload::PeerStats(peer_stats)), + }); + } + + pub(super) fn send_stream_error(&self, status: Status) { + self.inbound_tx + .as_ref() + .expect("mock gateway inbound channel already closed") + .send(Err(status)) + .expect("failed to inject inbound stream error"); + } + + fn send_request(&self, request: CoreRequest) { + self.inbound_tx + .as_ref() + .expect("mock gateway inbound channel already closed") + .send(Ok(request)) + .expect("failed to inject mock gateway request"); + } + + pub(super) fn close_stream(&mut self) { + self.inbound_tx.take(); + } + + pub(super) async fn recv_outbound(&mut self) -> CoreResponse { + timeout(TEST_TIMEOUT, self.outbound_rx.recv()) + .await + .expect("timed out waiting for outbound response") + .expect("mock gateway outbound response channel closed unexpectedly") + } + + pub(super) async fn expect_no_outbound(&mut self) { + if let Ok(Some(_message)) = timeout(Duration::from_millis(200), self.outbound_rx.recv()).await { + panic!("unexpected outbound response"); + } + } + + pub(super) async fn expect_server_finished(mut self) { + let server_task = assert_some!( + self.server_task.take(), + "mock gateway server task already taken" + ); + let server_result = timeout(TEST_TIMEOUT, server_task) + .await + .expect("timed out waiting for mock gateway server to finish") + .expect("mock gateway server task panicked"); + server_result.expect("mock gateway server exited with error"); + } +} + +impl Drop for MockGatewayHarness { + fn drop(&mut self) { + if let Some(server_task) = self.server_task.take() { + server_task.abort(); + } + let _ = std::fs::remove_file(&self.socket_path); + } +} + +pub(super) struct HandlerTestContext { + pub(super) pool: PgPool, + pub(super) network: WireguardNetwork, + pub(super) gateway: Gateway, + pub(super) peer_stats_rx: UnboundedReceiver, + events_tx: Option>, + pub(super) mock_gateway: Option, + handler_task: Option>>, +} + +impl HandlerTestContext { + pub(super) async fn new(options: PgConnectOptions) -> Self { + let pool = setup_pool(options).await; + let network = create_network(&pool).await; + let gateway = create_gateway(&pool, network.id).await; + let (events_tx, _) = broadcast::channel(16); + let (peer_stats_tx, peer_stats_rx) = mpsc::unbounded_channel(); + let (_, certs_rx) = watch::channel(Arc::new(HashMap::new())); + let mut mock_gateway = MockGatewayHarness::start().await; + let mut handler = GatewayHandler::new_with_test_socket( + gateway.clone(), + pool.clone(), + events_tx.clone(), + peer_stats_tx, + certs_rx, + mock_gateway.socket_path(), + ) + .expect("failed to create gateway handler"); + let clients = Arc::>>::default(); + let handler_task = tokio::spawn(async move { handler.handle_connection_once(clients).await }); + + mock_gateway.wait_connected().await; + + Self { + pool, + network, + gateway, + peer_stats_rx, + events_tx: Some(events_tx), + mock_gateway: Some(mock_gateway), + handler_task: Some(handler_task), + } + } + + pub(super) fn events_tx(&self) -> &broadcast::Sender { + self.events_tx + .as_ref() + .expect("events sender already taken from context") + } + + pub(super) fn mock_gateway(&self) -> &MockGatewayHarness { + self.mock_gateway + .as_ref() + .expect("mock gateway already taken from context") + } + + pub(super) fn mock_gateway_mut(&mut self) -> &mut MockGatewayHarness { + self.mock_gateway + .as_mut() + .expect("mock gateway already taken from context") + } + + pub(super) async fn reload_gateway(&self) -> Gateway { + Gateway::find_by_id(&self.pool, self.gateway.id) + .await + .expect("failed to query gateway from database") + .expect("expected gateway in database") + } + + pub(super) async fn create_other_network(&self) -> WireguardNetwork { + create_network(&self.pool).await + } + + pub(super) async fn expect_no_peer_stats(&mut self) { + if let Ok(Some(message)) = timeout(Duration::from_millis(200), self.peer_stats_rx.recv()).await { + panic!("unexpected peer stats update: {message:?}"); + } + } + + pub(super) async fn recv_peer_stats(&mut self) -> PeerStatsUpdate { + timeout(TEST_TIMEOUT, self.peer_stats_rx.recv()) + .await + .expect("timed out waiting for peer stats update") + .expect("peer stats channel unexpectedly closed") + } + + pub(super) async fn complete_config_handshake(&mut self) -> Gateway { + self.mock_gateway().send_config_request(); + let _ = self.mock_gateway_mut().recv_outbound().await; + let connected_gateway = wait_for_gateway_connection_state( + &self.pool, + self.gateway.id, + true, + ) + .await; + timeout(TEST_TIMEOUT, async { + while self.events_tx().receiver_count() == 0 { + tokio::time::sleep(Duration::from_millis(20)).await; + } + }) + .await + .expect("timed out waiting for gateway updates handler subscription"); + tokio::task::yield_now().await; + connected_gateway + } + + pub(super) async fn finish(mut self) -> MockGatewayHarness { + let mut mock_gateway = + assert_some!(self.mock_gateway.take(), "mock gateway already taken from context"); + mock_gateway.close_stream(); + let handler_task = self + .handler_task + .as_mut() + .expect("handler task already taken from context"); + let result = timeout(TEST_TIMEOUT, handler_task) + .await + .expect("timed out waiting for handler task to finish") + .expect("gateway handler task panicked"); + result.expect("gateway handler returned an unexpected error"); + self.handler_task.take(); + self.events_tx.take(); + mock_gateway + } + + pub(super) async fn finish_after_error(mut self) -> MockGatewayHarness { + let mock_gateway = + assert_some!(self.mock_gateway.take(), "mock gateway already taken from context"); + let handler_task = self + .handler_task + .as_mut() + .expect("handler task already taken from context"); + let result = timeout(TEST_TIMEOUT, handler_task) + .await + .expect("timed out waiting for handler task to finish after stream error") + .expect("gateway handler task panicked after stream error"); + result.expect("gateway handler returned an unexpected error after stream error"); + self.handler_task.take(); + self.events_tx.take(); + mock_gateway + } +} + +impl Drop for HandlerTestContext { + fn drop(&mut self) { + if let Some(handler_task) = self.handler_task.take() { + handler_task.abort(); + } + } +} + +pub(super) async fn reload_gateway(pool: &PgPool, gateway_id: Id) -> Gateway { + Gateway::find_by_id(pool, gateway_id) + .await + .expect("failed to query gateway from database") + .expect("expected gateway in database") +} + +pub(super) async fn wait_for_gateway_connection_state( + pool: &PgPool, + gateway_id: Id, + expected_connected: bool, +) -> Gateway { + timeout(TEST_TIMEOUT, async { + loop { + let gateway = reload_gateway(pool, gateway_id).await; + if gateway.is_connected() == expected_connected { + return gateway; + } + + tokio::time::sleep(Duration::from_millis(20)).await; + } + }) + .await + .expect("timed out waiting for gateway connection state change") +} + +pub(super) fn build_peer_stats(endpoint: &str) -> PeerStats { + PeerStats { + public_key: "peer-public-key".to_string(), + endpoint: endpoint.to_string(), + upload: 123, + download: 456, + keepalive_interval: 25, + latest_handshake: 1_700_000_000, + allowed_ips: "10.10.0.2/32".to_string(), + } +} + +async fn create_network(pool: &PgPool) -> WireguardNetwork { + let mut network = WireguardNetwork { + name: unique_name("network"), + endpoint: "198.51.100.10".to_string(), + port: 51820, + ..Default::default() + }; + network + .try_set_address("10.10.0.1/24") + .expect("failed to set network address"); + network.save(pool).await.expect("failed to create test network") +} + +async fn create_gateway(pool: &PgPool, location_id: Id) -> Gateway { + Gateway::new( + location_id, + unique_name("gateway"), + "127.0.0.1".to_string(), + 51820, + "test-admin".to_string(), + ) + .save(pool) + .await + .expect("failed to create test gateway") +} From 501293d09f7c29d134abfa0af21241b553ba9562 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Mon, 16 Mar 2026 16:51:48 +0100 Subject: [PATCH 03/36] formatting --- .../tests/integration/grpc/common/mod.rs | 20 +++++--- .../tests/integration/grpc/health.rs | 4 +- .../tests/integration/grpc/worker.rs | 51 +++++++++++-------- .../src/tests/handler.rs | 8 +-- .../src/tests/support.rs | 43 +++++++++------- 5 files changed, 69 insertions(+), 57 deletions(-) diff --git a/crates/defguard_core/tests/integration/grpc/common/mod.rs b/crates/defguard_core/tests/integration/grpc/common/mod.rs index ee3fe24645..3d477cfdd1 100644 --- a/crates/defguard_core/tests/integration/grpc/common/mod.rs +++ b/crates/defguard_core/tests/integration/grpc/common/mod.rs @@ -15,13 +15,19 @@ use defguard_core::{ grpc::{AUTHORIZATION_HEADER, WorkerState, build_grpc_service_router}, }; use hyper_util::rt::TokioIo; -use sqlx::{PgPool, postgres::{PgConnectOptions, PgPoolOptions}}; +use sqlx::{ + PgPool, + postgres::{PgConnectOptions, PgPoolOptions}, +}; use tokio::{ io::DuplexStream, sync::mpsc::{UnboundedReceiver, unbounded_channel}, task::JoinHandle, }; -use tonic::{Request, transport::{Channel, Endpoint, Server, Uri, server::Router}}; +use tonic::{ + Request, + transport::{Channel, Endpoint, Server, Uri, server::Router}, +}; use tower::service_fn; use crate::common::initialize_users; @@ -163,11 +169,9 @@ pub(crate) fn worker_request(message: T, username: &str) -> Request { } fn initialize_jwt_secrets() { - JWT_SECRETS.call_once(|| { - unsafe { - env::set_var(AUTH_SECRET_ENV, "defguard-test-auth-secret"); - env::set_var(GATEWAY_SECRET_ENV, "defguard-test-gateway-secret"); - env::set_var(YUBIBRIDGE_SECRET_ENV, "defguard-test-yubibridge-secret"); - } + JWT_SECRETS.call_once(|| unsafe { + env::set_var(AUTH_SECRET_ENV, "defguard-test-auth-secret"); + env::set_var(GATEWAY_SECRET_ENV, "defguard-test-gateway-secret"); + env::set_var(YUBIBRIDGE_SECRET_ENV, "defguard-test-yubibridge-secret"); }); } diff --git a/crates/defguard_core/tests/integration/grpc/health.rs b/crates/defguard_core/tests/integration/grpc/health.rs index 975ccb8319..c2e1e5a348 100644 --- a/crates/defguard_core/tests/integration/grpc/health.rs +++ b/crates/defguard_core/tests/integration/grpc/health.rs @@ -1,8 +1,6 @@ use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use tonic_health::pb::{ - HealthCheckRequest, - health_check_response::ServingStatus, - health_client::HealthClient, + HealthCheckRequest, health_check_response::ServingStatus, health_client::HealthClient, }; use super::common::{make_grpc_test_server, setup_grpc_pool}; diff --git a/crates/defguard_core/tests/integration/grpc/worker.rs b/crates/defguard_core/tests/integration/grpc/worker.rs index 90856d2d9e..0a319c71b4 100644 --- a/crates/defguard_core/tests/integration/grpc/worker.rs +++ b/crates/defguard_core/tests/integration/grpc/worker.rs @@ -1,10 +1,7 @@ use claims::assert_matches; use defguard_common::db::models::{AuthenticationKey, AuthenticationKeyType, User, YubiKey}; -use defguard_proto::worker::{ - JobStatus, Worker, - worker_service_client::WorkerServiceClient, -}; use defguard_core::db::AppEvent; +use defguard_proto::worker::{JobStatus, Worker, worker_service_client::WorkerServiceClient}; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use tokio::sync::mpsc::error::TryRecvError; use tonic::Code; @@ -196,7 +193,11 @@ async fn set_job_done_success_removes_job_and_stores_status( assert!(status.success); assert_eq!(status.serial, "yk-serial-1"); assert_eq!(status.error, ""); - assert!(state.get_job("worker-1", std::net::IpAddr::from([127, 0, 0, 1])).is_none()); + assert!( + state + .get_job("worker-1", std::net::IpAddr::from([127, 0, 0, 1])) + .is_none() + ); } let user = User::find_by_username(&pool, "hpotter") @@ -299,14 +300,18 @@ async fn set_job_done_failure_stores_status_without_keys_or_event( .await .expect("user query should succeed") .expect("user should exist"); - assert!(YubiKey::find_by_user_id(&pool, user.id) - .await - .expect("yubikey query should succeed") - .is_empty()); - assert!(AuthenticationKey::find_by_user_id(&pool, user.id, None) - .await - .expect("auth key query should succeed") - .is_empty()); + assert!( + YubiKey::find_by_user_id(&pool, user.id) + .await + .expect("yubikey query should succeed") + .is_empty() + ); + assert!( + AuthenticationKey::find_by_user_id(&pool, user.id, None) + .await + .expect("auth key query should succeed") + .is_empty() + ); assert_matches!(server.app_event_rx.try_recv(), Err(TryRecvError::Empty)); } @@ -351,14 +356,18 @@ async fn set_job_done_unknown_job_is_ignored(_: PgPoolOptions, options: PgConnec .await .expect("user query should succeed") .expect("user should exist"); - assert!(YubiKey::find_by_user_id(&pool, user.id) - .await - .expect("yubikey query should succeed") - .is_empty()); - assert!(AuthenticationKey::find_by_user_id(&pool, user.id, None) - .await - .expect("auth key query should succeed") - .is_empty()); + assert!( + YubiKey::find_by_user_id(&pool, user.id) + .await + .expect("yubikey query should succeed") + .is_empty() + ); + assert!( + AuthenticationKey::find_by_user_id(&pool, user.id, None) + .await + .expect("auth key query should succeed") + .is_empty() + ); assert_matches!(server.app_event_rx.try_recv(), Err(TryRecvError::Empty)); } diff --git a/crates/defguard_gateway_manager/src/tests/handler.rs b/crates/defguard_gateway_manager/src/tests/handler.rs index 072d9f7dff..35c4cd546a 100644 --- a/crates/defguard_gateway_manager/src/tests/handler.rs +++ b/crates/defguard_gateway_manager/src/tests/handler.rs @@ -1,6 +1,5 @@ use defguard_proto::gateway::{ - CoreResponse, Update, UpdateType, - core_response, + CoreResponse, Update, UpdateType, core_response, update::{self}, }; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; @@ -99,10 +98,7 @@ async fn test_ignores_peer_stats_before_config_handshake( } #[sqlx::test] -async fn test_forwards_valid_peer_stats_after_config( - _: PgPoolOptions, - options: PgConnectOptions, -) { +async fn test_forwards_valid_peer_stats_after_config(_: PgPoolOptions, options: PgConnectOptions) { let mut context = HandlerTestContext::new(options).await; context.mock_gateway().send_config_request(); diff --git a/crates/defguard_gateway_manager/src/tests/support.rs b/crates/defguard_gateway_manager/src/tests/support.rs index 06178722f0..8ca5513a33 100644 --- a/crates/defguard_gateway_manager/src/tests/support.rs +++ b/crates/defguard_gateway_manager/src/tests/support.rs @@ -19,8 +19,7 @@ use defguard_common::{ }; use defguard_core::grpc::GatewayEvent; use defguard_proto::gateway::{ - ConfigurationRequest, CoreRequest, CoreResponse, PeerStats, - core_request, gateway_server, + ConfigurationRequest, CoreRequest, CoreResponse, PeerStats, core_request, gateway_server, }; use sqlx::{PgPool, postgres::PgConnectOptions}; use tokio::{ @@ -90,9 +89,7 @@ impl MockGatewayState { } } - fn take_inbound_rx( - &self, - ) -> Result>, Status> { + fn take_inbound_rx(&self) -> Result>, Status> { self.inbound_rx .lock() .expect("failed to lock inbound receiver") @@ -233,7 +230,9 @@ impl MockGatewayHarness { } pub(super) async fn expect_no_outbound(&mut self) { - if let Ok(Some(_message)) = timeout(Duration::from_millis(200), self.outbound_rx.recv()).await { + if let Ok(Some(_message)) = + timeout(Duration::from_millis(200), self.outbound_rx.recv()).await + { panic!("unexpected outbound response"); } } @@ -289,7 +288,8 @@ impl HandlerTestContext { ) .expect("failed to create gateway handler"); let clients = Arc::>>::default(); - let handler_task = tokio::spawn(async move { handler.handle_connection_once(clients).await }); + let handler_task = + tokio::spawn(async move { handler.handle_connection_once(clients).await }); mock_gateway.wait_connected().await; @@ -334,7 +334,9 @@ impl HandlerTestContext { } pub(super) async fn expect_no_peer_stats(&mut self) { - if let Ok(Some(message)) = timeout(Duration::from_millis(200), self.peer_stats_rx.recv()).await { + if let Ok(Some(message)) = + timeout(Duration::from_millis(200), self.peer_stats_rx.recv()).await + { panic!("unexpected peer stats update: {message:?}"); } } @@ -349,12 +351,8 @@ impl HandlerTestContext { pub(super) async fn complete_config_handshake(&mut self) -> Gateway { self.mock_gateway().send_config_request(); let _ = self.mock_gateway_mut().recv_outbound().await; - let connected_gateway = wait_for_gateway_connection_state( - &self.pool, - self.gateway.id, - true, - ) - .await; + let connected_gateway = + wait_for_gateway_connection_state(&self.pool, self.gateway.id, true).await; timeout(TEST_TIMEOUT, async { while self.events_tx().receiver_count() == 0 { tokio::time::sleep(Duration::from_millis(20)).await; @@ -367,8 +365,10 @@ impl HandlerTestContext { } pub(super) async fn finish(mut self) -> MockGatewayHarness { - let mut mock_gateway = - assert_some!(self.mock_gateway.take(), "mock gateway already taken from context"); + let mut mock_gateway = assert_some!( + self.mock_gateway.take(), + "mock gateway already taken from context" + ); mock_gateway.close_stream(); let handler_task = self .handler_task @@ -385,8 +385,10 @@ impl HandlerTestContext { } pub(super) async fn finish_after_error(mut self) -> MockGatewayHarness { - let mock_gateway = - assert_some!(self.mock_gateway.take(), "mock gateway already taken from context"); + let mock_gateway = assert_some!( + self.mock_gateway.take(), + "mock gateway already taken from context" + ); let handler_task = self .handler_task .as_mut() @@ -458,7 +460,10 @@ async fn create_network(pool: &PgPool) -> WireguardNetwork { network .try_set_address("10.10.0.1/24") .expect("failed to set network address"); - network.save(pool).await.expect("failed to create test network") + network + .save(pool) + .await + .expect("failed to create test network") } async fn create_gateway(pool: &PgPool, location_id: Id) -> Gateway { From 618176a680d7b8b4b20f115dafed5a9ace0c80a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Mon, 16 Mar 2026 16:52:24 +0100 Subject: [PATCH 04/36] fix test layout --- crates/defguard_gateway_manager/src/lib.rs | 1 - crates/defguard_gateway_manager/src/tests.rs | 108 ------------------- 2 files changed, 109 deletions(-) delete mode 100644 crates/defguard_gateway_manager/src/tests.rs diff --git a/crates/defguard_gateway_manager/src/lib.rs b/crates/defguard_gateway_manager/src/lib.rs index d0eb05ec6e..27ed4a61a8 100644 --- a/crates/defguard_gateway_manager/src/lib.rs +++ b/crates/defguard_gateway_manager/src/lib.rs @@ -28,7 +28,6 @@ mod error; mod handler; #[cfg(test)] -#[path = "tests/mod.rs"] mod tests; const GATEWAY_TABLE_TRIGGER: &str = "gateway_change"; diff --git a/crates/defguard_gateway_manager/src/tests.rs b/crates/defguard_gateway_manager/src/tests.rs deleted file mode 100644 index 2257afcba0..0000000000 --- a/crates/defguard_gateway_manager/src/tests.rs +++ /dev/null @@ -1,108 +0,0 @@ -use std::{ - io, - net::{IpAddr, Ipv4Addr}, - sync::{Arc, Mutex}, -}; - -use ipnetwork::IpNetwork; -use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; -use tokio::{ - net::UnixListener, - sync::{broadcast, mpsc::unbounded_channel}, -}; -use tokio_stream::wrappers::{UnboundedReceiverStream, UnixListenerStream}; -use tonic::{Request, Response, Status, Streaming, transport::Server}; - -use defguard_common::db::{ - models::{ - gateway::Gateway, - wireguard::{LocationMfaMode, ServiceLocationMode, WireguardNetwork}, - }, - setup_pool, -}; -use defguard_mail::Mail; -use defguard_proto::gateway::{CoreRequest, CoreResponse, gateway_server}; - -use super::{TONIC_SOCKET, handler::GatewayHandler}; -use crate::grpc::{ClientMap, GrpcEvent, gateway::events::GatewayEvent}; - -// TODO: move to "gateway" repo. -struct FakeGateway; - -#[tonic::async_trait] -impl gateway_server::Gateway for FakeGateway { - type BidiStream = UnboundedReceiverStream>; - - async fn bidi( - &self, - request: Request>, - ) -> Result, Status> { - let (_tx, rx) = unbounded_channel(); - let mut stream = request.into_inner(); - tokio::spawn(async move { - loop { - match stream.message().await { - Ok(Some(_response)) => (), - Ok(None) => (), - Err(_err) => (), - } - } - }); - - Ok(Response::new(UnboundedReceiverStream::new(rx))) - } -} - -async fn fake_gateway() -> Result<(), io::Error> { - let gateway = FakeGateway {}; - - let uds = UnixListener::bind(TONIC_SOCKET)?; - let uds_stream = UnixListenerStream::new(uds); - - Server::builder() - .add_service(gateway_server::GatewayServer::new(gateway)) - .serve_with_incoming(uds_stream) - .await - .unwrap(); - - Ok(()) -} - -#[sqlx::test] -async fn test_gateway(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - let network = WireguardNetwork::new( - "TestNet".to_string(), - vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 1, 1, 1)), 24).unwrap()], - 50051, - "0.0.0.0".to_string(), - None, - 1420, - 0, - vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 1, 1, 0)), 24).unwrap()], - false, - 25, - 300, - false, - false, - LocationMfaMode::default(), - ServiceLocationMode::default(), - ) - .save(&pool) - .await - .unwrap(); - let gateway = Gateway::new(network.id, "http://[::]:50051") - .save(&pool) - .await - .unwrap(); - let client_state = Arc::new(Mutex::new(ClientMap::new())); - let (events_tx, _events_rx) = broadcast::channel::(16); - let (grpc_event_tx, _grpc_event_rx) = unbounded_channel::(); - - let mut gateway_handler = - GatewayHandler::new(gateway, None, pool, client_state, events_tx, grpc_event_tx).unwrap(); - let handle = tokio::spawn(async move { - gateway_handler.handle_connection().await; - }); - handle.abort(); -} From d1d29327c4009f346779fd61e746ce4280801b7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Mon, 16 Mar 2026 17:17:26 +0100 Subject: [PATCH 05/36] put tests in a separate integration tests module --- crates/defguard_gateway_manager/Cargo.toml | 4 +- .../defguard_gateway_manager/src/handler.rs | 90 ++++++++++--------- crates/defguard_gateway_manager/src/lib.rs | 4 +- .../tests/support.rs => tests/common/mod.rs} | 73 ++++++++------- .../gateway_manager}/handler.rs | 4 +- .../tests => tests/gateway_manager}/mod.rs | 1 - crates/defguard_gateway_manager/tests/mod.rs | 2 + 7 files changed, 90 insertions(+), 88 deletions(-) rename crates/defguard_gateway_manager/{src/tests/support.rs => tests/common/mod.rs} (86%) rename crates/defguard_gateway_manager/{src/tests => tests/gateway_manager}/handler.rs (99%) rename crates/defguard_gateway_manager/{src/tests => tests/gateway_manager}/mod.rs (50%) create mode 100644 crates/defguard_gateway_manager/tests/mod.rs diff --git a/crates/defguard_gateway_manager/Cargo.toml b/crates/defguard_gateway_manager/Cargo.toml index 9afb53828e..fe0d595112 100644 --- a/crates/defguard_gateway_manager/Cargo.toml +++ b/crates/defguard_gateway_manager/Cargo.toml @@ -17,6 +17,7 @@ defguard_version.workspace = true anyhow.workspace = true chrono.workspace = true +hyper-util = "0.1" hyper-rustls.workspace = true reqwest.workspace = true semver.workspace = true @@ -28,6 +29,3 @@ tokio-stream.workspace = true tonic.workspace = true tower.workspace = true tracing.workspace = true - -[dev-dependencies] -hyper-util = "0.1" diff --git a/crates/defguard_gateway_manager/src/handler.rs b/crates/defguard_gateway_manager/src/handler.rs index da921230b7..99ac9366e5 100644 --- a/crates/defguard_gateway_manager/src/handler.rs +++ b/crates/defguard_gateway_manager/src/handler.rs @@ -1,6 +1,7 @@ use std::{ collections::HashMap, net::IpAddr, + path::PathBuf, str::FromStr, sync::{ Arc, Mutex, @@ -8,11 +9,7 @@ use std::{ }, }; -#[cfg(test)] -use std::path::{Path, PathBuf}; - use chrono::DateTime; -#[cfg(not(test))] use defguard_common::db::models::Settings; use defguard_common::{ VERSION, @@ -27,7 +24,6 @@ use defguard_core::{ handlers::mail::send_gateway_disconnected_email, location_management::allowed_peers::get_location_allowed_peers, }; -#[cfg(not(test))] use defguard_grpc_tls::{certs as tls_certs, connector::HttpsSchemeConnector}; use defguard_proto::{ enterprise::firewall::FirewallConfig, @@ -37,7 +33,6 @@ use defguard_proto::{ }, }; use defguard_version::client::ClientVersionInterceptor; -#[cfg(not(test))] use hyper_rustls::HttpsConnectorBuilder; use reqwest::Url; use semver::Version; @@ -55,13 +50,11 @@ use tonic::{Code, Status, transport::Endpoint}; use crate::{Client, TEN_SECS, error::GatewayError}; -#[cfg(test)] #[derive(Debug, Default)] struct GatewayTestTransport { socket_path: Option, } -#[cfg(test)] impl GatewayTestTransport { fn with_socket_path(socket_path: PathBuf) -> Self { Self { @@ -69,12 +62,8 @@ impl GatewayTestTransport { } } - fn socket_path(&self) -> Result<&Path, GatewayError> { - self.socket_path.as_deref().ok_or_else(|| { - GatewayError::EndpointError( - "Missing test gateway transport socket path for GatewayHandler".to_string(), - ) - }) + fn socket_path(&self) -> Option<&PathBuf> { + self.socket_path.as_ref() } } @@ -88,7 +77,6 @@ pub(super) struct GatewayHandler { events_tx: Sender, peer_stats_tx: UnboundedSender, certs_rx: watch::Receiver>>, - #[cfg(test)] test_transport: GatewayTestTransport, } @@ -115,13 +103,11 @@ impl GatewayHandler { events_tx, peer_stats_tx, certs_rx, - #[cfg(test)] test_transport: GatewayTestTransport::default(), }) } - #[cfg(test)] - pub(super) fn new_with_test_socket( + fn new_with_test_socket( gateway: Gateway, pool: PgPool, events_tx: Sender, @@ -277,13 +263,21 @@ impl GatewayHandler { clients: Arc>>, retry_on_connect_failure: bool, ) -> Result<(), GatewayError> { - #[cfg(test)] - let _ = &self.certs_rx; let endpoint = self.endpoint()?; let uri = endpoint.uri().to_string(); - #[cfg(not(test))] - let channel = { + let channel = if let Some(socket_path) = self.test_transport.socket_path().cloned() { + endpoint.connect_with_connector_lazy(tower::service_fn( + move |_: tonic::transport::Uri| { + let socket_path = socket_path.clone(); + async move { + Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new( + tokio::net::UnixStream::connect(socket_path).await?, + )) + } + }, + )) + } else { let settings = Settings::get_current_settings(); let Some(ca_cert_der) = settings.ca_cert_der else { return Err(GatewayError::EndpointError( @@ -301,20 +295,6 @@ impl GatewayHandler { let connector = HttpsSchemeConnector::new(connector); endpoint.connect_with_connector_lazy(connector) }; - #[cfg(test)] - let channel = { - let socket_path = self.test_transport.socket_path()?.to_path_buf(); - endpoint.connect_with_connector_lazy(tower::service_fn( - move |_: tonic::transport::Uri| { - let socket_path = socket_path.clone(); - async move { - Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new( - tokio::net::UnixStream::connect(socket_path).await?, - )) - } - }, - )) - }; debug!("Connecting to Gateway {uri}"); let interceptor = ClientVersionInterceptor::new( @@ -452,13 +432,39 @@ impl GatewayHandler { .await?; } } +} - #[cfg(test)] - pub(super) async fn handle_connection_once( - &mut self, - clients: Arc>>, - ) -> Result<(), GatewayError> { - self.handle_connection_iteration(clients, false).await +#[doc(hidden)] +pub struct TestGatewayHandler { + inner: GatewayHandler, +} + +impl TestGatewayHandler { + pub fn new( + gateway: Gateway, + pool: PgPool, + events_tx: Sender, + peer_stats_tx: UnboundedSender, + certs_rx: watch::Receiver>>, + socket_path: PathBuf, + ) -> anyhow::Result { + let inner = GatewayHandler::new_with_test_socket( + gateway, + pool, + events_tx, + peer_stats_tx, + certs_rx, + socket_path, + )?; + Ok(Self { inner }) + } + + pub async fn handle_connection_once(&mut self) -> anyhow::Result<()> { + let clients = Arc::>>::default(); + self.inner + .handle_connection_iteration(clients, false) + .await + .map_err(anyhow::Error::from) } } diff --git a/crates/defguard_gateway_manager/src/lib.rs b/crates/defguard_gateway_manager/src/lib.rs index 27ed4a61a8..19ed75c12c 100644 --- a/crates/defguard_gateway_manager/src/lib.rs +++ b/crates/defguard_gateway_manager/src/lib.rs @@ -27,8 +27,8 @@ mod certs; mod error; mod handler; -#[cfg(test)] -mod tests; +#[doc(hidden)] +pub use handler::TestGatewayHandler; const GATEWAY_TABLE_TRIGGER: &str = "gateway_change"; const GATEWAY_RECONNECT_DELAY: Duration = Duration::from_secs(5); diff --git a/crates/defguard_gateway_manager/src/tests/support.rs b/crates/defguard_gateway_manager/tests/common/mod.rs similarity index 86% rename from crates/defguard_gateway_manager/src/tests/support.rs rename to crates/defguard_gateway_manager/tests/common/mod.rs index 8ca5513a33..01bdbe3b6c 100644 --- a/crates/defguard_gateway_manager/src/tests/support.rs +++ b/crates/defguard_gateway_manager/tests/common/mod.rs @@ -18,6 +18,7 @@ use defguard_common::{ messages::peer_stats_update::PeerStatsUpdate, }; use defguard_core::grpc::GatewayEvent; +use defguard_gateway_manager::TestGatewayHandler; use defguard_proto::gateway::{ ConfigurationRequest, CoreRequest, CoreResponse, PeerStats, core_request, gateway_server, }; @@ -35,8 +36,6 @@ use tokio::{ use tokio_stream::{once, wrappers::UnboundedReceiverStream}; use tonic::{Request, Response, Status, Streaming, transport::Server}; -use crate::{Client, error::GatewayError, handler::GatewayHandler}; - const TEST_TIMEOUT: Duration = Duration::from_secs(2); macro_rules! assert_some { @@ -127,7 +126,7 @@ impl gateway_server::Gateway for MockGatewayService { } } -pub(super) struct MockGatewayHarness { +pub(crate) struct MockGatewayHarness { socket_path: PathBuf, inbound_tx: Option>>, outbound_rx: UnboundedReceiver, @@ -137,7 +136,7 @@ pub(super) struct MockGatewayHarness { } impl MockGatewayHarness { - pub(super) async fn start() -> Self { + pub(crate) async fn start() -> Self { let socket_path = unique_socket_path(); let _ = std::fs::remove_file(&socket_path); @@ -173,18 +172,18 @@ impl MockGatewayHarness { } } - pub(super) fn socket_path(&self) -> PathBuf { + pub(crate) fn socket_path(&self) -> PathBuf { self.socket_path.clone() } - pub(super) async fn wait_connected(&mut self) { + pub(crate) async fn wait_connected(&mut self) { timeout(TEST_TIMEOUT, &mut self.connected_rx) .await .expect("timed out waiting for mock gateway connection") .expect("mock gateway connection notifier dropped"); } - pub(super) fn send_config_request(&self) { + pub(crate) fn send_config_request(&self) { let request = ConfigurationRequest { hostname: "mock-gateway".to_string(), ..Default::default() @@ -195,14 +194,14 @@ impl MockGatewayHarness { }); } - pub(super) fn send_peer_stats(&self, peer_stats: PeerStats) { + pub(crate) fn send_peer_stats(&self, peer_stats: PeerStats) { self.send_request(CoreRequest { id: self.next_message_id.fetch_add(1, Ordering::Relaxed), payload: Some(core_request::Payload::PeerStats(peer_stats)), }); } - pub(super) fn send_stream_error(&self, status: Status) { + pub(crate) fn send_stream_error(&self, status: Status) { self.inbound_tx .as_ref() .expect("mock gateway inbound channel already closed") @@ -218,18 +217,18 @@ impl MockGatewayHarness { .expect("failed to inject mock gateway request"); } - pub(super) fn close_stream(&mut self) { + pub(crate) fn close_stream(&mut self) { self.inbound_tx.take(); } - pub(super) async fn recv_outbound(&mut self) -> CoreResponse { + pub(crate) async fn recv_outbound(&mut self) -> CoreResponse { timeout(TEST_TIMEOUT, self.outbound_rx.recv()) .await .expect("timed out waiting for outbound response") .expect("mock gateway outbound response channel closed unexpectedly") } - pub(super) async fn expect_no_outbound(&mut self) { + pub(crate) async fn expect_no_outbound(&mut self) { if let Ok(Some(_message)) = timeout(Duration::from_millis(200), self.outbound_rx.recv()).await { @@ -237,7 +236,7 @@ impl MockGatewayHarness { } } - pub(super) async fn expect_server_finished(mut self) { + pub(crate) async fn expect_server_finished(mut self) { let server_task = assert_some!( self.server_task.take(), "mock gateway server task already taken" @@ -259,18 +258,18 @@ impl Drop for MockGatewayHarness { } } -pub(super) struct HandlerTestContext { - pub(super) pool: PgPool, - pub(super) network: WireguardNetwork, - pub(super) gateway: Gateway, - pub(super) peer_stats_rx: UnboundedReceiver, +pub(crate) struct HandlerTestContext { + pub(crate) pool: PgPool, + pub(crate) network: WireguardNetwork, + pub(crate) gateway: Gateway, + pub(crate) peer_stats_rx: UnboundedReceiver, events_tx: Option>, - pub(super) mock_gateway: Option, - handler_task: Option>>, + pub(crate) mock_gateway: Option, + handler_task: Option>>, } impl HandlerTestContext { - pub(super) async fn new(options: PgConnectOptions) -> Self { + pub(crate) async fn new(options: PgConnectOptions) -> Self { let pool = setup_pool(options).await; let network = create_network(&pool).await; let gateway = create_gateway(&pool, network.id).await; @@ -278,7 +277,7 @@ impl HandlerTestContext { let (peer_stats_tx, peer_stats_rx) = mpsc::unbounded_channel(); let (_, certs_rx) = watch::channel(Arc::new(HashMap::new())); let mut mock_gateway = MockGatewayHarness::start().await; - let mut handler = GatewayHandler::new_with_test_socket( + let mut handler = TestGatewayHandler::new( gateway.clone(), pool.clone(), events_tx.clone(), @@ -287,9 +286,7 @@ impl HandlerTestContext { mock_gateway.socket_path(), ) .expect("failed to create gateway handler"); - let clients = Arc::>>::default(); - let handler_task = - tokio::spawn(async move { handler.handle_connection_once(clients).await }); + let handler_task = tokio::spawn(async move { handler.handle_connection_once().await }); mock_gateway.wait_connected().await; @@ -304,36 +301,36 @@ impl HandlerTestContext { } } - pub(super) fn events_tx(&self) -> &broadcast::Sender { + pub(crate) fn events_tx(&self) -> &broadcast::Sender { self.events_tx .as_ref() .expect("events sender already taken from context") } - pub(super) fn mock_gateway(&self) -> &MockGatewayHarness { + pub(crate) fn mock_gateway(&self) -> &MockGatewayHarness { self.mock_gateway .as_ref() .expect("mock gateway already taken from context") } - pub(super) fn mock_gateway_mut(&mut self) -> &mut MockGatewayHarness { + pub(crate) fn mock_gateway_mut(&mut self) -> &mut MockGatewayHarness { self.mock_gateway .as_mut() .expect("mock gateway already taken from context") } - pub(super) async fn reload_gateway(&self) -> Gateway { + pub(crate) async fn reload_gateway(&self) -> Gateway { Gateway::find_by_id(&self.pool, self.gateway.id) .await .expect("failed to query gateway from database") .expect("expected gateway in database") } - pub(super) async fn create_other_network(&self) -> WireguardNetwork { + pub(crate) async fn create_other_network(&self) -> WireguardNetwork { create_network(&self.pool).await } - pub(super) async fn expect_no_peer_stats(&mut self) { + pub(crate) async fn expect_no_peer_stats(&mut self) { if let Ok(Some(message)) = timeout(Duration::from_millis(200), self.peer_stats_rx.recv()).await { @@ -341,14 +338,14 @@ impl HandlerTestContext { } } - pub(super) async fn recv_peer_stats(&mut self) -> PeerStatsUpdate { + pub(crate) async fn recv_peer_stats(&mut self) -> PeerStatsUpdate { timeout(TEST_TIMEOUT, self.peer_stats_rx.recv()) .await .expect("timed out waiting for peer stats update") .expect("peer stats channel unexpectedly closed") } - pub(super) async fn complete_config_handshake(&mut self) -> Gateway { + pub(crate) async fn complete_config_handshake(&mut self) -> Gateway { self.mock_gateway().send_config_request(); let _ = self.mock_gateway_mut().recv_outbound().await; let connected_gateway = @@ -364,7 +361,7 @@ impl HandlerTestContext { connected_gateway } - pub(super) async fn finish(mut self) -> MockGatewayHarness { + pub(crate) async fn finish(mut self) -> MockGatewayHarness { let mut mock_gateway = assert_some!( self.mock_gateway.take(), "mock gateway already taken from context" @@ -384,7 +381,7 @@ impl HandlerTestContext { mock_gateway } - pub(super) async fn finish_after_error(mut self) -> MockGatewayHarness { + pub(crate) async fn finish_after_error(mut self) -> MockGatewayHarness { let mock_gateway = assert_some!( self.mock_gateway.take(), "mock gateway already taken from context" @@ -412,14 +409,14 @@ impl Drop for HandlerTestContext { } } -pub(super) async fn reload_gateway(pool: &PgPool, gateway_id: Id) -> Gateway { +pub(crate) async fn reload_gateway(pool: &PgPool, gateway_id: Id) -> Gateway { Gateway::find_by_id(pool, gateway_id) .await .expect("failed to query gateway from database") .expect("expected gateway in database") } -pub(super) async fn wait_for_gateway_connection_state( +pub(crate) async fn wait_for_gateway_connection_state( pool: &PgPool, gateway_id: Id, expected_connected: bool, @@ -438,7 +435,7 @@ pub(super) async fn wait_for_gateway_connection_state( .expect("timed out waiting for gateway connection state change") } -pub(super) fn build_peer_stats(endpoint: &str) -> PeerStats { +pub(crate) fn build_peer_stats(endpoint: &str) -> PeerStats { PeerStats { public_key: "peer-public-key".to_string(), endpoint: endpoint.to_string(), diff --git a/crates/defguard_gateway_manager/src/tests/handler.rs b/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs similarity index 99% rename from crates/defguard_gateway_manager/src/tests/handler.rs rename to crates/defguard_gateway_manager/tests/gateway_manager/handler.rs index 35c4cd546a..e8ad9e1d9f 100644 --- a/crates/defguard_gateway_manager/src/tests/handler.rs +++ b/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs @@ -1,3 +1,4 @@ +use defguard_core::grpc::GatewayEvent; use defguard_proto::gateway::{ CoreResponse, Update, UpdateType, core_response, update::{self}, @@ -5,8 +6,7 @@ use defguard_proto::gateway::{ use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use tonic::Status; -use super::support::{HandlerTestContext, build_peer_stats, reload_gateway}; -use defguard_core::grpc::GatewayEvent; +use crate::common::{HandlerTestContext, build_peer_stats, reload_gateway}; macro_rules! assert_send_ok { ($result:expr, $message:literal) => { diff --git a/crates/defguard_gateway_manager/src/tests/mod.rs b/crates/defguard_gateway_manager/tests/gateway_manager/mod.rs similarity index 50% rename from crates/defguard_gateway_manager/src/tests/mod.rs rename to crates/defguard_gateway_manager/tests/gateway_manager/mod.rs index f763168bef..427a476918 100644 --- a/crates/defguard_gateway_manager/src/tests/mod.rs +++ b/crates/defguard_gateway_manager/tests/gateway_manager/mod.rs @@ -1,2 +1 @@ mod handler; -mod support; diff --git a/crates/defguard_gateway_manager/tests/mod.rs b/crates/defguard_gateway_manager/tests/mod.rs new file mode 100644 index 0000000000..c49ae37809 --- /dev/null +++ b/crates/defguard_gateway_manager/tests/mod.rs @@ -0,0 +1,2 @@ +pub(crate) mod common; +pub(crate) mod gateway_manager; From 306ef096fecef5f674c72a7784a74fb700a096ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Mon, 16 Mar 2026 19:28:10 +0100 Subject: [PATCH 06/36] test update message routing --- .../tests/common/mod.rs | 12 +- .../tests/gateway_manager/handler.rs | 119 ++++++++++++++++++ 2 files changed, 129 insertions(+), 2 deletions(-) diff --git a/crates/defguard_gateway_manager/tests/common/mod.rs b/crates/defguard_gateway_manager/tests/common/mod.rs index 01bdbe3b6c..81bdb4195d 100644 --- a/crates/defguard_gateway_manager/tests/common/mod.rs +++ b/crates/defguard_gateway_manager/tests/common/mod.rs @@ -270,10 +270,17 @@ pub(crate) struct HandlerTestContext { impl HandlerTestContext { pub(crate) async fn new(options: PgConnectOptions) -> Self { + let (events_tx, _) = broadcast::channel(16); + Self::new_with_events_tx(options, events_tx).await + } + + pub(crate) async fn new_with_events_tx( + options: PgConnectOptions, + events_tx: broadcast::Sender, + ) -> Self { let pool = setup_pool(options).await; let network = create_network(&pool).await; let gateway = create_gateway(&pool, network.id).await; - let (events_tx, _) = broadcast::channel(16); let (peer_stats_tx, peer_stats_rx) = mpsc::unbounded_channel(); let (_, certs_rx) = watch::channel(Arc::new(HashMap::new())); let mut mock_gateway = MockGatewayHarness::start().await; @@ -346,12 +353,13 @@ impl HandlerTestContext { } pub(crate) async fn complete_config_handshake(&mut self) -> Gateway { + let initial_event_receivers = self.events_tx().receiver_count(); self.mock_gateway().send_config_request(); let _ = self.mock_gateway_mut().recv_outbound().await; let connected_gateway = wait_for_gateway_connection_state(&self.pool, self.gateway.id, true).await; timeout(TEST_TIMEOUT, async { - while self.events_tx().receiver_count() == 0 { + while self.events_tx().receiver_count() <= initial_event_receivers { tokio::time::sleep(Duration::from_millis(20)).await; } }) diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs b/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs index e8ad9e1d9f..28584140b5 100644 --- a/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs +++ b/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs @@ -162,6 +162,99 @@ async fn test_matching_location_network_event_produces_outbound_update( context.finish().await.expect_server_finished().await; } +#[sqlx::test] +async fn test_matching_location_network_modified_event_produces_modify_update( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + let _ = context.complete_config_handshake().await; + + let mut modified_network = context.network.clone(); + modified_network.name = format!("{}-modified", context.network.name); + modified_network.address = vec!["10.20.0.1/24" + .parse() + .expect("failed to parse modified network address")]; + modified_network.port = 51821; + modified_network.mtu = 1380; + modified_network.fwmark = 42; + + assert_send_ok!( + context.events_tx().send(GatewayEvent::NetworkModified( + context.network.id, + modified_network, + Vec::new(), + None, + )), + "failed to broadcast modified gateway event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_network_modify_update( + outbound, + &format!("{}-modified", context.network.name), + "10.20.0.1/24", + 51821, + 1380, + 42, + ); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_only_matching_handler_receives_network_modified_update( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let (events_tx, _) = tokio::sync::broadcast::channel(16); + let mut matching_context = HandlerTestContext::new_with_events_tx(options.clone(), events_tx.clone()).await; + let mut unrelated_context = HandlerTestContext::new_with_events_tx(options, events_tx).await; + + assert_ne!(matching_context.network.id, unrelated_context.network.id); + + let _ = matching_context.complete_config_handshake().await; + let _ = unrelated_context.complete_config_handshake().await; + + let mut modified_network = matching_context.network.clone(); + modified_network.name = format!("{}-modified", matching_context.network.name); + modified_network.address = vec!["10.30.0.1/24" + .parse() + .expect("failed to parse modified network address")]; + modified_network.port = 51831; + modified_network.mtu = 1400; + modified_network.fwmark = 7; + + assert_send_ok!( + matching_context + .events_tx() + .send(GatewayEvent::NetworkModified( + matching_context.network.id, + modified_network, + Vec::new(), + None, + )), + "failed to broadcast modified gateway event" + ); + + let outbound = matching_context.mock_gateway_mut().recv_outbound().await; + assert_network_modify_update( + outbound, + &format!("{}-modified", matching_context.network.name), + "10.30.0.1/24", + 51831, + 1400, + 7, + ); + matching_context.mock_gateway_mut().expect_no_outbound().await; + unrelated_context.mock_gateway_mut().expect_no_outbound().await; + + matching_context.finish().await.expect_server_finished().await; + unrelated_context.finish().await.expect_server_finished().await; +} + #[sqlx::test] async fn test_different_location_network_event_is_ignored( _: PgPoolOptions, @@ -268,3 +361,29 @@ fn assert_network_delete_update(outbound: CoreResponse, expected_network_name: & _ => panic_unexpected!("expected network delete update"), } } + +fn assert_network_modify_update( + outbound: CoreResponse, + expected_network_name: &str, + expected_address: &str, + expected_port: u32, + expected_mtu: u32, + expected_fwmark: u32, +) { + match outbound.payload { + Some(core_response::Payload::Update(Update { + update_type, + update: Some(update::Update::Network(network)), + })) => { + assert_eq!(update_type, UpdateType::Modify as i32); + assert_eq!(network.name, expected_network_name); + assert_eq!(network.addresses, vec![expected_address.to_string()]); + assert_eq!(network.port, expected_port); + assert_eq!(network.peers, Vec::new()); + assert_eq!(network.firewall_config, None); + assert_eq!(network.mtu, expected_mtu); + assert_eq!(network.fwmark, expected_fwmark); + } + _ => panic_unexpected!("expected network modify update"), + } +} From f7031c8727371ea7e09aadac5177f430d182d6ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Tue, 17 Mar 2026 09:14:28 +0100 Subject: [PATCH 07/36] add more network event tests --- .../tests/gateway_manager/handler.rs | 136 ++++++++++++++++-- 1 file changed, 123 insertions(+), 13 deletions(-) diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs b/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs index 28584140b5..1aa95e287c 100644 --- a/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs +++ b/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs @@ -140,7 +140,7 @@ async fn test_drops_malformed_or_missing_endpoint_peer_stats( } #[sqlx::test] -async fn test_matching_location_network_event_produces_outbound_update( +async fn test_matching_location_network_deleted_event_produces_delete_update( _: PgPoolOptions, options: PgConnectOptions, ) { @@ -173,9 +173,11 @@ async fn test_matching_location_network_modified_event_produces_modify_update( let mut modified_network = context.network.clone(); modified_network.name = format!("{}-modified", context.network.name); - modified_network.address = vec!["10.20.0.1/24" - .parse() - .expect("failed to parse modified network address")]; + modified_network.address = vec![ + "10.20.0.1/24" + .parse() + .expect("failed to parse modified network address"), + ]; modified_network.port = 51821; modified_network.mtu = 1380; modified_network.fwmark = 42; @@ -204,13 +206,56 @@ async fn test_matching_location_network_modified_event_produces_modify_update( context.finish().await.expect_server_finished().await; } +#[sqlx::test] +async fn test_matching_location_network_created_event_produces_create_update( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + let _ = context.complete_config_handshake().await; + + let mut created_network = context.network.clone(); + created_network.name = format!("{}-created", context.network.name); + created_network.address = vec![ + "10.40.0.1/24" + .parse() + .expect("failed to parse created network address"), + ]; + created_network.port = 51841; + created_network.mtu = 1410; + created_network.fwmark = 17; + + assert_send_ok!( + context.events_tx().send(GatewayEvent::NetworkCreated( + context.network.id, + created_network, + )), + "failed to broadcast created gateway event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_network_create_update( + outbound, + &format!("{}-created", context.network.name), + "10.40.0.1/24", + 51841, + 1410, + 17, + ); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + #[sqlx::test] async fn test_only_matching_handler_receives_network_modified_update( _: PgPoolOptions, options: PgConnectOptions, ) { let (events_tx, _) = tokio::sync::broadcast::channel(16); - let mut matching_context = HandlerTestContext::new_with_events_tx(options.clone(), events_tx.clone()).await; + let mut matching_context = + HandlerTestContext::new_with_events_tx(options.clone(), events_tx.clone()).await; let mut unrelated_context = HandlerTestContext::new_with_events_tx(options, events_tx).await; assert_ne!(matching_context.network.id, unrelated_context.network.id); @@ -220,9 +265,11 @@ async fn test_only_matching_handler_receives_network_modified_update( let mut modified_network = matching_context.network.clone(); modified_network.name = format!("{}-modified", matching_context.network.name); - modified_network.address = vec!["10.30.0.1/24" - .parse() - .expect("failed to parse modified network address")]; + modified_network.address = vec![ + "10.30.0.1/24" + .parse() + .expect("failed to parse modified network address"), + ]; modified_network.port = 51831; modified_network.mtu = 1400; modified_network.fwmark = 7; @@ -248,15 +295,52 @@ async fn test_only_matching_handler_receives_network_modified_update( 1400, 7, ); - matching_context.mock_gateway_mut().expect_no_outbound().await; - unrelated_context.mock_gateway_mut().expect_no_outbound().await; + matching_context + .mock_gateway_mut() + .expect_no_outbound() + .await; + unrelated_context + .mock_gateway_mut() + .expect_no_outbound() + .await; + + matching_context + .finish() + .await + .expect_server_finished() + .await; + unrelated_context + .finish() + .await + .expect_server_finished() + .await; +} + +#[sqlx::test] +async fn test_different_location_network_created_event_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + let other_network = context.create_other_network().await; + assert_ne!(other_network.id, context.network.id); - matching_context.finish().await.expect_server_finished().await; - unrelated_context.finish().await.expect_server_finished().await; + let _ = context.complete_config_handshake().await; + assert_send_ok!( + context.events_tx().send(GatewayEvent::NetworkCreated( + other_network.id, + other_network, + )), + "failed to broadcast unrelated created gateway event" + ); + + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; } #[sqlx::test] -async fn test_different_location_network_event_is_ignored( +async fn test_different_location_network_deleted_event_is_ignored( _: PgPoolOptions, options: PgConnectOptions, ) { @@ -362,6 +446,32 @@ fn assert_network_delete_update(outbound: CoreResponse, expected_network_name: & } } +fn assert_network_create_update( + outbound: CoreResponse, + expected_network_name: &str, + expected_address: &str, + expected_port: u32, + expected_mtu: u32, + expected_fwmark: u32, +) { + match outbound.payload { + Some(core_response::Payload::Update(Update { + update_type, + update: Some(update::Update::Network(network)), + })) => { + assert_eq!(update_type, UpdateType::Create as i32); + assert_eq!(network.name, expected_network_name); + assert_eq!(network.addresses, vec![expected_address.to_string()]); + assert_eq!(network.port, expected_port); + assert_eq!(network.peers, Vec::new()); + assert_eq!(network.firewall_config, None); + assert_eq!(network.mtu, expected_mtu); + assert_eq!(network.fwmark, expected_fwmark); + } + _ => panic_unexpected!("expected network create update"), + } +} + fn assert_network_modify_update( outbound: CoreResponse, expected_network_name: &str, From 332ee2719bff0938eac42a706a3548ff05849da1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Tue, 17 Mar 2026 09:36:38 +0100 Subject: [PATCH 08/36] add peer lifecycle tests --- .../tests/gateway_manager/handler.rs | 241 ++++++++++++++++++ 1 file changed, 241 insertions(+) diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs b/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs index 1aa95e287c..f53792492b 100644 --- a/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs +++ b/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs @@ -1,3 +1,12 @@ +use std::net::IpAddr; + +use defguard_common::db::{ + Id, + models::{ + device::{Device, DeviceInfo, DeviceType, WireguardNetworkDevice}, + user::User, + }, +}; use defguard_core::grpc::GatewayEvent; use defguard_proto::gateway::{ CoreResponse, Update, UpdateType, core_response, @@ -139,6 +148,136 @@ async fn test_drops_malformed_or_missing_endpoint_peer_stats( context.finish().await.expect_server_finished().await; } +#[sqlx::test] +async fn test_device_created_for_network_produces_peer_create_update( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + let expected_keepalive_interval = expected_keepalive_interval(&context); + + let _ = context.complete_config_handshake().await; + let device_info = create_device_info_for_current_network( + &context, + "created-peer-device", + "LQKsT6/3HWKuJmMulH63R8iK+5sI8FyYEL6WDIi6lQU=", + "10.10.0.10", + Some("created-preshared-key"), + ) + .await; + + assert_send_ok!( + context + .events_tx() + .send(GatewayEvent::DeviceCreated(device_info)), + "failed to broadcast created device event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_peer_update( + outbound, + UpdateType::Create, + "LQKsT6/3HWKuJmMulH63R8iK+5sI8FyYEL6WDIi6lQU=", + &["10.10.0.10"], + Some("created-preshared-key"), + Some(expected_keepalive_interval), + ); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_device_modified_for_network_produces_peer_modify_update( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + let expected_keepalive_interval = expected_keepalive_interval(&context); + + let _ = context.complete_config_handshake().await; + let device = create_device_for_current_network( + &context, + "modified-peer-device", + "TJgN9JzUF5zdZAPYD96G/Wys2M3TvaT5TIrErUl20nI=", + "10.10.0.20", + Some("initial-preshared-key"), + ) + .await; + + let mut network_device = WireguardNetworkDevice::find(&context.pool, device.id, context.network.id) + .await + .expect("failed to load device network info") + .expect("expected device network info for modified device"); + network_device.wireguard_ips = vec![parse_test_ip("10.10.0.21")]; + network_device.preshared_key = Some("modified-preshared-key".to_string()); + network_device + .update(&context.pool) + .await + .expect("failed to update device network info"); + let device_info = DeviceInfo::from_device(&context.pool, device) + .await + .expect("failed to load modified device info"); + + assert_send_ok!( + context + .events_tx() + .send(GatewayEvent::DeviceModified(device_info)), + "failed to broadcast modified device event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_peer_update( + outbound, + UpdateType::Modify, + "TJgN9JzUF5zdZAPYD96G/Wys2M3TvaT5TIrErUl20nI=", + &["10.10.0.21"], + Some("modified-preshared-key"), + Some(expected_keepalive_interval), + ); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_device_deleted_for_network_produces_peer_delete_update( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + let _ = context.complete_config_handshake().await; + let device_info = create_device_info_for_current_network( + &context, + "deleted-peer-device", + "PKY3zg5/ecNyMjqLi6yJ3jwb4PvC/SGzjhJ3jrn2vVQ=", + "10.10.0.30", + Some("deleted-preshared-key"), + ) + .await; + + assert_send_ok!( + context + .events_tx() + .send(GatewayEvent::DeviceDeleted(device_info)), + "failed to broadcast deleted device event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_peer_update( + outbound, + UpdateType::Delete, + "PKY3zg5/ecNyMjqLi6yJ3jwb4PvC/SGzjhJ3jrn2vVQ=", + &[], + None, + None, + ); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + #[sqlx::test] async fn test_matching_location_network_deleted_event_produces_delete_update( _: PgPoolOptions, @@ -433,6 +572,108 @@ async fn test_gateway_is_marked_disconnected_when_stream_errors( mock_gateway.expect_server_finished().await; } +async fn create_device_info_for_current_network( + context: &HandlerTestContext, + device_name: &str, + device_pubkey: &str, + device_ip: &str, + preshared_key: Option<&str>, +) -> DeviceInfo { + let device = create_device_for_current_network( + context, + device_name, + device_pubkey, + device_ip, + preshared_key, + ) + .await; + + DeviceInfo::from_device(&context.pool, device) + .await + .expect("failed to load device info") +} + +async fn create_device_for_current_network( + context: &HandlerTestContext, + device_name: &str, + device_pubkey: &str, + device_ip: &str, + preshared_key: Option<&str>, +) -> Device { + let username = format!("{device_name}-user"); + let email = format!("{device_name}@example.com"); + let user = User::new( + username, + Some("pass123"), + "Peer".to_string(), + "Test".to_string(), + email, + None, + ) + .save(&context.pool) + .await + .expect("failed to create test user"); + let device = Device::new( + device_name.to_string(), + device_pubkey.to_string(), + user.id, + DeviceType::User, + None, + true, + ) + .save(&context.pool) + .await + .expect("failed to create test device"); + + let mut network_device = + WireguardNetworkDevice::new(context.network.id, device.id, vec![parse_test_ip(device_ip)]); + network_device.preshared_key = preshared_key.map(str::to_owned); + network_device + .insert(&context.pool) + .await + .expect("failed to attach device to network"); + + device +} + +fn expected_keepalive_interval(context: &HandlerTestContext) -> u32 { + u32::try_from(context.network.keepalive_interval) + .expect("expected non-negative network keepalive interval") +} + +fn parse_test_ip(ip: &str) -> IpAddr { + ip.parse().expect("failed to parse test peer IP address") +} + +fn assert_peer_update( + outbound: CoreResponse, + expected_update_type: UpdateType, + expected_pubkey: &str, + expected_allowed_ips: &[&str], + expected_preshared_key: Option<&str>, + expected_keepalive_interval: Option, +) { + match outbound.payload { + Some(core_response::Payload::Update(Update { + update_type, + update: Some(update::Update::Peer(peer)), + })) => { + assert_eq!(update_type, expected_update_type as i32); + assert_eq!(peer.pubkey, expected_pubkey); + assert_eq!( + peer.allowed_ips, + expected_allowed_ips + .iter() + .map(|allowed_ip| allowed_ip.to_string()) + .collect::>() + ); + assert_eq!(peer.preshared_key.as_deref(), expected_preshared_key); + assert_eq!(peer.keepalive_interval, expected_keepalive_interval); + } + _ => panic_unexpected!("expected peer update"), + } +} + fn assert_network_delete_update(outbound: CoreResponse, expected_network_name: &str) { match outbound.payload { Some(core_response::Payload::Update(Update { From bdefa52fc315c0c97af64256ab4ddd46e97ee634 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Tue, 17 Mar 2026 11:52:23 +0100 Subject: [PATCH 09/36] test ignored events --- .../tests/common/mod.rs | 7 +- .../tests/gateway_manager/handler.rs | 195 +++++++++++++++++- 2 files changed, 193 insertions(+), 9 deletions(-) diff --git a/crates/defguard_gateway_manager/tests/common/mod.rs b/crates/defguard_gateway_manager/tests/common/mod.rs index 81bdb4195d..d03c618f6d 100644 --- a/crates/defguard_gateway_manager/tests/common/mod.rs +++ b/crates/defguard_gateway_manager/tests/common/mod.rs @@ -12,7 +12,9 @@ use std::{ use defguard_common::{ db::{ Id, - models::{gateway::Gateway, wireguard::WireguardNetwork}, + models::{ + gateway::Gateway, settings::initialize_current_settings, wireguard::WireguardNetwork, + }, setup_pool, }, messages::peer_stats_update::PeerStatsUpdate, @@ -279,6 +281,9 @@ impl HandlerTestContext { events_tx: broadcast::Sender, ) -> Self { let pool = setup_pool(options).await; + initialize_current_settings(&pool) + .await + .expect("failed to initialize global settings for gateway handler tests"); let network = create_network(&pool).await; let gateway = create_gateway(&pool, network.id).await; let (peer_stats_tx, peer_stats_rx) = mpsc::unbounded_channel(); diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs b/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs index f6dc387390..0b612c0b2b 100644 --- a/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs +++ b/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs @@ -187,6 +187,22 @@ async fn test_device_created_for_network_produces_peer_create_update( context.finish().await.expect_server_finished().await; } +#[sqlx::test] +async fn test_device_created_before_config_handshake_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + assert_device_event_is_ignored_before_config_handshake( + options, + "created-before-config-device", + "tND8hJQhYnI8naBTo59He43zYldagfjlwmSxWEc01Cc=", + "10.10.0.11", + Some("created-before-config-preshared-key"), + GatewayEvent::DeviceCreated, + ) + .await; +} + #[sqlx::test] async fn test_device_modified_for_network_produces_peer_modify_update( _: PgPoolOptions, @@ -196,8 +212,9 @@ async fn test_device_modified_for_network_produces_peer_modify_update( let expected_keepalive_interval = expected_keepalive_interval(&context); let _ = context.complete_config_handshake().await; - let device = create_device_for_current_network( + let device = create_device_for_network( &context, + context.network.id, "modified-peer-device", "TJgN9JzUF5zdZAPYD96G/Wys2M3TvaT5TIrErUl20nI=", "10.10.0.20", @@ -241,6 +258,22 @@ async fn test_device_modified_for_network_produces_peer_modify_update( context.finish().await.expect_server_finished().await; } +#[sqlx::test] +async fn test_device_modified_before_config_handshake_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + assert_device_event_is_ignored_before_config_handshake( + options, + "modified-before-config-device", + "wyFOHCec/Fi9s+cARikVO71JhyYtYMk0FrQx3fK2PTM=", + "10.10.0.22", + Some("modified-before-config-preshared-key"), + GatewayEvent::DeviceModified, + ) + .await; +} + #[sqlx::test] async fn test_device_deleted_for_network_produces_peer_delete_update( _: PgPoolOptions, @@ -279,6 +312,70 @@ async fn test_device_deleted_for_network_produces_peer_delete_update( context.finish().await.expect_server_finished().await; } +#[sqlx::test] +async fn test_device_deleted_before_config_handshake_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + assert_device_event_is_ignored_before_config_handshake( + options, + "deleted-before-config-device", + "m84QJmDMkqdCj8AB2NTE8F55W7M/i3CaaD3eQbQdInY=", + "10.10.0.31", + Some("deleted-before-config-preshared-key"), + GatewayEvent::DeviceDeleted, + ) + .await; +} + +#[sqlx::test] +async fn test_device_created_for_different_network_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + assert_device_event_for_different_network_is_ignored( + options, + "created-other-network-device", + "W6wBmd8wgTwvCyGqDRXk6Hf4OMqDUbUn2XWKnG5wVVQ=", + "10.11.0.10", + Some("created-other-network-preshared-key"), + GatewayEvent::DeviceCreated, + ) + .await; +} + +#[sqlx::test] +async fn test_device_modified_for_different_network_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + assert_device_event_for_different_network_is_ignored( + options, + "modified-other-network-device", + "yjuzq0cLk3Ww5oQcqK6YkSKwXnqQ1V9OlSMFAEkr0lU=", + "10.11.0.20", + Some("modified-other-network-preshared-key"), + GatewayEvent::DeviceModified, + ) + .await; +} + +#[sqlx::test] +async fn test_device_deleted_for_different_network_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + assert_device_event_for_different_network_is_ignored( + options, + "deleted-other-network-device", + "Jtp+K8xnFXuF4cae+tVGZNwoSM2fXjJbRl3sI6rdcAQ=", + "10.11.0.30", + Some("deleted-other-network-preshared-key"), + GatewayEvent::DeviceDeleted, + ) + .await; +} + #[sqlx::test] async fn test_matching_location_network_deleted_event_produces_delete_update( _: PgPoolOptions, @@ -580,8 +677,28 @@ async fn create_device_info_for_current_network( device_ip: &str, preshared_key: Option<&str>, ) -> DeviceInfo { - let device = create_device_for_current_network( + create_device_info_for_network( + context, + context.network.id, + device_name, + device_pubkey, + device_ip, + preshared_key, + ) + .await +} + +async fn create_device_info_for_network( + context: &HandlerTestContext, + network_id: Id, + device_name: &str, + device_pubkey: &str, + device_ip: &str, + preshared_key: Option<&str>, +) -> DeviceInfo { + let device = create_device_for_network( context, + network_id, device_name, device_pubkey, device_ip, @@ -594,8 +711,9 @@ async fn create_device_info_for_current_network( .expect("failed to load device info") } -async fn create_device_for_current_network( +async fn create_device_for_network( context: &HandlerTestContext, + network_id: Id, device_name: &str, device_pubkey: &str, device_ip: &str, @@ -626,11 +744,8 @@ async fn create_device_for_current_network( .await .expect("failed to create test device"); - let mut network_device = WireguardNetworkDevice::new( - context.network.id, - device.id, - vec![parse_test_ip(device_ip)], - ); + let mut network_device = + WireguardNetworkDevice::new(network_id, device.id, vec![parse_test_ip(device_ip)]); network_device.preshared_key = preshared_key.map(str::to_owned); network_device .insert(&context.pool) @@ -640,6 +755,70 @@ async fn create_device_for_current_network( device } +async fn assert_device_event_is_ignored_before_config_handshake( + options: PgConnectOptions, + device_name: &str, + device_pubkey: &str, + device_ip: &str, + preshared_key: Option<&str>, + build_event: fn(DeviceInfo) -> GatewayEvent, +) { + let mut context = HandlerTestContext::new(options).await; + assert_eq!(context.events_tx().receiver_count(), 0); + + let _broadcast_guard = context.events_tx().subscribe(); + let device_info = create_device_info_for_current_network( + &context, + device_name, + device_pubkey, + device_ip, + preshared_key, + ) + .await; + + assert_send_ok!( + context.events_tx().send(build_event(device_info)), + "failed to broadcast ignored device event" + ); + + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +async fn assert_device_event_for_different_network_is_ignored( + options: PgConnectOptions, + device_name: &str, + device_pubkey: &str, + device_ip: &str, + preshared_key: Option<&str>, + build_event: fn(DeviceInfo) -> GatewayEvent, +) { + let mut context = HandlerTestContext::new(options).await; + let other_network = context.create_other_network().await; + assert_ne!(other_network.id, context.network.id); + + let _ = context.complete_config_handshake().await; + let device_info = create_device_info_for_network( + &context, + other_network.id, + device_name, + device_pubkey, + device_ip, + preshared_key, + ) + .await; + + assert_send_ok!( + context.events_tx().send(build_event(device_info)), + "failed to broadcast ignored device event" + ); + + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + fn expected_keepalive_interval(context: &HandlerTestContext) -> u32 { u32::try_from(context.network.keepalive_interval) .expect("expected non-negative network keepalive interval") From ba0978d1115ba392d90784cffc35beb2664d7c51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Tue, 17 Mar 2026 21:41:36 +0100 Subject: [PATCH 10/36] add firewall config tests --- .../tests/gateway_manager/handler.rs | 209 ++++++++++++++++++ 1 file changed, 209 insertions(+) diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs b/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs index 0b612c0b2b..152c61a932 100644 --- a/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs +++ b/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs @@ -8,6 +8,10 @@ use defguard_common::db::{ }, }; use defguard_core::grpc::GatewayEvent; +use defguard_proto::enterprise::firewall::{ + FirewallConfig, FirewallPolicy, FirewallRule, IpAddress, IpVersion, Port, Protocol, + SnatBinding, ip_address::Address, port::Port as PortInner, +}; use defguard_proto::gateway::{ CoreResponse, Update, UpdateType, core_response, update::{self}, @@ -485,6 +489,80 @@ async fn test_matching_location_network_created_event_produces_create_update( context.finish().await.expect_server_finished().await; } +#[sqlx::test] +async fn test_matching_location_firewall_config_changed_event_produces_update( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + let expected_firewall_config = build_test_firewall_config(); + + let _ = context.complete_config_handshake().await; + + assert_send_ok!( + context + .events_tx() + .send(GatewayEvent::FirewallConfigChanged( + context.network.id, + expected_firewall_config.clone(), + )), + "failed to broadcast firewall config changed event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_firewall_modify_update(outbound, &expected_firewall_config); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_matching_location_firewall_disabled_event_produces_disable_update( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + let _ = context.complete_config_handshake().await; + + assert_send_ok!( + context + .events_tx() + .send(GatewayEvent::FirewallDisabled(context.network.id)), + "failed to broadcast firewall disabled event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_firewall_disable_update(outbound); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_different_location_firewall_config_changed_event_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let expected_firewall_config = build_test_firewall_config(); + + assert_firewall_event_for_different_network_is_ignored(options, move |other_network_id| { + GatewayEvent::FirewallConfigChanged(other_network_id, expected_firewall_config) + }) + .await; +} + +#[sqlx::test] +async fn test_different_location_firewall_disabled_event_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + assert_firewall_event_for_different_network_is_ignored(options, |other_network_id| { + GatewayEvent::FirewallDisabled(other_network_id) + }) + .await; +} + #[sqlx::test] async fn test_only_matching_handler_receives_network_modified_update( _: PgPoolOptions, @@ -819,6 +897,26 @@ async fn assert_device_event_for_different_network_is_ignored( context.finish().await.expect_server_finished().await; } +async fn assert_firewall_event_for_different_network_is_ignored( + options: PgConnectOptions, + build_event: impl FnOnce(Id) -> GatewayEvent, +) { + let mut context = HandlerTestContext::new(options).await; + let other_network = context.create_other_network().await; + assert_ne!(other_network.id, context.network.id); + + let _ = context.complete_config_handshake().await; + + assert_send_ok!( + context.events_tx().send(build_event(other_network.id)), + "failed to broadcast ignored firewall event" + ); + + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + fn expected_keepalive_interval(context: &HandlerTestContext) -> u32 { u32::try_from(context.network.keepalive_interval) .expect("expected non-negative network keepalive interval") @@ -921,3 +1019,114 @@ fn assert_network_modify_update( _ => panic_unexpected!("expected network modify update"), } } + +fn build_test_firewall_config() -> FirewallConfig { + FirewallConfig { + default_policy: i32::from(FirewallPolicy::Allow), + rules: vec![FirewallRule { + id: 101, + source_addrs: vec![IpAddress { + address: Some(Address::IpSubnet("10.10.0.0/24".to_string())), + }], + destination_addrs: vec![IpAddress { + address: Some(Address::Ip("198.51.100.20".to_string())), + }], + destination_ports: vec![Port { + port: Some(PortInner::SinglePort(443)), + }], + protocols: vec![i32::from(Protocol::Tcp)], + verdict: i32::from(FirewallPolicy::Deny), + comment: Some("block test https destination".to_string()), + ip_version: i32::from(IpVersion::Ipv4), + }], + snat_bindings: vec![SnatBinding { + id: 202, + source_addrs: vec![IpAddress { + address: Some(Address::IpSubnet("10.10.0.0/24".to_string())), + }], + public_ip: "203.0.113.44".to_string(), + comment: Some("test snat binding".to_string()), + }], + } +} + +fn assert_firewall_modify_update( + outbound: CoreResponse, + expected_firewall_config: &FirewallConfig, +) { + match outbound.payload { + Some(core_response::Payload::Update(Update { + update_type, + update: Some(update::Update::FirewallConfig(firewall_config)), + })) => { + assert_eq!(update_type, UpdateType::Modify as i32); + assert_eq!( + firewall_config.default_policy, + expected_firewall_config.default_policy + ); + assert_eq!( + firewall_config.rules.len(), + expected_firewall_config.rules.len() + ); + assert_eq!( + firewall_config.snat_bindings.len(), + expected_firewall_config.snat_bindings.len() + ); + + let firewall_rule = firewall_config + .rules + .first() + .expect("expected firewall rule in update payload"); + let expected_firewall_rule = expected_firewall_config + .rules + .first() + .expect("expected firewall rule in test config"); + assert_eq!(firewall_rule.id, expected_firewall_rule.id); + assert_eq!( + firewall_rule.source_addrs, + expected_firewall_rule.source_addrs + ); + assert_eq!( + firewall_rule.destination_addrs, + expected_firewall_rule.destination_addrs + ); + assert_eq!( + firewall_rule.destination_ports, + expected_firewall_rule.destination_ports + ); + assert_eq!(firewall_rule.protocols, expected_firewall_rule.protocols); + assert_eq!(firewall_rule.verdict, expected_firewall_rule.verdict); + assert_eq!(firewall_rule.comment, expected_firewall_rule.comment); + assert_eq!(firewall_rule.ip_version, expected_firewall_rule.ip_version); + + let snat_binding = firewall_config + .snat_bindings + .first() + .expect("expected SNAT binding in update payload"); + let expected_snat_binding = expected_firewall_config + .snat_bindings + .first() + .expect("expected SNAT binding in test config"); + assert_eq!(snat_binding.id, expected_snat_binding.id); + assert_eq!( + snat_binding.source_addrs, + expected_snat_binding.source_addrs + ); + assert_eq!(snat_binding.public_ip, expected_snat_binding.public_ip); + assert_eq!(snat_binding.comment, expected_snat_binding.comment); + } + _ => panic_unexpected!("expected firewall config update"), + } +} + +fn assert_firewall_disable_update(outbound: CoreResponse) { + match outbound.payload { + Some(core_response::Payload::Update(Update { + update_type, + update: Some(update::Update::DisableFirewall(())), + })) => { + assert_eq!(update_type, UpdateType::Delete as i32); + } + _ => panic_unexpected!("expected firewall disable update"), + } +} From c09a123cc16abb5cf29a4311903a36be8e56b076 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Wed, 18 Mar 2026 07:23:43 +0100 Subject: [PATCH 11/36] add mfa tests --- .../tests/gateway_manager/handler.rs | 183 ++++++++++++++++++ 1 file changed, 183 insertions(+) diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs b/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs index 152c61a932..6142b686fb 100644 --- a/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs +++ b/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs @@ -5,6 +5,7 @@ use defguard_common::db::{ models::{ device::{Device, DeviceInfo, DeviceType, WireguardNetworkDevice}, user::User, + wireguard::{LocationMfaMode, WireguardNetwork}, }, }; use defguard_core::grpc::GatewayEvent; @@ -380,6 +381,127 @@ async fn test_device_deleted_for_different_network_is_ignored( .await; } +#[sqlx::test] +async fn test_matching_location_mfa_session_authorized_produces_peer_create( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + let expected_keepalive_interval = expected_keepalive_interval(&context); + enable_internal_mfa_for_network(&context.pool, &mut context.network).await; + + let _ = context.complete_config_handshake().await; + let (device, network_device) = create_authorized_mfa_device_for_current_network( + &context, + "mfa-authorized-device", + "4v9K9Q4HEdmlX0Mb4uxDLPq3nKjvU8fNnJ9fKjzh4ko=", + "10.10.0.40", + Some("mfa-authorized-preshared-key"), + ) + .await; + + assert_send_ok!( + context.events_tx().send(GatewayEvent::MfaSessionAuthorized( + context.network.id, + device, + network_device, + )), + "failed to broadcast MFA session authorized event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_peer_update( + outbound, + UpdateType::Create, + "4v9K9Q4HEdmlX0Mb4uxDLPq3nKjvU8fNnJ9fKjzh4ko=", + &["10.10.0.40"], + Some("mfa-authorized-preshared-key"), + Some(expected_keepalive_interval), + ); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_mfa_session_authorized_with_mismatched_network_id_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + enable_internal_mfa_for_network(&context.pool, &mut context.network).await; + + let mut other_network = context.create_other_network().await; + enable_internal_mfa_for_network(&context.pool, &mut other_network).await; + assert_ne!(other_network.id, context.network.id); + + let _ = context.complete_config_handshake().await; + let (device, network_device) = create_authorized_mfa_device_for_network( + &context, + other_network.id, + "mfa-mismatched-network-device", + "Z2UuIvYJvU5fTOp8i3tHfLm4xZ0R8ExY6E3S3l+rqT8=", + "10.11.0.40", + Some("mfa-mismatched-network-preshared-key"), + ) + .await; + + assert_send_ok!( + context.events_tx().send(GatewayEvent::MfaSessionAuthorized( + context.network.id, + device, + network_device, + )), + "failed to broadcast mismatched MFA session authorized event" + ); + + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_matching_location_mfa_session_disconnected_produces_peer_delete( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + enable_internal_mfa_for_network(&context.pool, &mut context.network).await; + + let _ = context.complete_config_handshake().await; + let (device, _) = create_authorized_mfa_device_for_current_network( + &context, + "mfa-disconnected-device", + "2+n8hQ1yA2sPp1z2i6m8lP4VtY7M8W6hYqS3n4uL7qg=", + "10.10.0.41", + Some("mfa-disconnected-preshared-key"), + ) + .await; + + assert_send_ok!( + context + .events_tx() + .send(GatewayEvent::MfaSessionDisconnected( + context.network.id, + device, + )), + "failed to broadcast MFA session disconnected event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_peer_update( + outbound, + UpdateType::Delete, + "2+n8hQ1yA2sPp1z2i6m8lP4VtY7M8W6hYqS3n4uL7qg=", + &[], + None, + None, + ); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + #[sqlx::test] async fn test_matching_location_network_deleted_event_produces_delete_update( _: PgPoolOptions, @@ -766,6 +888,55 @@ async fn create_device_info_for_current_network( .await } +async fn create_authorized_mfa_device_for_current_network( + context: &HandlerTestContext, + device_name: &str, + device_pubkey: &str, + device_ip: &str, + preshared_key: Option<&str>, +) -> (Device, WireguardNetworkDevice) { + create_authorized_mfa_device_for_network( + context, + context.network.id, + device_name, + device_pubkey, + device_ip, + preshared_key, + ) + .await +} + +async fn create_authorized_mfa_device_for_network( + context: &HandlerTestContext, + network_id: Id, + device_name: &str, + device_pubkey: &str, + device_ip: &str, + preshared_key: Option<&str>, +) -> (Device, WireguardNetworkDevice) { + let device = create_device_for_network( + context, + network_id, + device_name, + device_pubkey, + device_ip, + preshared_key, + ) + .await; + let mut network_device = WireguardNetworkDevice::find(&context.pool, device.id, network_id) + .await + .expect("failed to load MFA device network info") + .expect("expected MFA device network info"); + network_device.is_authorized = true; + network_device.preshared_key = preshared_key.map(str::to_owned); + network_device + .update(&context.pool) + .await + .expect("failed to persist MFA device network info"); + + (device, network_device) +} + async fn create_device_info_for_network( context: &HandlerTestContext, network_id: Id, @@ -833,6 +1004,18 @@ async fn create_device_for_network( device } +async fn enable_internal_mfa_for_network( + pool: &sqlx::PgPool, + network: &mut WireguardNetwork, +) { + network.location_mfa_mode = LocationMfaMode::Internal; + network + .save(pool) + .await + .expect("failed to enable MFA for test network"); + assert!(network.mfa_enabled()); +} + async fn assert_device_event_is_ignored_before_config_handshake( options: PgConnectOptions, device_name: &str, From 7726a150be447cc33d08ab7042faa39a11fa9b4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Wed, 18 Mar 2026 07:30:10 +0100 Subject: [PATCH 12/36] test peer stats conversion --- .../defguard_gateway_manager/src/handler.rs | 143 ++++++++++++++++++ .../tests/gateway_manager/handler.rs | 5 +- 2 files changed, 144 insertions(+), 4 deletions(-) diff --git a/crates/defguard_gateway_manager/src/handler.rs b/crates/defguard_gateway_manager/src/handler.rs index 5ed6cb59e9..15ec3a512c 100644 --- a/crates/defguard_gateway_manager/src/handler.rs +++ b/crates/defguard_gateway_manager/src/handler.rs @@ -905,3 +905,146 @@ fn gen_config( fwmark: network.fwmark as u32, } } + +#[cfg(test)] +mod tests { + use chrono::{DateTime, Utc}; + use defguard_common::db::{ + Id, + models::wireguard::{LocationMfaMode, ServiceLocationMode}, + }; + + use super::{ + FirewallConfig, Peer, PeerStats, WireguardNetwork, gen_config, + try_protos_into_stats_message, + }; + + fn build_peer_stats(endpoint: &str) -> PeerStats { + PeerStats { + public_key: "peer-public-key".to_string(), + endpoint: endpoint.to_string(), + upload: 123, + download: 456, + keepalive_interval: 25, + latest_handshake: 1_700_000_000, + allowed_ips: "10.10.0.2/32".to_string(), + } + } + + fn build_network() -> WireguardNetwork { + WireguardNetwork { + id: 7, + name: "test-network".to_string(), + address: vec![ + "10.10.0.1/24".parse().expect("valid IPv4 network"), + "fd00::1/64".parse().expect("valid IPv6 network"), + ], + port: 51820, + pubkey: "network-public-key".to_string(), + prvkey: "network-private-key".to_string(), + endpoint: "198.51.100.10".to_string(), + dns: Some("1.1.1.1".to_string()), + mtu: 1420, + fwmark: 4321, + allowed_ips: vec!["0.0.0.0/0".parse().expect("valid allowed IP network")], + allow_all_groups: false, + connected_at: None, + acl_enabled: false, + acl_default_allow: false, + keepalive_interval: 25, + peer_disconnect_threshold: 180, + location_mfa_mode: LocationMfaMode::default(), + service_location_mode: ServiceLocationMode::default(), + } + } + + #[test] + fn try_protos_into_stats_message_maps_valid_peer_stats() { + let stats = try_protos_into_stats_message(build_peer_stats("203.0.113.10:51820"), 11, 22) + .expect("valid peer stats should be converted"); + + assert_eq!(stats.location_id, 11); + assert_eq!(stats.gateway_id, 22); + assert_eq!(stats.device_pubkey, "peer-public-key"); + assert_eq!(stats.endpoint.to_string(), "203.0.113.10:51820"); + assert_eq!(stats.upload, 123); + assert_eq!(stats.download, 456); + assert_eq!( + stats.latest_handshake, + DateTime::from_timestamp(1_700_000_000, 0) + .expect("valid handshake timestamp") + .naive_utc() + ); + } + + #[test] + fn try_protos_into_stats_message_rejects_invalid_endpoint() { + let stats = try_protos_into_stats_message(build_peer_stats("not-a-socket-address"), 11, 22); + + assert!(stats.is_none()); + } + + #[test] + fn try_protos_into_stats_message_falls_back_to_default_timestamp() { + let stats = try_protos_into_stats_message( + PeerStats { + latest_handshake: i64::MAX as u64, + ..build_peer_stats("203.0.113.10:51820") + }, + 11, + 22, + ) + .expect("valid endpoint should still produce stats"); + + assert_eq!(stats.latest_handshake, DateTime::::default().naive_utc()); + } + + #[test] + fn gen_config_maps_network_fields() { + let config = gen_config( + &build_network(), + vec![Peer { + pubkey: "peer-public-key".to_string(), + allowed_ips: vec!["10.10.0.2/32".to_string()], + preshared_key: Some("peer-preshared-key".to_string()), + keepalive_interval: Some(25), + }], + Some(FirewallConfig { + default_policy: 0, + rules: Vec::new(), + snat_bindings: Vec::new(), + }), + ); + + assert_eq!(config.name, "test-network"); + assert_eq!(config.port, 51820); + assert_eq!(config.prvkey, "network-private-key"); + assert_eq!(config.addresses, vec!["10.10.0.1/24", "fd00::1/64"]); + assert_eq!(config.mtu, 1420); + assert_eq!(config.fwmark, 4321); + + let peer = config + .peers + .first() + .expect("generated config should include peer"); + assert_eq!(peer.pubkey, "peer-public-key"); + assert_eq!(peer.allowed_ips, vec!["10.10.0.2/32"]); + assert_eq!(peer.preshared_key.as_deref(), Some("peer-preshared-key")); + assert_eq!(peer.keepalive_interval, Some(25)); + + let firewall_config = config + .firewall_config + .expect("generated config should include firewall config"); + assert_eq!(firewall_config.default_policy, 0); + assert!(firewall_config.rules.is_empty()); + assert!(firewall_config.snat_bindings.is_empty()); + } + + #[test] + fn gen_config_preserves_absent_firewall_config_and_empty_peers() { + let config = gen_config(&build_network(), Vec::new(), None); + + assert!(config.peers.is_empty()); + assert!(config.firewall_config.is_none()); + } +} diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs b/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs index 6142b686fb..f10895ddb1 100644 --- a/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs +++ b/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs @@ -1004,10 +1004,7 @@ async fn create_device_for_network( device } -async fn enable_internal_mfa_for_network( - pool: &sqlx::PgPool, - network: &mut WireguardNetwork, -) { +async fn enable_internal_mfa_for_network(pool: &sqlx::PgPool, network: &mut WireguardNetwork) { network.location_mfa_mode = LocationMfaMode::Internal; network .save(pool) From e43b41ed21a1584be1b0077aef636d518febccad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Wed, 18 Mar 2026 08:39:52 +0100 Subject: [PATCH 13/36] add cert tests --- crates/defguard_gateway_manager/src/certs.rs | 96 ++++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/crates/defguard_gateway_manager/src/certs.rs b/crates/defguard_gateway_manager/src/certs.rs index a1daf0b53a..aaa1b4a090 100644 --- a/crates/defguard_gateway_manager/src/certs.rs +++ b/crates/defguard_gateway_manager/src/certs.rs @@ -31,3 +31,99 @@ pub(super) async fn refresh_certs(pool: &PgPool, tx: &watch::Sender, + ) -> Gateway { + let mut gateway = Gateway::new( + location_id, + name.to_string(), + "127.0.0.1".to_string(), + 51820, + "test-admin".to_string(), + ); + gateway.certificate = certificate.map(str::to_owned); + + gateway + .save(pool) + .await + .expect("failed to create gateway for cert refresh tests") + } +} From 853831be65f7315390beb9fe29b3cdaa620c54c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Wed, 18 Mar 2026 09:40:47 +0100 Subject: [PATCH 14/36] expand test certs --- crates/defguard_gateway_manager/src/certs.rs | 53 ++++++++++++++++++-- 1 file changed, 49 insertions(+), 4 deletions(-) diff --git a/crates/defguard_gateway_manager/src/certs.rs b/crates/defguard_gateway_manager/src/certs.rs index aaa1b4a090..cfea2139f3 100644 --- a/crates/defguard_gateway_manager/src/certs.rs +++ b/crates/defguard_gateway_manager/src/certs.rs @@ -81,15 +81,17 @@ mod tests { .save(&pool) .await .expect("failed to create network for cert refresh tests"); - let gateway_with_cert = + let mut gateway_with_cert = create_gateway(&pool, network.id, "gateway-with-cert", Some("cert-1")).await; - let gateway_without_cert = + let mut gateway_without_cert = create_gateway(&pool, network.id, "gateway-without-cert", None).await; - let gateway_with_new_cert = + let mut gateway_with_new_cert = create_gateway(&pool, network.id, "gateway-with-new-cert", Some("cert-3")).await; - let (tx, mut rx) = + let (tx, rx) = watch::channel(Arc::new(HashMap::from([(999, "stale-cert".to_string())]))); + let mut lagging_rx = rx.clone(); + let mut rx = rx; refresh_certs(&pool, &tx).await; @@ -104,6 +106,49 @@ mod tests { assert_eq!(published.as_ref(), &expected); assert!(!published.contains_key(&gateway_without_cert.id)); assert!(!published.contains_key(&999)); + + gateway_with_cert.certificate = Some("cert-2".to_string()); + gateway_with_cert + .save(&pool) + .await + .expect("failed to update gateway certificate for cert refresh tests"); + + gateway_without_cert.certificate = Some("cert-4".to_string()); + gateway_without_cert + .save(&pool) + .await + .expect("failed to add gateway certificate for cert refresh tests"); + + gateway_with_new_cert.certificate = None; + gateway_with_new_cert + .save(&pool) + .await + .expect("failed to remove gateway certificate for cert refresh tests"); + + refresh_certs(&pool, &tx).await; + + assert!(rx.has_changed().expect("cert watch sender should still be alive")); + + let published = Arc::clone(&rx.borrow_and_update()); + let expected = HashMap::from([ + (gateway_with_cert.id, "cert-2".to_string()), + (gateway_without_cert.id, "cert-4".to_string()), + ]); + + assert_eq!(published.as_ref(), &expected); + assert!(!published.contains_key(&gateway_with_new_cert.id)); + assert!(!published.contains_key(&999)); + + assert!(lagging_rx + .has_changed() + .expect("cert watch sender should still be alive")); + let latest_only = Arc::clone(&lagging_rx.borrow_and_update()); + + assert_eq!(latest_only.as_ref(), &expected); + assert_ne!(latest_only.as_ref(), &HashMap::from([ + (gateway_with_cert.id, "cert-1".to_string()), + (gateway_with_new_cert.id, "cert-3".to_string()), + ])); } async fn create_gateway( From 0ec18a369a4dd9025d9749761f855e477b270d39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Wed, 18 Mar 2026 10:01:02 +0100 Subject: [PATCH 15/36] split test module into smaller files --- crates/defguard_gateway_manager/src/certs.rs | 36 +- .../defguard_gateway_manager/src/handler.rs | 5 +- .../tests/gateway_manager/handler.rs | 1326 +---------------- .../gateway_manager/handler/device_events.rs | 227 +++ .../handler/firewall_events.rs | 73 + .../gateway_manager/handler/handshake.rs | 56 + .../gateway_manager/handler/lifecycle.rs | 59 + .../tests/gateway_manager/handler/mfa.rs | 120 ++ .../gateway_manager/handler/network_events.rs | 233 +++ .../tests/gateway_manager/handler/stats.rs | 58 + .../tests/gateway_manager/handler/support.rs | 481 ++++++ 11 files changed, 1354 insertions(+), 1320 deletions(-) create mode 100644 crates/defguard_gateway_manager/tests/gateway_manager/handler/device_events.rs create mode 100644 crates/defguard_gateway_manager/tests/gateway_manager/handler/firewall_events.rs create mode 100644 crates/defguard_gateway_manager/tests/gateway_manager/handler/handshake.rs create mode 100644 crates/defguard_gateway_manager/tests/gateway_manager/handler/lifecycle.rs create mode 100644 crates/defguard_gateway_manager/tests/gateway_manager/handler/mfa.rs create mode 100644 crates/defguard_gateway_manager/tests/gateway_manager/handler/network_events.rs create mode 100644 crates/defguard_gateway_manager/tests/gateway_manager/handler/stats.rs create mode 100644 crates/defguard_gateway_manager/tests/gateway_manager/handler/support.rs diff --git a/crates/defguard_gateway_manager/src/certs.rs b/crates/defguard_gateway_manager/src/certs.rs index cfea2139f3..5e2367e179 100644 --- a/crates/defguard_gateway_manager/src/certs.rs +++ b/crates/defguard_gateway_manager/src/certs.rs @@ -39,9 +39,7 @@ mod tests { use defguard_common::db::{ Id, models::{ - gateway::Gateway, - settings::initialize_current_settings, - wireguard::WireguardNetwork, + gateway::Gateway, settings::initialize_current_settings, wireguard::WireguardNetwork, }, setup_pool, }; @@ -88,14 +86,16 @@ mod tests { let mut gateway_with_new_cert = create_gateway(&pool, network.id, "gateway-with-new-cert", Some("cert-3")).await; - let (tx, rx) = - watch::channel(Arc::new(HashMap::from([(999, "stale-cert".to_string())]))); + let (tx, rx) = watch::channel(Arc::new(HashMap::from([(999, "stale-cert".to_string())]))); let mut lagging_rx = rx.clone(); let mut rx = rx; refresh_certs(&pool, &tx).await; - assert!(rx.has_changed().expect("cert watch sender should still be alive")); + assert!( + rx.has_changed() + .expect("cert watch sender should still be alive") + ); let published = Arc::clone(&rx.borrow_and_update()); let expected = HashMap::from([ @@ -127,7 +127,10 @@ mod tests { refresh_certs(&pool, &tx).await; - assert!(rx.has_changed().expect("cert watch sender should still be alive")); + assert!( + rx.has_changed() + .expect("cert watch sender should still be alive") + ); let published = Arc::clone(&rx.borrow_and_update()); let expected = HashMap::from([ @@ -139,16 +142,21 @@ mod tests { assert!(!published.contains_key(&gateway_with_new_cert.id)); assert!(!published.contains_key(&999)); - assert!(lagging_rx - .has_changed() - .expect("cert watch sender should still be alive")); + assert!( + lagging_rx + .has_changed() + .expect("cert watch sender should still be alive") + ); let latest_only = Arc::clone(&lagging_rx.borrow_and_update()); assert_eq!(latest_only.as_ref(), &expected); - assert_ne!(latest_only.as_ref(), &HashMap::from([ - (gateway_with_cert.id, "cert-1".to_string()), - (gateway_with_new_cert.id, "cert-3".to_string()), - ])); + assert_ne!( + latest_only.as_ref(), + &HashMap::from([ + (gateway_with_cert.id, "cert-1".to_string()), + (gateway_with_new_cert.id, "cert-3".to_string()), + ]) + ); } async fn create_gateway( diff --git a/crates/defguard_gateway_manager/src/handler.rs b/crates/defguard_gateway_manager/src/handler.rs index 15ec3a512c..1948de549f 100644 --- a/crates/defguard_gateway_manager/src/handler.rs +++ b/crates/defguard_gateway_manager/src/handler.rs @@ -996,7 +996,10 @@ mod tests { ) .expect("valid endpoint should still produce stats"); - assert_eq!(stats.latest_handshake, DateTime::::default().naive_utc()); + assert_eq!( + stats.latest_handshake, + DateTime::::default().naive_utc() + ); } #[test] diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs b/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs index f10895ddb1..b7b12df582 100644 --- a/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs +++ b/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs @@ -1,1312 +1,28 @@ -use std::net::IpAddr; +#[path = "handler/support.rs"] +mod support; -use defguard_common::db::{ - Id, - models::{ - device::{Device, DeviceInfo, DeviceType, WireguardNetworkDevice}, - user::User, - wireguard::{LocationMfaMode, WireguardNetwork}, - }, -}; +use defguard_common::db::models::device::{DeviceInfo, WireguardNetworkDevice}; use defguard_core::grpc::GatewayEvent; -use defguard_proto::enterprise::firewall::{ - FirewallConfig, FirewallPolicy, FirewallRule, IpAddress, IpVersion, Port, Protocol, - SnatBinding, ip_address::Address, port::Port as PortInner, -}; -use defguard_proto::gateway::{ - CoreResponse, Update, UpdateType, core_response, - update::{self}, -}; +use defguard_proto::gateway::{UpdateType, core_response}; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use tonic::Status; +use self::support::{ + assert_device_event_for_different_network_is_ignored, + assert_device_event_is_ignored_before_config_handshake, assert_firewall_disable_update, + assert_firewall_event_for_different_network_is_ignored, assert_firewall_modify_update, + assert_network_create_update, assert_network_delete_update, assert_network_modify_update, + assert_peer_update, assert_send_ok, build_test_firewall_config, + create_authorized_mfa_device_for_current_network, create_authorized_mfa_device_for_network, + create_device_for_network, create_device_info_for_current_network, + enable_internal_mfa_for_network, expected_keepalive_interval, panic_unexpected, parse_test_ip, +}; use crate::common::{HandlerTestContext, build_peer_stats, reload_gateway}; -macro_rules! assert_send_ok { - ($result:expr, $message:literal) => { - match $result { - Ok(value) => value, - Err(_) => panic!($message), - } - }; -} - -macro_rules! panic_unexpected { - ($message:literal) => { - panic!($message) - }; -} - -#[sqlx::test] -async fn test_sends_configuration_on_first_config_request( - _: PgPoolOptions, - options: PgConnectOptions, -) { - let mut context = HandlerTestContext::new(options).await; - - context.mock_gateway().send_config_request(); - let outbound = context.mock_gateway_mut().recv_outbound().await; - - match outbound.payload { - Some(core_response::Payload::Config(config)) => { - assert_eq!(config.name, context.network.name); - assert_eq!(config.port, context.network.port as u32); - assert_eq!(config.peers, Vec::new()); - } - _ => panic_unexpected!("expected configuration response"), - } - - context.finish().await.expect_server_finished().await; -} - -#[sqlx::test] -async fn test_does_not_send_configuration_before_gateway_requests_it( - _: PgPoolOptions, - options: PgConnectOptions, -) { - let mut context = HandlerTestContext::new(options).await; - - let gateway_before = context.reload_gateway().await; - assert!(!gateway_before.is_connected()); - - context.mock_gateway_mut().expect_no_outbound().await; - - let gateway_after = context.reload_gateway().await; - assert!(!gateway_after.is_connected()); - - context.finish().await.expect_server_finished().await; -} - -#[sqlx::test] -async fn test_ignores_repeated_config_request(_: PgPoolOptions, options: PgConnectOptions) { - let mut context = HandlerTestContext::new(options).await; - - context.mock_gateway().send_config_request(); - let first_outbound = context.mock_gateway_mut().recv_outbound().await; - assert!(matches!( - first_outbound.payload, - Some(core_response::Payload::Config(_)) - )); - - context.mock_gateway().send_config_request(); - context.mock_gateway_mut().expect_no_outbound().await; - - context.finish().await.expect_server_finished().await; -} - -#[sqlx::test] -async fn test_ignores_peer_stats_before_config_handshake( - _: PgPoolOptions, - options: PgConnectOptions, -) { - let mut context = HandlerTestContext::new(options).await; - - context - .mock_gateway() - .send_peer_stats(build_peer_stats("203.0.113.10:51820")); - - context.expect_no_peer_stats().await; - context.mock_gateway_mut().expect_no_outbound().await; - - context.finish().await.expect_server_finished().await; -} - -#[sqlx::test] -async fn test_forwards_valid_peer_stats_after_config(_: PgPoolOptions, options: PgConnectOptions) { - let mut context = HandlerTestContext::new(options).await; - - context.mock_gateway().send_config_request(); - let _ = context.mock_gateway_mut().recv_outbound().await; - context - .mock_gateway() - .send_peer_stats(build_peer_stats("203.0.113.10:51820")); - - let forwarded = context.recv_peer_stats().await; - assert_eq!(forwarded.location_id, context.network.id); - assert_eq!(forwarded.gateway_id, context.gateway.id); - assert_eq!(forwarded.device_pubkey, "peer-public-key"); - assert_eq!(forwarded.endpoint.to_string(), "203.0.113.10:51820"); - assert_eq!(forwarded.upload, 123); - assert_eq!(forwarded.download, 456); - - context.finish().await.expect_server_finished().await; -} - -#[sqlx::test] -async fn test_drops_malformed_or_missing_endpoint_peer_stats( - _: PgPoolOptions, - options: PgConnectOptions, -) { - let mut context = HandlerTestContext::new(options).await; - - context.mock_gateway().send_config_request(); - let _ = context.mock_gateway_mut().recv_outbound().await; - - context.mock_gateway().send_peer_stats(build_peer_stats("")); - context.expect_no_peer_stats().await; - - context - .mock_gateway() - .send_peer_stats(build_peer_stats("not-a-socket-address")); - context.expect_no_peer_stats().await; - - context.finish().await.expect_server_finished().await; -} - -#[sqlx::test] -async fn test_device_created_for_network_produces_peer_create_update( - _: PgPoolOptions, - options: PgConnectOptions, -) { - let mut context = HandlerTestContext::new(options).await; - let expected_keepalive_interval = expected_keepalive_interval(&context); - - let _ = context.complete_config_handshake().await; - let device_info = create_device_info_for_current_network( - &context, - "created-peer-device", - "LQKsT6/3HWKuJmMulH63R8iK+5sI8FyYEL6WDIi6lQU=", - "10.10.0.10", - Some("created-preshared-key"), - ) - .await; - - assert_send_ok!( - context - .events_tx() - .send(GatewayEvent::DeviceCreated(device_info)), - "failed to broadcast created device event" - ); - - let outbound = context.mock_gateway_mut().recv_outbound().await; - assert_peer_update( - outbound, - UpdateType::Create, - "LQKsT6/3HWKuJmMulH63R8iK+5sI8FyYEL6WDIi6lQU=", - &["10.10.0.10"], - Some("created-preshared-key"), - Some(expected_keepalive_interval), - ); - context.mock_gateway_mut().expect_no_outbound().await; - - context.finish().await.expect_server_finished().await; -} - -#[sqlx::test] -async fn test_device_created_before_config_handshake_is_ignored( - _: PgPoolOptions, - options: PgConnectOptions, -) { - assert_device_event_is_ignored_before_config_handshake( - options, - "created-before-config-device", - "tND8hJQhYnI8naBTo59He43zYldagfjlwmSxWEc01Cc=", - "10.10.0.11", - Some("created-before-config-preshared-key"), - GatewayEvent::DeviceCreated, - ) - .await; -} - -#[sqlx::test] -async fn test_device_modified_for_network_produces_peer_modify_update( - _: PgPoolOptions, - options: PgConnectOptions, -) { - let mut context = HandlerTestContext::new(options).await; - let expected_keepalive_interval = expected_keepalive_interval(&context); - - let _ = context.complete_config_handshake().await; - let device = create_device_for_network( - &context, - context.network.id, - "modified-peer-device", - "TJgN9JzUF5zdZAPYD96G/Wys2M3TvaT5TIrErUl20nI=", - "10.10.0.20", - Some("initial-preshared-key"), - ) - .await; - - let mut network_device = - WireguardNetworkDevice::find(&context.pool, device.id, context.network.id) - .await - .expect("failed to load device network info") - .expect("expected device network info for modified device"); - network_device.wireguard_ips = vec![parse_test_ip("10.10.0.21")]; - network_device.preshared_key = Some("modified-preshared-key".to_string()); - network_device - .update(&context.pool) - .await - .expect("failed to update device network info"); - let device_info = DeviceInfo::from_device(&context.pool, device) - .await - .expect("failed to load modified device info"); - - assert_send_ok!( - context - .events_tx() - .send(GatewayEvent::DeviceModified(device_info)), - "failed to broadcast modified device event" - ); - - let outbound = context.mock_gateway_mut().recv_outbound().await; - assert_peer_update( - outbound, - UpdateType::Modify, - "TJgN9JzUF5zdZAPYD96G/Wys2M3TvaT5TIrErUl20nI=", - &["10.10.0.21"], - Some("modified-preshared-key"), - Some(expected_keepalive_interval), - ); - context.mock_gateway_mut().expect_no_outbound().await; - - context.finish().await.expect_server_finished().await; -} - -#[sqlx::test] -async fn test_device_modified_before_config_handshake_is_ignored( - _: PgPoolOptions, - options: PgConnectOptions, -) { - assert_device_event_is_ignored_before_config_handshake( - options, - "modified-before-config-device", - "wyFOHCec/Fi9s+cARikVO71JhyYtYMk0FrQx3fK2PTM=", - "10.10.0.22", - Some("modified-before-config-preshared-key"), - GatewayEvent::DeviceModified, - ) - .await; -} - -#[sqlx::test] -async fn test_device_deleted_for_network_produces_peer_delete_update( - _: PgPoolOptions, - options: PgConnectOptions, -) { - let mut context = HandlerTestContext::new(options).await; - - let _ = context.complete_config_handshake().await; - let device_info = create_device_info_for_current_network( - &context, - "deleted-peer-device", - "PKY3zg5/ecNyMjqLi6yJ3jwb4PvC/SGzjhJ3jrn2vVQ=", - "10.10.0.30", - Some("deleted-preshared-key"), - ) - .await; - - assert_send_ok!( - context - .events_tx() - .send(GatewayEvent::DeviceDeleted(device_info)), - "failed to broadcast deleted device event" - ); - - let outbound = context.mock_gateway_mut().recv_outbound().await; - assert_peer_update( - outbound, - UpdateType::Delete, - "PKY3zg5/ecNyMjqLi6yJ3jwb4PvC/SGzjhJ3jrn2vVQ=", - &[], - None, - None, - ); - context.mock_gateway_mut().expect_no_outbound().await; - - context.finish().await.expect_server_finished().await; -} - -#[sqlx::test] -async fn test_device_deleted_before_config_handshake_is_ignored( - _: PgPoolOptions, - options: PgConnectOptions, -) { - assert_device_event_is_ignored_before_config_handshake( - options, - "deleted-before-config-device", - "m84QJmDMkqdCj8AB2NTE8F55W7M/i3CaaD3eQbQdInY=", - "10.10.0.31", - Some("deleted-before-config-preshared-key"), - GatewayEvent::DeviceDeleted, - ) - .await; -} - -#[sqlx::test] -async fn test_device_created_for_different_network_is_ignored( - _: PgPoolOptions, - options: PgConnectOptions, -) { - assert_device_event_for_different_network_is_ignored( - options, - "created-other-network-device", - "W6wBmd8wgTwvCyGqDRXk6Hf4OMqDUbUn2XWKnG5wVVQ=", - "10.11.0.10", - Some("created-other-network-preshared-key"), - GatewayEvent::DeviceCreated, - ) - .await; -} - -#[sqlx::test] -async fn test_device_modified_for_different_network_is_ignored( - _: PgPoolOptions, - options: PgConnectOptions, -) { - assert_device_event_for_different_network_is_ignored( - options, - "modified-other-network-device", - "yjuzq0cLk3Ww5oQcqK6YkSKwXnqQ1V9OlSMFAEkr0lU=", - "10.11.0.20", - Some("modified-other-network-preshared-key"), - GatewayEvent::DeviceModified, - ) - .await; -} - -#[sqlx::test] -async fn test_device_deleted_for_different_network_is_ignored( - _: PgPoolOptions, - options: PgConnectOptions, -) { - assert_device_event_for_different_network_is_ignored( - options, - "deleted-other-network-device", - "Jtp+K8xnFXuF4cae+tVGZNwoSM2fXjJbRl3sI6rdcAQ=", - "10.11.0.30", - Some("deleted-other-network-preshared-key"), - GatewayEvent::DeviceDeleted, - ) - .await; -} - -#[sqlx::test] -async fn test_matching_location_mfa_session_authorized_produces_peer_create( - _: PgPoolOptions, - options: PgConnectOptions, -) { - let mut context = HandlerTestContext::new(options).await; - let expected_keepalive_interval = expected_keepalive_interval(&context); - enable_internal_mfa_for_network(&context.pool, &mut context.network).await; - - let _ = context.complete_config_handshake().await; - let (device, network_device) = create_authorized_mfa_device_for_current_network( - &context, - "mfa-authorized-device", - "4v9K9Q4HEdmlX0Mb4uxDLPq3nKjvU8fNnJ9fKjzh4ko=", - "10.10.0.40", - Some("mfa-authorized-preshared-key"), - ) - .await; - - assert_send_ok!( - context.events_tx().send(GatewayEvent::MfaSessionAuthorized( - context.network.id, - device, - network_device, - )), - "failed to broadcast MFA session authorized event" - ); - - let outbound = context.mock_gateway_mut().recv_outbound().await; - assert_peer_update( - outbound, - UpdateType::Create, - "4v9K9Q4HEdmlX0Mb4uxDLPq3nKjvU8fNnJ9fKjzh4ko=", - &["10.10.0.40"], - Some("mfa-authorized-preshared-key"), - Some(expected_keepalive_interval), - ); - context.mock_gateway_mut().expect_no_outbound().await; - - context.finish().await.expect_server_finished().await; -} - -#[sqlx::test] -async fn test_mfa_session_authorized_with_mismatched_network_id_is_ignored( - _: PgPoolOptions, - options: PgConnectOptions, -) { - let mut context = HandlerTestContext::new(options).await; - enable_internal_mfa_for_network(&context.pool, &mut context.network).await; - - let mut other_network = context.create_other_network().await; - enable_internal_mfa_for_network(&context.pool, &mut other_network).await; - assert_ne!(other_network.id, context.network.id); - - let _ = context.complete_config_handshake().await; - let (device, network_device) = create_authorized_mfa_device_for_network( - &context, - other_network.id, - "mfa-mismatched-network-device", - "Z2UuIvYJvU5fTOp8i3tHfLm4xZ0R8ExY6E3S3l+rqT8=", - "10.11.0.40", - Some("mfa-mismatched-network-preshared-key"), - ) - .await; - - assert_send_ok!( - context.events_tx().send(GatewayEvent::MfaSessionAuthorized( - context.network.id, - device, - network_device, - )), - "failed to broadcast mismatched MFA session authorized event" - ); - - context.mock_gateway_mut().expect_no_outbound().await; - - context.finish().await.expect_server_finished().await; -} - -#[sqlx::test] -async fn test_matching_location_mfa_session_disconnected_produces_peer_delete( - _: PgPoolOptions, - options: PgConnectOptions, -) { - let mut context = HandlerTestContext::new(options).await; - enable_internal_mfa_for_network(&context.pool, &mut context.network).await; - - let _ = context.complete_config_handshake().await; - let (device, _) = create_authorized_mfa_device_for_current_network( - &context, - "mfa-disconnected-device", - "2+n8hQ1yA2sPp1z2i6m8lP4VtY7M8W6hYqS3n4uL7qg=", - "10.10.0.41", - Some("mfa-disconnected-preshared-key"), - ) - .await; - - assert_send_ok!( - context - .events_tx() - .send(GatewayEvent::MfaSessionDisconnected( - context.network.id, - device, - )), - "failed to broadcast MFA session disconnected event" - ); - - let outbound = context.mock_gateway_mut().recv_outbound().await; - assert_peer_update( - outbound, - UpdateType::Delete, - "2+n8hQ1yA2sPp1z2i6m8lP4VtY7M8W6hYqS3n4uL7qg=", - &[], - None, - None, - ); - context.mock_gateway_mut().expect_no_outbound().await; - - context.finish().await.expect_server_finished().await; -} - -#[sqlx::test] -async fn test_matching_location_network_deleted_event_produces_delete_update( - _: PgPoolOptions, - options: PgConnectOptions, -) { - let mut context = HandlerTestContext::new(options).await; - - let _ = context.complete_config_handshake().await; - - assert_send_ok!( - context.events_tx().send(GatewayEvent::NetworkDeleted( - context.network.id, - context.network.name.clone(), - )), - "failed to broadcast gateway event" - ); - - let outbound = context.mock_gateway_mut().recv_outbound().await; - assert_network_delete_update(outbound, &context.network.name); - - context.finish().await.expect_server_finished().await; -} - -#[sqlx::test] -async fn test_matching_location_network_modified_event_produces_modify_update( - _: PgPoolOptions, - options: PgConnectOptions, -) { - let mut context = HandlerTestContext::new(options).await; - - let _ = context.complete_config_handshake().await; - - let mut modified_network = context.network.clone(); - modified_network.name = format!("{}-modified", context.network.name); - modified_network.address = vec![ - "10.20.0.1/24" - .parse() - .expect("failed to parse modified network address"), - ]; - modified_network.port = 51821; - modified_network.mtu = 1380; - modified_network.fwmark = 42; - - assert_send_ok!( - context.events_tx().send(GatewayEvent::NetworkModified( - context.network.id, - modified_network, - Vec::new(), - None, - )), - "failed to broadcast modified gateway event" - ); - - let outbound = context.mock_gateway_mut().recv_outbound().await; - assert_network_modify_update( - outbound, - &format!("{}-modified", context.network.name), - "10.20.0.1/24", - 51821, - 1380, - 42, - ); - context.mock_gateway_mut().expect_no_outbound().await; - - context.finish().await.expect_server_finished().await; -} - -#[sqlx::test] -async fn test_matching_location_network_created_event_produces_create_update( - _: PgPoolOptions, - options: PgConnectOptions, -) { - let mut context = HandlerTestContext::new(options).await; - - let _ = context.complete_config_handshake().await; - - let mut created_network = context.network.clone(); - created_network.name = format!("{}-created", context.network.name); - created_network.address = vec![ - "10.40.0.1/24" - .parse() - .expect("failed to parse created network address"), - ]; - created_network.port = 51841; - created_network.mtu = 1410; - created_network.fwmark = 17; - - assert_send_ok!( - context.events_tx().send(GatewayEvent::NetworkCreated( - context.network.id, - created_network, - )), - "failed to broadcast created gateway event" - ); - - let outbound = context.mock_gateway_mut().recv_outbound().await; - assert_network_create_update( - outbound, - &format!("{}-created", context.network.name), - "10.40.0.1/24", - 51841, - 1410, - 17, - ); - context.mock_gateway_mut().expect_no_outbound().await; - - context.finish().await.expect_server_finished().await; -} - -#[sqlx::test] -async fn test_matching_location_firewall_config_changed_event_produces_update( - _: PgPoolOptions, - options: PgConnectOptions, -) { - let mut context = HandlerTestContext::new(options).await; - let expected_firewall_config = build_test_firewall_config(); - - let _ = context.complete_config_handshake().await; - - assert_send_ok!( - context - .events_tx() - .send(GatewayEvent::FirewallConfigChanged( - context.network.id, - expected_firewall_config.clone(), - )), - "failed to broadcast firewall config changed event" - ); - - let outbound = context.mock_gateway_mut().recv_outbound().await; - assert_firewall_modify_update(outbound, &expected_firewall_config); - context.mock_gateway_mut().expect_no_outbound().await; - - context.finish().await.expect_server_finished().await; -} - -#[sqlx::test] -async fn test_matching_location_firewall_disabled_event_produces_disable_update( - _: PgPoolOptions, - options: PgConnectOptions, -) { - let mut context = HandlerTestContext::new(options).await; - - let _ = context.complete_config_handshake().await; - - assert_send_ok!( - context - .events_tx() - .send(GatewayEvent::FirewallDisabled(context.network.id)), - "failed to broadcast firewall disabled event" - ); - - let outbound = context.mock_gateway_mut().recv_outbound().await; - assert_firewall_disable_update(outbound); - context.mock_gateway_mut().expect_no_outbound().await; - - context.finish().await.expect_server_finished().await; -} - -#[sqlx::test] -async fn test_different_location_firewall_config_changed_event_is_ignored( - _: PgPoolOptions, - options: PgConnectOptions, -) { - let expected_firewall_config = build_test_firewall_config(); - - assert_firewall_event_for_different_network_is_ignored(options, move |other_network_id| { - GatewayEvent::FirewallConfigChanged(other_network_id, expected_firewall_config) - }) - .await; -} - -#[sqlx::test] -async fn test_different_location_firewall_disabled_event_is_ignored( - _: PgPoolOptions, - options: PgConnectOptions, -) { - assert_firewall_event_for_different_network_is_ignored(options, |other_network_id| { - GatewayEvent::FirewallDisabled(other_network_id) - }) - .await; -} - -#[sqlx::test] -async fn test_only_matching_handler_receives_network_modified_update( - _: PgPoolOptions, - options: PgConnectOptions, -) { - let (events_tx, _) = tokio::sync::broadcast::channel(16); - let mut matching_context = - HandlerTestContext::new_with_events_tx(options.clone(), events_tx.clone()).await; - let mut unrelated_context = HandlerTestContext::new_with_events_tx(options, events_tx).await; - - assert_ne!(matching_context.network.id, unrelated_context.network.id); - - let _ = matching_context.complete_config_handshake().await; - let _ = unrelated_context.complete_config_handshake().await; - - let mut modified_network = matching_context.network.clone(); - modified_network.name = format!("{}-modified", matching_context.network.name); - modified_network.address = vec![ - "10.30.0.1/24" - .parse() - .expect("failed to parse modified network address"), - ]; - modified_network.port = 51831; - modified_network.mtu = 1400; - modified_network.fwmark = 7; - - assert_send_ok!( - matching_context - .events_tx() - .send(GatewayEvent::NetworkModified( - matching_context.network.id, - modified_network, - Vec::new(), - None, - )), - "failed to broadcast modified gateway event" - ); - - let outbound = matching_context.mock_gateway_mut().recv_outbound().await; - assert_network_modify_update( - outbound, - &format!("{}-modified", matching_context.network.name), - "10.30.0.1/24", - 51831, - 1400, - 7, - ); - matching_context - .mock_gateway_mut() - .expect_no_outbound() - .await; - unrelated_context - .mock_gateway_mut() - .expect_no_outbound() - .await; - - matching_context - .finish() - .await - .expect_server_finished() - .await; - unrelated_context - .finish() - .await - .expect_server_finished() - .await; -} - -#[sqlx::test] -async fn test_different_location_network_created_event_is_ignored( - _: PgPoolOptions, - options: PgConnectOptions, -) { - let mut context = HandlerTestContext::new(options).await; - let other_network = context.create_other_network().await; - assert_ne!(other_network.id, context.network.id); - - let _ = context.complete_config_handshake().await; - assert_send_ok!( - context.events_tx().send(GatewayEvent::NetworkCreated( - other_network.id, - other_network, - )), - "failed to broadcast unrelated created gateway event" - ); - - context.mock_gateway_mut().expect_no_outbound().await; - - context.finish().await.expect_server_finished().await; -} - -#[sqlx::test] -async fn test_different_location_network_deleted_event_is_ignored( - _: PgPoolOptions, - options: PgConnectOptions, -) { - let mut context = HandlerTestContext::new(options).await; - let other_network = context.create_other_network().await; - assert_ne!(other_network.id, context.network.id); - - let _ = context.complete_config_handshake().await; - assert_send_ok!( - context.events_tx().send(GatewayEvent::NetworkDeleted( - other_network.id, - other_network.name.clone(), - )), - "failed to broadcast unrelated gateway event" - ); - - context.mock_gateway_mut().expect_no_outbound().await; - - assert_send_ok!( - context.events_tx().send(GatewayEvent::NetworkDeleted( - context.network.id, - context.network.name.clone(), - )), - "failed to broadcast owned gateway event" - ); - - let outbound = context.mock_gateway_mut().recv_outbound().await; - assert_network_delete_update(outbound, &context.network.name); - - context.finish().await.expect_server_finished().await; -} - -#[sqlx::test] -async fn test_gateway_is_marked_connected_after_successful_config_handshake( - _: PgPoolOptions, - options: PgConnectOptions, -) { - let mut context = HandlerTestContext::new(options).await; - - let gateway_before = context.reload_gateway().await; - assert!(!gateway_before.is_connected()); - - let gateway_after = context.complete_config_handshake().await; - assert!(gateway_after.is_connected()); - assert!(gateway_after.connected_at.is_some()); - - context.finish().await.expect_server_finished().await; -} - -#[sqlx::test] -async fn test_gateway_is_marked_disconnected_when_stream_closes( - _: PgPoolOptions, - options: PgConnectOptions, -) { - let mut context = HandlerTestContext::new(options).await; - - let connected_gateway = context.complete_config_handshake().await; - assert!(connected_gateway.is_connected()); - - let pool = context.pool.clone(); - let gateway_id = context.gateway.id; - let mock_gateway = context.finish().await; - let disconnected_gateway = reload_gateway(&pool, gateway_id).await; - assert!(!disconnected_gateway.is_connected()); - assert!(disconnected_gateway.disconnected_at.is_some()); - - mock_gateway.expect_server_finished().await; -} - -#[sqlx::test] -async fn test_gateway_is_marked_disconnected_when_stream_errors( - _: PgPoolOptions, - options: PgConnectOptions, -) { - let mut context = HandlerTestContext::new(options).await; - - let _ = context.complete_config_handshake().await; - - context - .mock_gateway() - .send_stream_error(Status::internal("mock gateway stream failure")); - - let pool = context.pool.clone(); - let gateway_id = context.gateway.id; - let mock_gateway = context.finish_after_error().await; - let disconnected_gateway = reload_gateway(&pool, gateway_id).await; - assert!(!disconnected_gateway.is_connected()); - assert!(disconnected_gateway.disconnected_at.is_some()); - - mock_gateway.expect_server_finished().await; -} - -async fn create_device_info_for_current_network( - context: &HandlerTestContext, - device_name: &str, - device_pubkey: &str, - device_ip: &str, - preshared_key: Option<&str>, -) -> DeviceInfo { - create_device_info_for_network( - context, - context.network.id, - device_name, - device_pubkey, - device_ip, - preshared_key, - ) - .await -} - -async fn create_authorized_mfa_device_for_current_network( - context: &HandlerTestContext, - device_name: &str, - device_pubkey: &str, - device_ip: &str, - preshared_key: Option<&str>, -) -> (Device, WireguardNetworkDevice) { - create_authorized_mfa_device_for_network( - context, - context.network.id, - device_name, - device_pubkey, - device_ip, - preshared_key, - ) - .await -} - -async fn create_authorized_mfa_device_for_network( - context: &HandlerTestContext, - network_id: Id, - device_name: &str, - device_pubkey: &str, - device_ip: &str, - preshared_key: Option<&str>, -) -> (Device, WireguardNetworkDevice) { - let device = create_device_for_network( - context, - network_id, - device_name, - device_pubkey, - device_ip, - preshared_key, - ) - .await; - let mut network_device = WireguardNetworkDevice::find(&context.pool, device.id, network_id) - .await - .expect("failed to load MFA device network info") - .expect("expected MFA device network info"); - network_device.is_authorized = true; - network_device.preshared_key = preshared_key.map(str::to_owned); - network_device - .update(&context.pool) - .await - .expect("failed to persist MFA device network info"); - - (device, network_device) -} - -async fn create_device_info_for_network( - context: &HandlerTestContext, - network_id: Id, - device_name: &str, - device_pubkey: &str, - device_ip: &str, - preshared_key: Option<&str>, -) -> DeviceInfo { - let device = create_device_for_network( - context, - network_id, - device_name, - device_pubkey, - device_ip, - preshared_key, - ) - .await; - - DeviceInfo::from_device(&context.pool, device) - .await - .expect("failed to load device info") -} - -async fn create_device_for_network( - context: &HandlerTestContext, - network_id: Id, - device_name: &str, - device_pubkey: &str, - device_ip: &str, - preshared_key: Option<&str>, -) -> Device { - let username = format!("{device_name}-user"); - let email = format!("{device_name}@example.com"); - let user = User::new( - username, - Some("pass123"), - "Peer".to_string(), - "Test".to_string(), - email, - None, - ) - .save(&context.pool) - .await - .expect("failed to create test user"); - let device = Device::new( - device_name.to_string(), - device_pubkey.to_string(), - user.id, - DeviceType::User, - None, - true, - ) - .save(&context.pool) - .await - .expect("failed to create test device"); - - let mut network_device = - WireguardNetworkDevice::new(network_id, device.id, vec![parse_test_ip(device_ip)]); - network_device.preshared_key = preshared_key.map(str::to_owned); - network_device - .insert(&context.pool) - .await - .expect("failed to attach device to network"); - - device -} - -async fn enable_internal_mfa_for_network(pool: &sqlx::PgPool, network: &mut WireguardNetwork) { - network.location_mfa_mode = LocationMfaMode::Internal; - network - .save(pool) - .await - .expect("failed to enable MFA for test network"); - assert!(network.mfa_enabled()); -} - -async fn assert_device_event_is_ignored_before_config_handshake( - options: PgConnectOptions, - device_name: &str, - device_pubkey: &str, - device_ip: &str, - preshared_key: Option<&str>, - build_event: fn(DeviceInfo) -> GatewayEvent, -) { - let mut context = HandlerTestContext::new(options).await; - assert_eq!(context.events_tx().receiver_count(), 0); - - let _broadcast_guard = context.events_tx().subscribe(); - let device_info = create_device_info_for_current_network( - &context, - device_name, - device_pubkey, - device_ip, - preshared_key, - ) - .await; - - assert_send_ok!( - context.events_tx().send(build_event(device_info)), - "failed to broadcast ignored device event" - ); - - context.mock_gateway_mut().expect_no_outbound().await; - - context.finish().await.expect_server_finished().await; -} - -async fn assert_device_event_for_different_network_is_ignored( - options: PgConnectOptions, - device_name: &str, - device_pubkey: &str, - device_ip: &str, - preshared_key: Option<&str>, - build_event: fn(DeviceInfo) -> GatewayEvent, -) { - let mut context = HandlerTestContext::new(options).await; - let other_network = context.create_other_network().await; - assert_ne!(other_network.id, context.network.id); - - let _ = context.complete_config_handshake().await; - let device_info = create_device_info_for_network( - &context, - other_network.id, - device_name, - device_pubkey, - device_ip, - preshared_key, - ) - .await; - - assert_send_ok!( - context.events_tx().send(build_event(device_info)), - "failed to broadcast ignored device event" - ); - - context.mock_gateway_mut().expect_no_outbound().await; - - context.finish().await.expect_server_finished().await; -} - -async fn assert_firewall_event_for_different_network_is_ignored( - options: PgConnectOptions, - build_event: impl FnOnce(Id) -> GatewayEvent, -) { - let mut context = HandlerTestContext::new(options).await; - let other_network = context.create_other_network().await; - assert_ne!(other_network.id, context.network.id); - - let _ = context.complete_config_handshake().await; - - assert_send_ok!( - context.events_tx().send(build_event(other_network.id)), - "failed to broadcast ignored firewall event" - ); - - context.mock_gateway_mut().expect_no_outbound().await; - - context.finish().await.expect_server_finished().await; -} - -fn expected_keepalive_interval(context: &HandlerTestContext) -> u32 { - u32::try_from(context.network.keepalive_interval) - .expect("expected non-negative network keepalive interval") -} - -fn parse_test_ip(ip: &str) -> IpAddr { - ip.parse().expect("failed to parse test peer IP address") -} - -fn assert_peer_update( - outbound: CoreResponse, - expected_update_type: UpdateType, - expected_pubkey: &str, - expected_allowed_ips: &[&str], - expected_preshared_key: Option<&str>, - expected_keepalive_interval: Option, -) { - match outbound.payload { - Some(core_response::Payload::Update(Update { - update_type, - update: Some(update::Update::Peer(peer)), - })) => { - assert_eq!(update_type, expected_update_type as i32); - assert_eq!(peer.pubkey, expected_pubkey); - assert_eq!( - peer.allowed_ips, - expected_allowed_ips - .iter() - .map(|allowed_ip| allowed_ip.to_string()) - .collect::>() - ); - assert_eq!(peer.preshared_key.as_deref(), expected_preshared_key); - assert_eq!(peer.keepalive_interval, expected_keepalive_interval); - } - _ => panic_unexpected!("expected peer update"), - } -} - -fn assert_network_delete_update(outbound: CoreResponse, expected_network_name: &str) { - match outbound.payload { - Some(core_response::Payload::Update(Update { - update_type, - update: Some(update::Update::Network(network)), - })) => { - assert_eq!(update_type, UpdateType::Delete as i32); - assert_eq!(network.name, expected_network_name); - } - _ => panic_unexpected!("expected network delete update"), - } -} - -fn assert_network_create_update( - outbound: CoreResponse, - expected_network_name: &str, - expected_address: &str, - expected_port: u32, - expected_mtu: u32, - expected_fwmark: u32, -) { - match outbound.payload { - Some(core_response::Payload::Update(Update { - update_type, - update: Some(update::Update::Network(network)), - })) => { - assert_eq!(update_type, UpdateType::Create as i32); - assert_eq!(network.name, expected_network_name); - assert_eq!(network.addresses, vec![expected_address.to_string()]); - assert_eq!(network.port, expected_port); - assert_eq!(network.peers, Vec::new()); - assert_eq!(network.firewall_config, None); - assert_eq!(network.mtu, expected_mtu); - assert_eq!(network.fwmark, expected_fwmark); - } - _ => panic_unexpected!("expected network create update"), - } -} - -fn assert_network_modify_update( - outbound: CoreResponse, - expected_network_name: &str, - expected_address: &str, - expected_port: u32, - expected_mtu: u32, - expected_fwmark: u32, -) { - match outbound.payload { - Some(core_response::Payload::Update(Update { - update_type, - update: Some(update::Update::Network(network)), - })) => { - assert_eq!(update_type, UpdateType::Modify as i32); - assert_eq!(network.name, expected_network_name); - assert_eq!(network.addresses, vec![expected_address.to_string()]); - assert_eq!(network.port, expected_port); - assert_eq!(network.peers, Vec::new()); - assert_eq!(network.firewall_config, None); - assert_eq!(network.mtu, expected_mtu); - assert_eq!(network.fwmark, expected_fwmark); - } - _ => panic_unexpected!("expected network modify update"), - } -} - -fn build_test_firewall_config() -> FirewallConfig { - FirewallConfig { - default_policy: i32::from(FirewallPolicy::Allow), - rules: vec![FirewallRule { - id: 101, - source_addrs: vec![IpAddress { - address: Some(Address::IpSubnet("10.10.0.0/24".to_string())), - }], - destination_addrs: vec![IpAddress { - address: Some(Address::Ip("198.51.100.20".to_string())), - }], - destination_ports: vec![Port { - port: Some(PortInner::SinglePort(443)), - }], - protocols: vec![i32::from(Protocol::Tcp)], - verdict: i32::from(FirewallPolicy::Deny), - comment: Some("block test https destination".to_string()), - ip_version: i32::from(IpVersion::Ipv4), - }], - snat_bindings: vec![SnatBinding { - id: 202, - source_addrs: vec![IpAddress { - address: Some(Address::IpSubnet("10.10.0.0/24".to_string())), - }], - public_ip: "203.0.113.44".to_string(), - comment: Some("test snat binding".to_string()), - }], - } -} - -fn assert_firewall_modify_update( - outbound: CoreResponse, - expected_firewall_config: &FirewallConfig, -) { - match outbound.payload { - Some(core_response::Payload::Update(Update { - update_type, - update: Some(update::Update::FirewallConfig(firewall_config)), - })) => { - assert_eq!(update_type, UpdateType::Modify as i32); - assert_eq!( - firewall_config.default_policy, - expected_firewall_config.default_policy - ); - assert_eq!( - firewall_config.rules.len(), - expected_firewall_config.rules.len() - ); - assert_eq!( - firewall_config.snat_bindings.len(), - expected_firewall_config.snat_bindings.len() - ); - - let firewall_rule = firewall_config - .rules - .first() - .expect("expected firewall rule in update payload"); - let expected_firewall_rule = expected_firewall_config - .rules - .first() - .expect("expected firewall rule in test config"); - assert_eq!(firewall_rule.id, expected_firewall_rule.id); - assert_eq!( - firewall_rule.source_addrs, - expected_firewall_rule.source_addrs - ); - assert_eq!( - firewall_rule.destination_addrs, - expected_firewall_rule.destination_addrs - ); - assert_eq!( - firewall_rule.destination_ports, - expected_firewall_rule.destination_ports - ); - assert_eq!(firewall_rule.protocols, expected_firewall_rule.protocols); - assert_eq!(firewall_rule.verdict, expected_firewall_rule.verdict); - assert_eq!(firewall_rule.comment, expected_firewall_rule.comment); - assert_eq!(firewall_rule.ip_version, expected_firewall_rule.ip_version); - - let snat_binding = firewall_config - .snat_bindings - .first() - .expect("expected SNAT binding in update payload"); - let expected_snat_binding = expected_firewall_config - .snat_bindings - .first() - .expect("expected SNAT binding in test config"); - assert_eq!(snat_binding.id, expected_snat_binding.id); - assert_eq!( - snat_binding.source_addrs, - expected_snat_binding.source_addrs - ); - assert_eq!(snat_binding.public_ip, expected_snat_binding.public_ip); - assert_eq!(snat_binding.comment, expected_snat_binding.comment); - } - _ => panic_unexpected!("expected firewall config update"), - } -} - -fn assert_firewall_disable_update(outbound: CoreResponse) { - match outbound.payload { - Some(core_response::Payload::Update(Update { - update_type, - update: Some(update::Update::DisableFirewall(())), - })) => { - assert_eq!(update_type, UpdateType::Delete as i32); - } - _ => panic_unexpected!("expected firewall disable update"), - } -} +include!("handler/handshake.rs"); +include!("handler/lifecycle.rs"); +include!("handler/stats.rs"); +include!("handler/network_events.rs"); +include!("handler/firewall_events.rs"); +include!("handler/device_events.rs"); +include!("handler/mfa.rs"); diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/handler/device_events.rs b/crates/defguard_gateway_manager/tests/gateway_manager/handler/device_events.rs new file mode 100644 index 0000000000..f7c9042c39 --- /dev/null +++ b/crates/defguard_gateway_manager/tests/gateway_manager/handler/device_events.rs @@ -0,0 +1,227 @@ +#[sqlx::test] +async fn test_device_created_for_network_produces_peer_create_update( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + let expected_keepalive_interval = expected_keepalive_interval(&context); + + let _ = context.complete_config_handshake().await; + let device_info = create_device_info_for_current_network( + &context, + "created-peer-device", + "LQKsT6/3HWKuJmMulH63R8iK+5sI8FyYEL6WDIi6lQU=", + "10.10.0.10", + Some("created-preshared-key"), + ) + .await; + + assert_send_ok!( + context + .events_tx() + .send(GatewayEvent::DeviceCreated(device_info)), + "failed to broadcast created device event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_peer_update( + outbound, + UpdateType::Create, + "LQKsT6/3HWKuJmMulH63R8iK+5sI8FyYEL6WDIi6lQU=", + &["10.10.0.10"], + Some("created-preshared-key"), + Some(expected_keepalive_interval), + ); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_device_created_before_config_handshake_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + assert_device_event_is_ignored_before_config_handshake( + options, + "created-before-config-device", + "tND8hJQhYnI8naBTo59He43zYldagfjlwmSxWEc01Cc=", + "10.10.0.11", + Some("created-before-config-preshared-key"), + GatewayEvent::DeviceCreated, + ) + .await; +} + +#[sqlx::test] +async fn test_device_modified_for_network_produces_peer_modify_update( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + let expected_keepalive_interval = expected_keepalive_interval(&context); + + let _ = context.complete_config_handshake().await; + let device = create_device_for_network( + &context, + context.network.id, + "modified-peer-device", + "TJgN9JzUF5zdZAPYD96G/Wys2M3TvaT5TIrErUl20nI=", + "10.10.0.20", + Some("initial-preshared-key"), + ) + .await; + + let mut network_device = + WireguardNetworkDevice::find(&context.pool, device.id, context.network.id) + .await + .expect("failed to load device network info") + .expect("expected device network info for modified device"); + network_device.wireguard_ips = vec![parse_test_ip("10.10.0.21")]; + network_device.preshared_key = Some("modified-preshared-key".to_string()); + network_device + .update(&context.pool) + .await + .expect("failed to update device network info"); + let device_info = DeviceInfo::from_device(&context.pool, device) + .await + .expect("failed to load modified device info"); + + assert_send_ok!( + context + .events_tx() + .send(GatewayEvent::DeviceModified(device_info)), + "failed to broadcast modified device event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_peer_update( + outbound, + UpdateType::Modify, + "TJgN9JzUF5zdZAPYD96G/Wys2M3TvaT5TIrErUl20nI=", + &["10.10.0.21"], + Some("modified-preshared-key"), + Some(expected_keepalive_interval), + ); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_device_modified_before_config_handshake_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + assert_device_event_is_ignored_before_config_handshake( + options, + "modified-before-config-device", + "wyFOHCec/Fi9s+cARikVO71JhyYtYMk0FrQx3fK2PTM=", + "10.10.0.22", + Some("modified-before-config-preshared-key"), + GatewayEvent::DeviceModified, + ) + .await; +} + +#[sqlx::test] +async fn test_device_deleted_for_network_produces_peer_delete_update( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + let _ = context.complete_config_handshake().await; + let device_info = create_device_info_for_current_network( + &context, + "deleted-peer-device", + "PKY3zg5/ecNyMjqLi6yJ3jwb4PvC/SGzjhJ3jrn2vVQ=", + "10.10.0.30", + Some("deleted-preshared-key"), + ) + .await; + + assert_send_ok!( + context + .events_tx() + .send(GatewayEvent::DeviceDeleted(device_info)), + "failed to broadcast deleted device event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_peer_update( + outbound, + UpdateType::Delete, + "PKY3zg5/ecNyMjqLi6yJ3jwb4PvC/SGzjhJ3jrn2vVQ=", + &[], + None, + None, + ); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_device_deleted_before_config_handshake_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + assert_device_event_is_ignored_before_config_handshake( + options, + "deleted-before-config-device", + "m84QJmDMkqdCj8AB2NTE8F55W7M/i3CaaD3eQbQdInY=", + "10.10.0.31", + Some("deleted-before-config-preshared-key"), + GatewayEvent::DeviceDeleted, + ) + .await; +} + +#[sqlx::test] +async fn test_device_created_for_different_network_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + assert_device_event_for_different_network_is_ignored( + options, + "created-other-network-device", + "W6wBmd8wgTwvCyGqDRXk6Hf4OMqDUbUn2XWKnG5wVVQ=", + "10.11.0.10", + Some("created-other-network-preshared-key"), + GatewayEvent::DeviceCreated, + ) + .await; +} + +#[sqlx::test] +async fn test_device_modified_for_different_network_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + assert_device_event_for_different_network_is_ignored( + options, + "modified-other-network-device", + "yjuzq0cLk3Ww5oQcqK6YkSKwXnqQ1V9OlSMFAEkr0lU=", + "10.11.0.20", + Some("modified-other-network-preshared-key"), + GatewayEvent::DeviceModified, + ) + .await; +} + +#[sqlx::test] +async fn test_device_deleted_for_different_network_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + assert_device_event_for_different_network_is_ignored( + options, + "deleted-other-network-device", + "Jtp+K8xnFXuF4cae+tVGZNwoSM2fXjJbRl3sI6rdcAQ=", + "10.11.0.30", + Some("deleted-other-network-preshared-key"), + GatewayEvent::DeviceDeleted, + ) + .await; +} diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/handler/firewall_events.rs b/crates/defguard_gateway_manager/tests/gateway_manager/handler/firewall_events.rs new file mode 100644 index 0000000000..1545aab35a --- /dev/null +++ b/crates/defguard_gateway_manager/tests/gateway_manager/handler/firewall_events.rs @@ -0,0 +1,73 @@ +#[sqlx::test] +async fn test_matching_location_firewall_config_changed_event_produces_update( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + let expected_firewall_config = build_test_firewall_config(); + + let _ = context.complete_config_handshake().await; + + assert_send_ok!( + context + .events_tx() + .send(GatewayEvent::FirewallConfigChanged( + context.network.id, + expected_firewall_config.clone(), + )), + "failed to broadcast firewall config changed event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_firewall_modify_update(outbound, &expected_firewall_config); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_matching_location_firewall_disabled_event_produces_disable_update( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + let _ = context.complete_config_handshake().await; + + assert_send_ok!( + context + .events_tx() + .send(GatewayEvent::FirewallDisabled(context.network.id)), + "failed to broadcast firewall disabled event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_firewall_disable_update(outbound); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_different_location_firewall_config_changed_event_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let expected_firewall_config = build_test_firewall_config(); + + assert_firewall_event_for_different_network_is_ignored(options, move |other_network_id| { + GatewayEvent::FirewallConfigChanged(other_network_id, expected_firewall_config) + }) + .await; +} + +#[sqlx::test] +async fn test_different_location_firewall_disabled_event_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + assert_firewall_event_for_different_network_is_ignored(options, |other_network_id| { + GatewayEvent::FirewallDisabled(other_network_id) + }) + .await; +} diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/handler/handshake.rs b/crates/defguard_gateway_manager/tests/gateway_manager/handler/handshake.rs new file mode 100644 index 0000000000..bd83c1e5d6 --- /dev/null +++ b/crates/defguard_gateway_manager/tests/gateway_manager/handler/handshake.rs @@ -0,0 +1,56 @@ +#[sqlx::test] +async fn test_sends_configuration_on_first_config_request( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + context.mock_gateway().send_config_request(); + let outbound = context.mock_gateway_mut().recv_outbound().await; + + match outbound.payload { + Some(core_response::Payload::Config(config)) => { + assert_eq!(config.name, context.network.name); + assert_eq!(config.port, context.network.port as u32); + assert_eq!(config.peers, Vec::new()); + } + _ => panic_unexpected("expected configuration response"), + } + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_does_not_send_configuration_before_gateway_requests_it( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + let gateway_before = context.reload_gateway().await; + assert!(!gateway_before.is_connected()); + + context.mock_gateway_mut().expect_no_outbound().await; + + let gateway_after = context.reload_gateway().await; + assert!(!gateway_after.is_connected()); + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_ignores_repeated_config_request(_: PgPoolOptions, options: PgConnectOptions) { + let mut context = HandlerTestContext::new(options).await; + + context.mock_gateway().send_config_request(); + let first_outbound = context.mock_gateway_mut().recv_outbound().await; + assert!(matches!( + first_outbound.payload, + Some(core_response::Payload::Config(_)) + )); + + context.mock_gateway().send_config_request(); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/handler/lifecycle.rs b/crates/defguard_gateway_manager/tests/gateway_manager/handler/lifecycle.rs new file mode 100644 index 0000000000..8c86a03789 --- /dev/null +++ b/crates/defguard_gateway_manager/tests/gateway_manager/handler/lifecycle.rs @@ -0,0 +1,59 @@ +#[sqlx::test] +async fn test_gateway_is_marked_connected_after_successful_config_handshake( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + let gateway_before = context.reload_gateway().await; + assert!(!gateway_before.is_connected()); + + let gateway_after = context.complete_config_handshake().await; + assert!(gateway_after.is_connected()); + assert!(gateway_after.connected_at.is_some()); + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_gateway_is_marked_disconnected_when_stream_closes( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + let connected_gateway = context.complete_config_handshake().await; + assert!(connected_gateway.is_connected()); + + let pool = context.pool.clone(); + let gateway_id = context.gateway.id; + let mock_gateway = context.finish().await; + let disconnected_gateway = reload_gateway(&pool, gateway_id).await; + assert!(!disconnected_gateway.is_connected()); + assert!(disconnected_gateway.disconnected_at.is_some()); + + mock_gateway.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_gateway_is_marked_disconnected_when_stream_errors( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + let _ = context.complete_config_handshake().await; + + context + .mock_gateway() + .send_stream_error(Status::internal("mock gateway stream failure")); + + let pool = context.pool.clone(); + let gateway_id = context.gateway.id; + let mock_gateway = context.finish_after_error().await; + let disconnected_gateway = reload_gateway(&pool, gateway_id).await; + assert!(!disconnected_gateway.is_connected()); + assert!(disconnected_gateway.disconnected_at.is_some()); + + mock_gateway.expect_server_finished().await; +} diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/handler/mfa.rs b/crates/defguard_gateway_manager/tests/gateway_manager/handler/mfa.rs new file mode 100644 index 0000000000..30b6da8895 --- /dev/null +++ b/crates/defguard_gateway_manager/tests/gateway_manager/handler/mfa.rs @@ -0,0 +1,120 @@ +#[sqlx::test] +async fn test_matching_location_mfa_session_authorized_produces_peer_create( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + let expected_keepalive_interval = expected_keepalive_interval(&context); + enable_internal_mfa_for_network(&context.pool, &mut context.network).await; + + let _ = context.complete_config_handshake().await; + let (device, network_device) = create_authorized_mfa_device_for_current_network( + &context, + "mfa-authorized-device", + "4v9K9Q4HEdmlX0Mb4uxDLPq3nKjvU8fNnJ9fKjzh4ko=", + "10.10.0.40", + Some("mfa-authorized-preshared-key"), + ) + .await; + + assert_send_ok!( + context.events_tx().send(GatewayEvent::MfaSessionAuthorized( + context.network.id, + device, + network_device, + )), + "failed to broadcast MFA session authorized event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_peer_update( + outbound, + UpdateType::Create, + "4v9K9Q4HEdmlX0Mb4uxDLPq3nKjvU8fNnJ9fKjzh4ko=", + &["10.10.0.40"], + Some("mfa-authorized-preshared-key"), + Some(expected_keepalive_interval), + ); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_mfa_session_authorized_with_mismatched_network_id_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + enable_internal_mfa_for_network(&context.pool, &mut context.network).await; + + let mut other_network = context.create_other_network().await; + enable_internal_mfa_for_network(&context.pool, &mut other_network).await; + assert_ne!(other_network.id, context.network.id); + + let _ = context.complete_config_handshake().await; + let (device, network_device) = create_authorized_mfa_device_for_network( + &context, + other_network.id, + "mfa-mismatched-network-device", + "Z2UuIvYJvU5fTOp8i3tHfLm4xZ0R8ExY6E3S3l+rqT8=", + "10.11.0.40", + Some("mfa-mismatched-network-preshared-key"), + ) + .await; + + assert_send_ok!( + context.events_tx().send(GatewayEvent::MfaSessionAuthorized( + context.network.id, + device, + network_device, + )), + "failed to broadcast mismatched MFA session authorized event" + ); + + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_matching_location_mfa_session_disconnected_produces_peer_delete( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + enable_internal_mfa_for_network(&context.pool, &mut context.network).await; + + let _ = context.complete_config_handshake().await; + let (device, _) = create_authorized_mfa_device_for_current_network( + &context, + "mfa-disconnected-device", + "2+n8hQ1yA2sPp1z2i6m8lP4VtY7M8W6hYqS3n4uL7qg=", + "10.10.0.41", + Some("mfa-disconnected-preshared-key"), + ) + .await; + + assert_send_ok!( + context + .events_tx() + .send(GatewayEvent::MfaSessionDisconnected( + context.network.id, + device, + )), + "failed to broadcast MFA session disconnected event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_peer_update( + outbound, + UpdateType::Delete, + "2+n8hQ1yA2sPp1z2i6m8lP4VtY7M8W6hYqS3n4uL7qg=", + &[], + None, + None, + ); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/handler/network_events.rs b/crates/defguard_gateway_manager/tests/gateway_manager/handler/network_events.rs new file mode 100644 index 0000000000..607bfc5dc9 --- /dev/null +++ b/crates/defguard_gateway_manager/tests/gateway_manager/handler/network_events.rs @@ -0,0 +1,233 @@ +#[sqlx::test] +async fn test_matching_location_network_deleted_event_produces_delete_update( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + let _ = context.complete_config_handshake().await; + + assert_send_ok!( + context.events_tx().send(GatewayEvent::NetworkDeleted( + context.network.id, + context.network.name.clone(), + )), + "failed to broadcast gateway event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_network_delete_update(outbound, &context.network.name); + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_matching_location_network_modified_event_produces_modify_update( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + let _ = context.complete_config_handshake().await; + + let mut modified_network = context.network.clone(); + modified_network.name = format!("{}-modified", context.network.name); + modified_network.address = vec![ + "10.20.0.1/24" + .parse() + .expect("failed to parse modified network address"), + ]; + modified_network.port = 51821; + modified_network.mtu = 1380; + modified_network.fwmark = 42; + + assert_send_ok!( + context.events_tx().send(GatewayEvent::NetworkModified( + context.network.id, + modified_network, + Vec::new(), + None, + )), + "failed to broadcast modified gateway event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_network_modify_update( + outbound, + &format!("{}-modified", context.network.name), + "10.20.0.1/24", + 51821, + 1380, + 42, + ); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_matching_location_network_created_event_produces_create_update( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + let _ = context.complete_config_handshake().await; + + let mut created_network = context.network.clone(); + created_network.name = format!("{}-created", context.network.name); + created_network.address = vec![ + "10.40.0.1/24" + .parse() + .expect("failed to parse created network address"), + ]; + created_network.port = 51841; + created_network.mtu = 1410; + created_network.fwmark = 17; + + assert_send_ok!( + context.events_tx().send(GatewayEvent::NetworkCreated( + context.network.id, + created_network, + )), + "failed to broadcast created gateway event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_network_create_update( + outbound, + &format!("{}-created", context.network.name), + "10.40.0.1/24", + 51841, + 1410, + 17, + ); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_only_matching_handler_receives_network_modified_update( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let (events_tx, _) = tokio::sync::broadcast::channel(16); + let mut matching_context = + HandlerTestContext::new_with_events_tx(options.clone(), events_tx.clone()).await; + let mut unrelated_context = HandlerTestContext::new_with_events_tx(options, events_tx).await; + + assert_ne!(matching_context.network.id, unrelated_context.network.id); + + let _ = matching_context.complete_config_handshake().await; + let _ = unrelated_context.complete_config_handshake().await; + + let mut modified_network = matching_context.network.clone(); + modified_network.name = format!("{}-modified", matching_context.network.name); + modified_network.address = vec![ + "10.30.0.1/24" + .parse() + .expect("failed to parse modified network address"), + ]; + modified_network.port = 51831; + modified_network.mtu = 1400; + modified_network.fwmark = 7; + + assert_send_ok!( + matching_context + .events_tx() + .send(GatewayEvent::NetworkModified( + matching_context.network.id, + modified_network, + Vec::new(), + None, + )), + "failed to broadcast modified gateway event" + ); + + let outbound = matching_context.mock_gateway_mut().recv_outbound().await; + assert_network_modify_update( + outbound, + &format!("{}-modified", matching_context.network.name), + "10.30.0.1/24", + 51831, + 1400, + 7, + ); + matching_context + .mock_gateway_mut() + .expect_no_outbound() + .await; + unrelated_context + .mock_gateway_mut() + .expect_no_outbound() + .await; + + matching_context + .finish() + .await + .expect_server_finished() + .await; + unrelated_context + .finish() + .await + .expect_server_finished() + .await; +} + +#[sqlx::test] +async fn test_different_location_network_created_event_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + let other_network = context.create_other_network().await; + assert_ne!(other_network.id, context.network.id); + + let _ = context.complete_config_handshake().await; + assert_send_ok!( + context.events_tx().send(GatewayEvent::NetworkCreated( + other_network.id, + other_network, + )), + "failed to broadcast unrelated created gateway event" + ); + + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_different_location_network_deleted_event_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + let other_network = context.create_other_network().await; + assert_ne!(other_network.id, context.network.id); + + let _ = context.complete_config_handshake().await; + assert_send_ok!( + context.events_tx().send(GatewayEvent::NetworkDeleted( + other_network.id, + other_network.name.clone(), + )), + "failed to broadcast unrelated gateway event" + ); + + context.mock_gateway_mut().expect_no_outbound().await; + + assert_send_ok!( + context.events_tx().send(GatewayEvent::NetworkDeleted( + context.network.id, + context.network.name.clone(), + )), + "failed to broadcast owned gateway event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_network_delete_update(outbound, &context.network.name); + + context.finish().await.expect_server_finished().await; +} diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/handler/stats.rs b/crates/defguard_gateway_manager/tests/gateway_manager/handler/stats.rs new file mode 100644 index 0000000000..285fb24fcb --- /dev/null +++ b/crates/defguard_gateway_manager/tests/gateway_manager/handler/stats.rs @@ -0,0 +1,58 @@ +#[sqlx::test] +async fn test_ignores_peer_stats_before_config_handshake( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + context + .mock_gateway() + .send_peer_stats(build_peer_stats("203.0.113.10:51820")); + + context.expect_no_peer_stats().await; + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_forwards_valid_peer_stats_after_config(_: PgPoolOptions, options: PgConnectOptions) { + let mut context = HandlerTestContext::new(options).await; + + context.mock_gateway().send_config_request(); + let _ = context.mock_gateway_mut().recv_outbound().await; + context + .mock_gateway() + .send_peer_stats(build_peer_stats("203.0.113.10:51820")); + + let forwarded = context.recv_peer_stats().await; + assert_eq!(forwarded.location_id, context.network.id); + assert_eq!(forwarded.gateway_id, context.gateway.id); + assert_eq!(forwarded.device_pubkey, "peer-public-key"); + assert_eq!(forwarded.endpoint.to_string(), "203.0.113.10:51820"); + assert_eq!(forwarded.upload, 123); + assert_eq!(forwarded.download, 456); + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_drops_malformed_or_missing_endpoint_peer_stats( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + context.mock_gateway().send_config_request(); + let _ = context.mock_gateway_mut().recv_outbound().await; + + context.mock_gateway().send_peer_stats(build_peer_stats("")); + context.expect_no_peer_stats().await; + + context + .mock_gateway() + .send_peer_stats(build_peer_stats("not-a-socket-address")); + context.expect_no_peer_stats().await; + + context.finish().await.expect_server_finished().await; +} diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/handler/support.rs b/crates/defguard_gateway_manager/tests/gateway_manager/handler/support.rs new file mode 100644 index 0000000000..5f1f80ef4c --- /dev/null +++ b/crates/defguard_gateway_manager/tests/gateway_manager/handler/support.rs @@ -0,0 +1,481 @@ +use std::net::IpAddr; + +use defguard_common::db::{ + Id, + models::{ + device::{Device, DeviceInfo, DeviceType, WireguardNetworkDevice}, + user::User, + wireguard::{LocationMfaMode, WireguardNetwork}, + }, +}; +use defguard_core::grpc::GatewayEvent; +use defguard_proto::enterprise::firewall::{ + FirewallConfig, FirewallPolicy, FirewallRule, IpAddress, IpVersion, Port, Protocol, + SnatBinding, ip_address::Address, port::Port as PortInner, +}; +use defguard_proto::gateway::{ + CoreResponse, Update, UpdateType, core_response, + update::{self}, +}; +use sqlx::postgres::PgConnectOptions; + +use crate::common::HandlerTestContext; + +macro_rules! assert_send_ok { + ($result:expr, $message:literal) => { + match $result { + Ok(value) => value, + Err(_) => panic!($message), + } + }; +} + +pub(crate) use assert_send_ok; + +pub(crate) fn panic_unexpected(message: &str) -> ! { + panic!("{message}") +} + +pub(crate) async fn create_device_info_for_current_network( + context: &HandlerTestContext, + device_name: &str, + device_pubkey: &str, + device_ip: &str, + preshared_key: Option<&str>, +) -> DeviceInfo { + create_device_info_for_network( + context, + context.network.id, + device_name, + device_pubkey, + device_ip, + preshared_key, + ) + .await +} + +pub(crate) async fn create_authorized_mfa_device_for_current_network( + context: &HandlerTestContext, + device_name: &str, + device_pubkey: &str, + device_ip: &str, + preshared_key: Option<&str>, +) -> (Device, WireguardNetworkDevice) { + create_authorized_mfa_device_for_network( + context, + context.network.id, + device_name, + device_pubkey, + device_ip, + preshared_key, + ) + .await +} + +pub(crate) async fn create_authorized_mfa_device_for_network( + context: &HandlerTestContext, + network_id: Id, + device_name: &str, + device_pubkey: &str, + device_ip: &str, + preshared_key: Option<&str>, +) -> (Device, WireguardNetworkDevice) { + let device = create_device_for_network( + context, + network_id, + device_name, + device_pubkey, + device_ip, + preshared_key, + ) + .await; + let mut network_device = WireguardNetworkDevice::find(&context.pool, device.id, network_id) + .await + .expect("failed to load MFA device network info") + .expect("expected MFA device network info"); + network_device.is_authorized = true; + network_device.preshared_key = preshared_key.map(str::to_owned); + network_device + .update(&context.pool) + .await + .expect("failed to persist MFA device network info"); + + (device, network_device) +} + +pub(crate) async fn create_device_info_for_network( + context: &HandlerTestContext, + network_id: Id, + device_name: &str, + device_pubkey: &str, + device_ip: &str, + preshared_key: Option<&str>, +) -> DeviceInfo { + let device = create_device_for_network( + context, + network_id, + device_name, + device_pubkey, + device_ip, + preshared_key, + ) + .await; + + DeviceInfo::from_device(&context.pool, device) + .await + .expect("failed to load device info") +} + +pub(crate) async fn create_device_for_network( + context: &HandlerTestContext, + network_id: Id, + device_name: &str, + device_pubkey: &str, + device_ip: &str, + preshared_key: Option<&str>, +) -> Device { + let username = format!("{device_name}-user"); + let email = format!("{device_name}@example.com"); + let user = User::new( + username, + Some("pass123"), + "Peer".to_string(), + "Test".to_string(), + email, + None, + ) + .save(&context.pool) + .await + .expect("failed to create test user"); + let device = Device::new( + device_name.to_string(), + device_pubkey.to_string(), + user.id, + DeviceType::User, + None, + true, + ) + .save(&context.pool) + .await + .expect("failed to create test device"); + + let mut network_device = + WireguardNetworkDevice::new(network_id, device.id, vec![parse_test_ip(device_ip)]); + network_device.preshared_key = preshared_key.map(str::to_owned); + network_device + .insert(&context.pool) + .await + .expect("failed to attach device to network"); + + device +} + +pub(crate) async fn enable_internal_mfa_for_network( + pool: &sqlx::PgPool, + network: &mut WireguardNetwork, +) { + network.location_mfa_mode = LocationMfaMode::Internal; + network + .save(pool) + .await + .expect("failed to enable MFA for test network"); + assert!(network.mfa_enabled()); +} + +pub(crate) async fn assert_device_event_is_ignored_before_config_handshake( + options: PgConnectOptions, + device_name: &str, + device_pubkey: &str, + device_ip: &str, + preshared_key: Option<&str>, + build_event: fn(DeviceInfo) -> GatewayEvent, +) { + let mut context = HandlerTestContext::new(options).await; + assert_eq!(context.events_tx().receiver_count(), 0); + + let _broadcast_guard = context.events_tx().subscribe(); + let device_info = create_device_info_for_current_network( + &context, + device_name, + device_pubkey, + device_ip, + preshared_key, + ) + .await; + + assert_send_ok!( + context.events_tx().send(build_event(device_info)), + "failed to broadcast ignored device event" + ); + + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +pub(crate) async fn assert_device_event_for_different_network_is_ignored( + options: PgConnectOptions, + device_name: &str, + device_pubkey: &str, + device_ip: &str, + preshared_key: Option<&str>, + build_event: fn(DeviceInfo) -> GatewayEvent, +) { + let mut context = HandlerTestContext::new(options).await; + let other_network = context.create_other_network().await; + assert_ne!(other_network.id, context.network.id); + + let _ = context.complete_config_handshake().await; + let device_info = create_device_info_for_network( + &context, + other_network.id, + device_name, + device_pubkey, + device_ip, + preshared_key, + ) + .await; + + assert_send_ok!( + context.events_tx().send(build_event(device_info)), + "failed to broadcast ignored device event" + ); + + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +pub(crate) async fn assert_firewall_event_for_different_network_is_ignored( + options: PgConnectOptions, + build_event: impl FnOnce(Id) -> GatewayEvent, +) { + let mut context = HandlerTestContext::new(options).await; + let other_network = context.create_other_network().await; + assert_ne!(other_network.id, context.network.id); + + let _ = context.complete_config_handshake().await; + + assert_send_ok!( + context.events_tx().send(build_event(other_network.id)), + "failed to broadcast ignored firewall event" + ); + + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +pub(crate) fn expected_keepalive_interval(context: &HandlerTestContext) -> u32 { + u32::try_from(context.network.keepalive_interval) + .expect("expected non-negative network keepalive interval") +} + +pub(crate) fn parse_test_ip(ip: &str) -> IpAddr { + ip.parse().expect("failed to parse test peer IP address") +} + +pub(crate) fn assert_peer_update( + outbound: CoreResponse, + expected_update_type: UpdateType, + expected_pubkey: &str, + expected_allowed_ips: &[&str], + expected_preshared_key: Option<&str>, + expected_keepalive_interval: Option, +) { + match outbound.payload { + Some(core_response::Payload::Update(Update { + update_type, + update: Some(update::Update::Peer(peer)), + })) => { + assert_eq!(update_type, expected_update_type as i32); + assert_eq!(peer.pubkey, expected_pubkey); + assert_eq!( + peer.allowed_ips, + expected_allowed_ips + .iter() + .map(|allowed_ip| allowed_ip.to_string()) + .collect::>() + ); + assert_eq!(peer.preshared_key.as_deref(), expected_preshared_key); + assert_eq!(peer.keepalive_interval, expected_keepalive_interval); + } + _ => panic_unexpected("expected peer update"), + } +} + +pub(crate) fn assert_network_delete_update(outbound: CoreResponse, expected_network_name: &str) { + match outbound.payload { + Some(core_response::Payload::Update(Update { + update_type, + update: Some(update::Update::Network(network)), + })) => { + assert_eq!(update_type, UpdateType::Delete as i32); + assert_eq!(network.name, expected_network_name); + } + _ => panic_unexpected("expected network delete update"), + } +} + +pub(crate) fn assert_network_create_update( + outbound: CoreResponse, + expected_network_name: &str, + expected_address: &str, + expected_port: u32, + expected_mtu: u32, + expected_fwmark: u32, +) { + match outbound.payload { + Some(core_response::Payload::Update(Update { + update_type, + update: Some(update::Update::Network(network)), + })) => { + assert_eq!(update_type, UpdateType::Create as i32); + assert_eq!(network.name, expected_network_name); + assert_eq!(network.addresses, vec![expected_address.to_string()]); + assert_eq!(network.port, expected_port); + assert_eq!(network.peers, Vec::new()); + assert_eq!(network.firewall_config, None); + assert_eq!(network.mtu, expected_mtu); + assert_eq!(network.fwmark, expected_fwmark); + } + _ => panic_unexpected("expected network create update"), + } +} + +pub(crate) fn assert_network_modify_update( + outbound: CoreResponse, + expected_network_name: &str, + expected_address: &str, + expected_port: u32, + expected_mtu: u32, + expected_fwmark: u32, +) { + match outbound.payload { + Some(core_response::Payload::Update(Update { + update_type, + update: Some(update::Update::Network(network)), + })) => { + assert_eq!(update_type, UpdateType::Modify as i32); + assert_eq!(network.name, expected_network_name); + assert_eq!(network.addresses, vec![expected_address.to_string()]); + assert_eq!(network.port, expected_port); + assert_eq!(network.peers, Vec::new()); + assert_eq!(network.firewall_config, None); + assert_eq!(network.mtu, expected_mtu); + assert_eq!(network.fwmark, expected_fwmark); + } + _ => panic_unexpected("expected network modify update"), + } +} + +pub(crate) fn build_test_firewall_config() -> FirewallConfig { + FirewallConfig { + default_policy: i32::from(FirewallPolicy::Allow), + rules: vec![FirewallRule { + id: 101, + source_addrs: vec![IpAddress { + address: Some(Address::IpSubnet("10.10.0.0/24".to_string())), + }], + destination_addrs: vec![IpAddress { + address: Some(Address::Ip("198.51.100.20".to_string())), + }], + destination_ports: vec![Port { + port: Some(PortInner::SinglePort(443)), + }], + protocols: vec![i32::from(Protocol::Tcp)], + verdict: i32::from(FirewallPolicy::Deny), + comment: Some("block test https destination".to_string()), + ip_version: i32::from(IpVersion::Ipv4), + }], + snat_bindings: vec![SnatBinding { + id: 202, + source_addrs: vec![IpAddress { + address: Some(Address::IpSubnet("10.10.0.0/24".to_string())), + }], + public_ip: "203.0.113.44".to_string(), + comment: Some("test snat binding".to_string()), + }], + } +} + +pub(crate) fn assert_firewall_modify_update( + outbound: CoreResponse, + expected_firewall_config: &FirewallConfig, +) { + match outbound.payload { + Some(core_response::Payload::Update(Update { + update_type, + update: Some(update::Update::FirewallConfig(firewall_config)), + })) => { + assert_eq!(update_type, UpdateType::Modify as i32); + assert_eq!( + firewall_config.default_policy, + expected_firewall_config.default_policy + ); + assert_eq!( + firewall_config.rules.len(), + expected_firewall_config.rules.len() + ); + assert_eq!( + firewall_config.snat_bindings.len(), + expected_firewall_config.snat_bindings.len() + ); + + let firewall_rule = firewall_config + .rules + .first() + .expect("expected firewall rule in update payload"); + let expected_firewall_rule = expected_firewall_config + .rules + .first() + .expect("expected firewall rule in test config"); + assert_eq!(firewall_rule.id, expected_firewall_rule.id); + assert_eq!( + firewall_rule.source_addrs, + expected_firewall_rule.source_addrs + ); + assert_eq!( + firewall_rule.destination_addrs, + expected_firewall_rule.destination_addrs + ); + assert_eq!( + firewall_rule.destination_ports, + expected_firewall_rule.destination_ports + ); + assert_eq!(firewall_rule.protocols, expected_firewall_rule.protocols); + assert_eq!(firewall_rule.verdict, expected_firewall_rule.verdict); + assert_eq!(firewall_rule.comment, expected_firewall_rule.comment); + assert_eq!(firewall_rule.ip_version, expected_firewall_rule.ip_version); + + let snat_binding = firewall_config + .snat_bindings + .first() + .expect("expected SNAT binding in update payload"); + let expected_snat_binding = expected_firewall_config + .snat_bindings + .first() + .expect("expected SNAT binding in test config"); + assert_eq!(snat_binding.id, expected_snat_binding.id); + assert_eq!( + snat_binding.source_addrs, + expected_snat_binding.source_addrs + ); + assert_eq!(snat_binding.public_ip, expected_snat_binding.public_ip); + assert_eq!(snat_binding.comment, expected_snat_binding.comment); + } + _ => panic_unexpected("expected firewall config update"), + } +} + +pub(crate) fn assert_firewall_disable_update(outbound: CoreResponse) { + match outbound.payload { + Some(core_response::Payload::Update(Update { + update_type, + update: Some(update::Update::DisableFirewall(())), + })) => { + assert_eq!(update_type, UpdateType::Delete as i32); + } + _ => panic_unexpected("expected firewall disable update"), + } +} From 78075b2bea8ee5d62c121b7ed002b9692de66f0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Wed, 18 Mar 2026 11:14:05 +0100 Subject: [PATCH 16/36] add manager tests --- .../defguard_gateway_manager/src/handler.rs | 2 +- crates/defguard_gateway_manager/src/lib.rs | 171 ++++++++++++++++-- .../tests/common/mod.rs | 163 +++++++++++++++-- .../tests/gateway_manager/manager.rs | 112 ++++++++++++ .../tests/gateway_manager/mod.rs | 1 + 5 files changed, 422 insertions(+), 27 deletions(-) create mode 100644 crates/defguard_gateway_manager/tests/gateway_manager/manager.rs diff --git a/crates/defguard_gateway_manager/src/handler.rs b/crates/defguard_gateway_manager/src/handler.rs index 1948de549f..aec2d9f484 100644 --- a/crates/defguard_gateway_manager/src/handler.rs +++ b/crates/defguard_gateway_manager/src/handler.rs @@ -106,7 +106,7 @@ impl GatewayHandler { }) } - fn new_with_test_socket( + pub(crate) fn new_with_test_socket( gateway: Gateway, pool: PgPool, events_tx: Sender, diff --git a/crates/defguard_gateway_manager/src/lib.rs b/crates/defguard_gateway_manager/src/lib.rs index 19ed75c12c..c1a7173ffc 100644 --- a/crates/defguard_gateway_manager/src/lib.rs +++ b/crates/defguard_gateway_manager/src/lib.rs @@ -1,6 +1,10 @@ use std::{ collections::HashMap, - sync::{Arc, Mutex}, + path::PathBuf, + sync::{ + Arc, Mutex, + atomic::{AtomicBool, Ordering}, + }, time::Duration, }; @@ -13,7 +17,7 @@ use defguard_proto::gateway::gateway_client::GatewayClient; use defguard_version::client::ClientVersionInterceptor; use sqlx::{PgPool, postgres::PgListener}; use tokio::{ - sync::{broadcast::Sender, mpsc::UnboundedSender, watch::Receiver}, + sync::{Notify, broadcast::Sender, mpsc::UnboundedSender, watch::Receiver}, task::{AbortHandle, JoinSet}, }; use tonic::{Request, service::interceptor::InterceptedService, transport::Channel}; @@ -36,10 +40,119 @@ const TEN_SECS: Duration = Duration::from_secs(10); type Client = GatewayClient>; +struct AbortTaskOnDrop { + handle: Option>, +} + +impl AbortTaskOnDrop { + fn new(handle: tokio::task::JoinHandle) -> Self { + Self { + handle: Some(handle), + } + } +} + +impl Drop for AbortTaskOnDrop { + fn drop(&mut self) { + if let Some(handle) = self.handle.take() { + handle.abort(); + } + } +} + +#[derive(Clone, Default)] +struct GatewayManagerTestSupport { + socket_paths_by_url: Arc>>, + handler_spawn_attempts_by_gateway: Arc>>, + listener_ready: Arc, + listener_ready_notify: Arc, +} + +impl GatewayManagerTestSupport { + fn register_gateway_url(&self, gateway_url: String, socket_path: PathBuf) { + self.socket_paths_by_url + .lock() + .expect("Failed to lock GatewayManager test socket registry") + .insert(gateway_url, socket_path); + } + + fn socket_path_for(&self, gateway: &Gateway) -> Option { + self.socket_paths_by_url + .lock() + .expect("Failed to lock GatewayManager test socket registry") + .get(&gateway.url()) + .cloned() + } + + fn note_handler_spawn_attempt(&self, gateway_id: Id) { + let mut handler_spawn_attempts = self + .handler_spawn_attempts_by_gateway + .lock() + .expect("Failed to lock GatewayManager handler spawn attempts registry"); + *handler_spawn_attempts.entry(gateway_id).or_default() += 1; + } + + fn handler_spawn_attempt_count(&self, gateway_id: Id) -> u64 { + self.handler_spawn_attempts_by_gateway + .lock() + .expect("Failed to lock GatewayManager handler spawn attempts registry") + .get(&gateway_id) + .copied() + .unwrap_or_default() + } + + fn mark_listener_ready(&self) { + self.listener_ready.store(true, Ordering::Release); + self.listener_ready_notify.notify_waiters(); + } + + async fn wait_until_listener_ready(&self) { + loop { + if self.listener_ready.load(Ordering::Acquire) { + return; + } + + let notified = self.listener_ready_notify.notified(); + if self.listener_ready.load(Ordering::Acquire) { + return; + } + + notified.await; + } + } +} + +#[doc(hidden)] +#[derive(Clone, Default)] +pub struct TestGatewayManagerControl { + inner: GatewayManagerTestSupport, +} + +impl TestGatewayManagerControl { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + pub fn register_gateway_url(&self, gateway_url: String, socket_path: PathBuf) { + self.inner.register_gateway_url(gateway_url, socket_path); + } + + #[doc(hidden)] + pub fn handler_spawn_attempt_count(&self, gateway_id: Id) -> u64 { + self.inner.handler_spawn_attempt_count(gateway_id) + } + + pub async fn wait_until_listener_ready(&self) { + self.inner.wait_until_listener_ready().await; + } +} + pub struct GatewayManager { clients: Arc>>, pool: PgPool, handlers: JoinSet>, + test_support: GatewayManagerTestSupport, tx: GatewayTxSet, } @@ -50,6 +163,23 @@ impl GatewayManager { clients: Arc::default(), handlers: JoinSet::new(), pool, + test_support: GatewayManagerTestSupport::default(), + tx, + } + } + + #[doc(hidden)] + #[must_use] + pub fn new_for_test( + pool: PgPool, + tx: GatewayTxSet, + control: TestGatewayManagerControl, + ) -> Self { + Self { + clients: Arc::default(), + handlers: JoinSet::new(), + pool, + test_support: control.inner, tx, } } @@ -59,14 +189,19 @@ impl GatewayManager { let (certs_tx, certs_rx) = tokio::sync::watch::channel(Arc::new(HashMap::new())); certs::refresh_certs(&self.pool, &certs_tx).await; let refresh_pool = self.pool.clone(); - tokio::spawn(async move { + let _refresh_certs_task = AbortTaskOnDrop::new(tokio::spawn(async move { loop { certs::refresh_certs(&refresh_pool, &certs_tx).await; tokio::time::sleep(TEN_SECS).await; } - }); + })); let mut abort_handles = HashMap::new(); for gateway in Gateway::all(&self.pool).await? { + if !gateway.enabled { + debug!("Existing Gateway is disabled, so it won't be handled"); + continue; + } + let id = gateway.id; let abort_handle = self.run_handler(gateway, Arc::clone(&self.clients), certs_rx.clone())?; @@ -76,6 +211,7 @@ impl GatewayManager { // Observe gateway URL changes. let mut listener = PgListener::connect_with(&self.pool).await?; listener.listen(GATEWAY_TABLE_TRIGGER).await?; + self.test_support.mark_listener_ready(); while let Ok(notification) = listener.recv().await { let payload = notification.payload(); match serde_json::from_str::>>(payload) { @@ -181,13 +317,26 @@ impl GatewayManager { clients: Arc>>, certs_rx: Receiver>>, ) -> Result { - let mut gateway_handler = GatewayHandler::new( - gateway, - self.pool.clone(), - self.tx.events.clone(), - self.tx.peer_stats.clone(), - certs_rx.clone(), - )?; + self.test_support.note_handler_spawn_attempt(gateway.id); + let mut gateway_handler = + if let Some(socket_path) = self.test_support.socket_path_for(&gateway) { + GatewayHandler::new_with_test_socket( + gateway, + self.pool.clone(), + self.tx.events.clone(), + self.tx.peer_stats.clone(), + certs_rx.clone(), + socket_path, + )? + } else { + GatewayHandler::new( + gateway, + self.pool.clone(), + self.tx.events.clone(), + self.tx.peer_stats.clone(), + certs_rx.clone(), + )? + }; let abort_handle = self.handlers.spawn(async move { loop { if let Err(err) = gateway_handler diff --git a/crates/defguard_gateway_manager/tests/common/mod.rs b/crates/defguard_gateway_manager/tests/common/mod.rs index d03c618f6d..a365baba5e 100644 --- a/crates/defguard_gateway_manager/tests/common/mod.rs +++ b/crates/defguard_gateway_manager/tests/common/mod.rs @@ -20,7 +20,9 @@ use defguard_common::{ messages::peer_stats_update::PeerStatsUpdate, }; use defguard_core::grpc::GatewayEvent; -use defguard_gateway_manager::TestGatewayHandler; +use defguard_gateway_manager::{ + GatewayManager, GatewayTxSet, TestGatewayHandler, TestGatewayManagerControl, +}; use defguard_proto::gateway::{ ConfigurationRequest, CoreRequest, CoreResponse, PeerStats, core_request, gateway_server, }; @@ -28,7 +30,7 @@ use sqlx::{PgPool, postgres::PgConnectOptions}; use tokio::{ net::UnixListener, sync::{ - broadcast, + Notify, broadcast, mpsc::{self, UnboundedReceiver, UnboundedSender}, oneshot, watch, }, @@ -76,6 +78,8 @@ struct MockGatewayState { outbound_tx: UnboundedSender, inbound_rx: Mutex>>>, connected_tx: Mutex>>, + purge_count: AtomicU64, + purge_notify: Notify, } impl MockGatewayState { @@ -97,6 +101,11 @@ impl MockGatewayState { .take() .ok_or_else(|| Status::failed_precondition("mock gateway already connected")) } + + fn note_purge(&self) { + self.purge_count.fetch_add(1, Ordering::Relaxed); + self.purge_notify.notify_waiters(); + } } #[tonic::async_trait] @@ -124,11 +133,13 @@ impl gateway_server::Gateway for MockGatewayService { } async fn purge(&self, _request: Request<()>) -> Result, Status> { + self.state.note_purge(); Ok(Response::new(())) } } pub(crate) struct MockGatewayHarness { + state: Arc, socket_path: PathBuf, inbound_tx: Option>>, outbound_rx: UnboundedReceiver, @@ -147,12 +158,15 @@ impl MockGatewayHarness { let (outbound_tx, outbound_rx) = mpsc::unbounded_channel(); let (inbound_tx, inbound_rx) = mpsc::unbounded_channel(); let (connected_tx, connected_rx) = oneshot::channel(); + let state = Arc::new(MockGatewayState { + outbound_tx, + inbound_rx: Mutex::new(Some(inbound_rx)), + connected_tx: Mutex::new(Some(connected_tx)), + purge_count: AtomicU64::new(0), + purge_notify: Notify::new(), + }); let service = MockGatewayService { - state: Arc::new(MockGatewayState { - outbound_tx, - inbound_rx: Mutex::new(Some(inbound_rx)), - connected_tx: Mutex::new(Some(connected_tx)), - }), + state: Arc::clone(&state), }; let server_task = tokio::spawn(async move { @@ -165,6 +179,7 @@ impl MockGatewayHarness { }); Self { + state, socket_path, inbound_tx: Some(inbound_tx), outbound_rx, @@ -185,6 +200,20 @@ impl MockGatewayHarness { .expect("mock gateway connection notifier dropped"); } + pub(crate) async fn wait_purged(&self) { + timeout(TEST_TIMEOUT, async { + loop { + if self.state.purge_count.load(Ordering::Relaxed) > 0 { + return; + } + + self.state.purge_notify.notified().await; + } + }) + .await + .expect("timed out waiting for purge request"); + } + pub(crate) fn send_config_request(&self) { let request = ConfigurationRequest { hostname: "mock-gateway".to_string(), @@ -260,6 +289,99 @@ impl Drop for MockGatewayHarness { } } +pub(crate) struct ManagerTestContext { + pub(crate) pool: PgPool, + control: TestGatewayManagerControl, + manager_task: Option>>, +} + +impl ManagerTestContext { + pub(crate) async fn new(options: PgConnectOptions) -> Self { + let pool = setup_pool(options).await; + initialize_current_settings(&pool) + .await + .expect("failed to initialize global settings for gateway manager tests"); + + Self { + pool, + control: TestGatewayManagerControl::new(), + manager_task: None, + } + } + + pub(crate) fn register_gateway_mock( + &self, + gateway: &Gateway, + mock_gateway: &MockGatewayHarness, + ) { + self.control + .register_gateway_url(gateway.url(), mock_gateway.socket_path()); + } + + pub(crate) fn handler_spawn_attempt_count(&self, gateway_id: Id) -> u64 { + self.control.handler_spawn_attempt_count(gateway_id) + } + + pub(crate) async fn wait_for_handler_spawn_attempt_count( + &self, + gateway_id: Id, + expected_count: u64, + ) { + timeout(TEST_TIMEOUT, async { + loop { + if self.handler_spawn_attempt_count(gateway_id) >= expected_count { + return; + } + + tokio::time::sleep(Duration::from_millis(20)).await; + } + }) + .await + .expect("timed out waiting for gateway manager handler spawn attempt"); + } + + pub(crate) async fn start(&mut self) { + assert!( + self.manager_task.is_none(), + "gateway manager already started" + ); + + let (events_tx, _) = broadcast::channel(16); + let (peer_stats_tx, _peer_stats_rx) = mpsc::unbounded_channel(); + let tx = GatewayTxSet::new(events_tx, peer_stats_tx); + let mut manager = GatewayManager::new_for_test(self.pool.clone(), tx, self.control.clone()); + let manager_task = tokio::spawn(async move { manager.run().await }); + + timeout(TEST_TIMEOUT, self.control.wait_until_listener_ready()) + .await + .expect("timed out waiting for gateway manager listener to become ready"); + self.manager_task = Some(manager_task); + } + + pub(crate) async fn finish(mut self) { + if let Some(manager_task) = self.manager_task.take() { + manager_task.abort(); + + match manager_task.await { + Err(err) if err.is_cancelled() => {} + Err(err) => panic!("gateway manager task panicked: {err}"), + Ok(Ok(())) => {} + Ok(Err(err)) => panic!("gateway manager exited with error: {err}"), + } + } + + self.pool.close().await; + } +} + +impl Drop for ManagerTestContext { + fn drop(&mut self) { + if let Some(manager_task) = self.manager_task.take() { + manager_task.abort(); + } + } +} + pub(crate) struct HandlerTestContext { pub(crate) pool: PgPool, pub(crate) network: WireguardNetwork, @@ -460,7 +582,7 @@ pub(crate) fn build_peer_stats(endpoint: &str) -> PeerStats { } } -async fn create_network(pool: &PgPool) -> WireguardNetwork { +pub(crate) async fn create_network(pool: &PgPool) -> WireguardNetwork { let mut network = WireguardNetwork { name: unique_name("network"), endpoint: "198.51.100.10".to_string(), @@ -476,15 +598,26 @@ async fn create_network(pool: &PgPool) -> WireguardNetwork { .expect("failed to create test network") } -async fn create_gateway(pool: &PgPool, location_id: Id) -> Gateway { - Gateway::new( +pub(crate) async fn create_gateway(pool: &PgPool, location_id: Id) -> Gateway { + create_gateway_with_enabled(pool, location_id, true).await +} + +pub(crate) async fn create_gateway_with_enabled( + pool: &PgPool, + location_id: Id, + enabled: bool, +) -> Gateway { + let port = 20_000 + i32::try_from(next_test_id() % 40_000).expect("port offset fits in i32"); + let mut gateway = Gateway::new( location_id, unique_name("gateway"), "127.0.0.1".to_string(), - 51820, + port, "test-admin".to_string(), - ) - .save(pool) - .await - .expect("failed to create test gateway") + ); + gateway.enabled = enabled; + gateway + .save(pool) + .await + .expect("failed to create test gateway") } diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/manager.rs b/crates/defguard_gateway_manager/tests/gateway_manager/manager.rs new file mode 100644 index 0000000000..f404442315 --- /dev/null +++ b/crates/defguard_gateway_manager/tests/gateway_manager/manager.rs @@ -0,0 +1,112 @@ +use defguard_common::db::models::gateway::Gateway; +use defguard_proto::gateway::core_response; +use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; + +use crate::common::{ + ManagerTestContext, MockGatewayHarness, create_gateway, create_gateway_with_enabled, + create_network, reload_gateway, wait_for_gateway_connection_state, +}; + +#[sqlx::test] +async fn test_starts_existing_enabled_gateway_on_startup( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = ManagerTestContext::new(options).await; + let network = create_network(&context.pool).await; + let gateway = create_gateway(&context.pool, network.id).await; + let mut mock_gateway = MockGatewayHarness::start().await; + context.register_gateway_mock(&gateway, &mock_gateway); + + context.start().await; + mock_gateway.wait_connected().await; + + mock_gateway.send_config_request(); + let outbound = mock_gateway.recv_outbound().await; + assert!(matches!( + outbound.payload, + Some(core_response::Payload::Config(_)) + )); + + let gateway_after = wait_for_gateway_connection_state(&context.pool, gateway.id, true).await; + assert!(gateway_after.is_connected()); + + context.finish().await; +} + +#[sqlx::test] +async fn test_starts_gateway_after_enabled_update_notification( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = ManagerTestContext::new(options).await; + let network = create_network(&context.pool).await; + let mut gateway = create_gateway_with_enabled(&context.pool, network.id, false).await; + let mut mock_gateway = MockGatewayHarness::start().await; + context.register_gateway_mock(&gateway, &mock_gateway); + + context.start().await; + assert_eq!( + context.handler_spawn_attempt_count(gateway.id), + 0, + "disabled gateway handler should not start during manager startup" + ); + + let gateway_before = reload_gateway(&context.pool, gateway.id).await; + assert!(!gateway_before.is_connected()); + + gateway.enabled = true; + gateway + .save(&context.pool) + .await + .expect("failed to enable test gateway"); + + context + .wait_for_handler_spawn_attempt_count(gateway.id, 1) + .await; + mock_gateway.wait_connected().await; + mock_gateway.send_config_request(); + let outbound = mock_gateway.recv_outbound().await; + assert!(matches!( + outbound.payload, + Some(core_response::Payload::Config(_)) + )); + + let gateway_after = wait_for_gateway_connection_state(&context.pool, gateway.id, true).await; + assert!(gateway_after.is_connected()); + + context.finish().await; +} + +#[sqlx::test] +async fn test_delete_notification_purges_and_aborts_gateway_connection( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = ManagerTestContext::new(options).await; + let network = create_network(&context.pool).await; + let gateway = create_gateway(&context.pool, network.id).await; + let mut mock_gateway = MockGatewayHarness::start().await; + context.register_gateway_mock(&gateway, &mock_gateway); + + context.start().await; + mock_gateway.wait_connected().await; + + mock_gateway.send_config_request(); + let outbound = mock_gateway.recv_outbound().await; + assert!(matches!( + outbound.payload, + Some(core_response::Payload::Config(_)) + )); + let gateway_after = wait_for_gateway_connection_state(&context.pool, gateway.id, true).await; + assert!(gateway_after.is_connected()); + + Gateway::delete_by_id(&context.pool, gateway.id) + .await + .expect("failed to delete test gateway"); + + mock_gateway.wait_purged().await; + mock_gateway.expect_server_finished().await; + + context.finish().await; +} diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/mod.rs b/crates/defguard_gateway_manager/tests/gateway_manager/mod.rs index 427a476918..caf495ef2e 100644 --- a/crates/defguard_gateway_manager/tests/gateway_manager/mod.rs +++ b/crates/defguard_gateway_manager/tests/gateway_manager/mod.rs @@ -1 +1,2 @@ mod handler; +mod manager; From f5ecf0da8061637fb26d44795239ff2b2503d62e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Wed, 18 Mar 2026 12:07:01 +0100 Subject: [PATCH 17/36] remove duplicate method --- crates/defguard_common/src/db/models/gateway.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/crates/defguard_common/src/db/models/gateway.rs b/crates/defguard_common/src/db/models/gateway.rs index e76b1f443b..a7e7184146 100644 --- a/crates/defguard_common/src/db/models/gateway.rs +++ b/crates/defguard_common/src/db/models/gateway.rs @@ -34,6 +34,12 @@ impl Gateway { self.connected_at.is_some() } } + + /// Return address and port as URL with HTTP scheme. + #[must_use] + pub fn url(&self) -> String { + format!("http://{}:{}", self.address, self.port) + } } impl Gateway { @@ -171,12 +177,6 @@ impl Gateway { Ok(record) } - /// Return address and port as URL with HTTP scheme. - #[must_use] - pub fn url(&self) -> String { - format!("http://{}:{}", self.address, self.port) - } - /// Disable all Gateways except one. Used for expired licence. pub async fn leave_one_enabled<'e, E>(executor: E) -> sqlx::Result<()> where From 6edb043ae9f84b120163032485bdcb13710b5f32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Wed, 18 Mar 2026 12:08:03 +0100 Subject: [PATCH 18/36] cleanup --- crates/defguard_gateway_manager/src/lib.rs | 223 ++++++++++++------ .../tests/common/mod.rs | 88 +++++-- .../tests/gateway_manager/manager.rs | 175 +++++++++++--- 3 files changed, 373 insertions(+), 113 deletions(-) diff --git a/crates/defguard_gateway_manager/src/lib.rs b/crates/defguard_gateway_manager/src/lib.rs index c1a7173ffc..eeea23bf69 100644 --- a/crates/defguard_gateway_manager/src/lib.rs +++ b/crates/defguard_gateway_manager/src/lib.rs @@ -64,6 +64,9 @@ impl Drop for AbortTaskOnDrop { struct GatewayManagerTestSupport { socket_paths_by_url: Arc>>, handler_spawn_attempts_by_gateway: Arc>>, + handler_spawn_attempt_notify: Arc, + gateway_notifications_by_gateway: Arc>>, + gateway_notification_notify: Arc, listener_ready: Arc, listener_ready_notify: Arc, } @@ -90,6 +93,7 @@ impl GatewayManagerTestSupport { .lock() .expect("Failed to lock GatewayManager handler spawn attempts registry"); *handler_spawn_attempts.entry(gateway_id).or_default() += 1; + self.handler_spawn_attempt_notify.notify_waiters(); } fn handler_spawn_attempt_count(&self, gateway_id: Id) -> u64 { @@ -101,6 +105,54 @@ impl GatewayManagerTestSupport { .unwrap_or_default() } + async fn wait_for_handler_spawn_attempt_count(&self, gateway_id: Id, expected_count: u64) { + loop { + if self.handler_spawn_attempt_count(gateway_id) >= expected_count { + return; + } + + let notified = self.handler_spawn_attempt_notify.notified(); + if self.handler_spawn_attempt_count(gateway_id) >= expected_count { + return; + } + + notified.await; + } + } + + fn note_gateway_notification(&self, gateway_id: Id) { + let mut gateway_notifications = self + .gateway_notifications_by_gateway + .lock() + .expect("Failed to lock GatewayManager gateway notification registry"); + *gateway_notifications.entry(gateway_id).or_default() += 1; + self.gateway_notification_notify.notify_waiters(); + } + + fn gateway_notification_count(&self, gateway_id: Id) -> u64 { + self.gateway_notifications_by_gateway + .lock() + .expect("Failed to lock GatewayManager gateway notification registry") + .get(&gateway_id) + .copied() + .unwrap_or_default() + } + + async fn wait_for_gateway_notification_count(&self, gateway_id: Id, expected_count: u64) { + loop { + if self.gateway_notification_count(gateway_id) >= expected_count { + return; + } + + let notified = self.gateway_notification_notify.notified(); + if self.gateway_notification_count(gateway_id) >= expected_count { + return; + } + + notified.await; + } + } + fn mark_listener_ready(&self) { self.listener_ready.store(true, Ordering::Release); self.listener_ready_notify.notify_waiters(); @@ -143,6 +195,25 @@ impl TestGatewayManagerControl { self.inner.handler_spawn_attempt_count(gateway_id) } + #[doc(hidden)] + pub async fn wait_for_handler_spawn_attempt_count(&self, gateway_id: Id, expected_count: u64) { + self.inner + .wait_for_handler_spawn_attempt_count(gateway_id, expected_count) + .await; + } + + #[doc(hidden)] + pub fn gateway_notification_count(&self, gateway_id: Id) -> u64 { + self.inner.gateway_notification_count(gateway_id) + } + + #[doc(hidden)] + pub async fn wait_for_gateway_notification_count(&self, gateway_id: Id, expected_count: u64) { + self.inner + .wait_for_gateway_notification_count(gateway_id, expected_count) + .await; + } + pub async fn wait_until_listener_ready(&self) { self.inner.wait_until_listener_ready().await; } @@ -215,9 +286,13 @@ impl GatewayManager { while let Ok(notification) = listener.recv().await { let payload = notification.payload(); match serde_json::from_str::>>(payload) { - Ok(gateway_notification) => match gateway_notification.operation { - TriggerOperation::Insert => { - if let Some(new) = gateway_notification.new { + Ok(gateway_notification) => { + let maybe_gateway_id = match gateway_notification.operation { + TriggerOperation::Insert => { + let Some(new) = gateway_notification.new else { + continue; + }; + let id = new.id; if new.enabled { let abort_handle = self.run_handler( @@ -229,77 +304,93 @@ impl GatewayManager { } else { debug!("New Gateway is disabled, so it won't be handled"); } + + Some(id) } - } - TriggerOperation::Update => { - let (Some(old), Some(new)) = - (gateway_notification.old, gateway_notification.new) - else { - continue; - }; - if old.address == new.address - && old.port == new.port - && old.enabled == new.enabled - { - debug!("Gateway address/port/state didn't change"); - continue; - } - if let Some(abort_handle) = abort_handles.remove(&old.id) { - info!( - "Aborting connection to Gateway {old}, it has changed in the \ - database" - ); - abort_handle.abort(); - } else if old.enabled { - warn!("Cannot find Gateway {old} on the list of connected gateways"); - } - if new.enabled { + TriggerOperation::Update => { + let (Some(old), Some(new)) = + (gateway_notification.old, gateway_notification.new) + else { + continue; + }; + let id = new.id; - let abort_handle = - self.run_handler(new, Arc::clone(&self.clients), certs_rx.clone())?; - abort_handles.insert(id, abort_handle); - } else { - debug!("Updated Gateway is disabled, so it won't be handled"); - } - } - TriggerOperation::Delete => { - let Some(old) = gateway_notification.old else { - continue; - }; - - // Send purge request to Gateway. - let maybe_client = { - self.clients - .lock() - .expect("Failed to lock GatewayManager::clients") - .remove(&old.id) - }; - - if let Some(mut client) = maybe_client { - debug!("Sending purge request to Gateway {old}"); - if let Err(err) = client.purge(Request::new(())).await { - error!("Error sending purge request to Gateway {old}: {err}"); + if old.address == new.address + && old.port == new.port + && old.enabled == new.enabled + { + debug!("Gateway address/port/state didn't change"); } else { - info!("Sent purge request to Gateway {old}"); + if let Some(abort_handle) = abort_handles.remove(&old.id) { + info!( + "Aborting connection to Gateway {old}, it has changed in the \ + database" + ); + abort_handle.abort(); + } else if old.enabled { + warn!( + "Cannot find Gateway {old} on the list of connected gateways" + ); + } + if new.enabled { + let abort_handle = self.run_handler( + new, + Arc::clone(&self.clients), + certs_rx.clone(), + )?; + abort_handles.insert(id, abort_handle); + } else { + debug!("Updated Gateway is disabled, so it won't be handled"); + } } - } else { - warn!( - "Cannot find gRPC client for Gateway {old}; skipping purge request" - ); + + Some(id) } + TriggerOperation::Delete => { + let Some(old) = gateway_notification.old else { + continue; + }; + + // Send purge request to Gateway. + let maybe_client = { + self.clients + .lock() + .expect("Failed to lock GatewayManager::clients") + .remove(&old.id) + }; + + if let Some(mut client) = maybe_client { + debug!("Sending purge request to Gateway {old}"); + if let Err(err) = client.purge(Request::new(())).await { + error!("Error sending purge request to Gateway {old}: {err}"); + } else { + info!("Sent purge request to Gateway {old}"); + } + } else { + warn!( + "Cannot find gRPC client for Gateway {old}; skipping purge request" + ); + } + + // Kill the `GatewayHandler` and the connection. + if let Some(abort_handle) = abort_handles.remove(&old.id) { + info!( + "Aborting connection to Gateway {old}, it has disappeard from the \ + database" + ); + abort_handle.abort(); + } else if old.enabled { + warn!("Cannot find Gateway {old} on the list of connected gateways"); + } - // Kill the `GatewayHandler` and the connection. - if let Some(abort_handle) = abort_handles.remove(&old.id) { - info!( - "Aborting connection to Gateway {old}, it has disappeard from the \ - database" - ); - abort_handle.abort(); - } else if old.enabled { - warn!("Cannot find Gateway {old} on the list of connected gateways"); + Some(old.id) } + }; + + if let Some(gateway_id) = maybe_gateway_id { + self.test_support.note_gateway_notification(gateway_id); } - }, + } Err(err) => error!("Failed to de-serialize database notification object: {err}"), } } diff --git a/crates/defguard_gateway_manager/tests/common/mod.rs b/crates/defguard_gateway_manager/tests/common/mod.rs index a365baba5e..c9949857a3 100644 --- a/crates/defguard_gateway_manager/tests/common/mod.rs +++ b/crates/defguard_gateway_manager/tests/common/mod.rs @@ -11,7 +11,7 @@ use std::{ use defguard_common::{ db::{ - Id, + Id, NoId, models::{ gateway::Gateway, settings::initialize_current_settings, wireguard::WireguardNetwork, }, @@ -78,12 +78,17 @@ struct MockGatewayState { outbound_tx: UnboundedSender, inbound_rx: Mutex>>>, connected_tx: Mutex>>, + connection_count: AtomicU64, + connection_notify: Notify, purge_count: AtomicU64, purge_notify: Notify, } impl MockGatewayState { fn notify_connected(&self) { + self.connection_count.fetch_add(1, Ordering::Relaxed); + self.connection_notify.notify_waiters(); + if let Some(tx) = self .connected_tx .lock() @@ -162,6 +167,8 @@ impl MockGatewayHarness { outbound_tx, inbound_rx: Mutex::new(Some(inbound_rx)), connected_tx: Mutex::new(Some(connected_tx)), + connection_count: AtomicU64::new(0), + connection_notify: Notify::new(), purge_count: AtomicU64::new(0), purge_notify: Notify::new(), }); @@ -200,6 +207,29 @@ impl MockGatewayHarness { .expect("mock gateway connection notifier dropped"); } + pub(crate) fn connection_count(&self) -> u64 { + self.state.connection_count.load(Ordering::Relaxed) + } + + pub(crate) async fn wait_for_connection_count(&self, expected_count: u64) { + timeout(TEST_TIMEOUT, async { + loop { + if self.connection_count() >= expected_count { + return; + } + + let notified = self.state.connection_notify.notified(); + if self.connection_count() >= expected_count { + return; + } + + notified.await; + } + }) + .await + .expect("timed out waiting for mock gateway connection count"); + } + pub(crate) async fn wait_purged(&self) { timeout(TEST_TIMEOUT, async { loop { @@ -207,7 +237,12 @@ impl MockGatewayHarness { return; } - self.state.purge_notify.notified().await; + let notified = self.state.purge_notify.notified(); + if self.state.purge_count.load(Ordering::Relaxed) > 0 { + return; + } + + notified.await; } }) .await @@ -314,32 +349,50 @@ impl ManagerTestContext { gateway: &Gateway, mock_gateway: &MockGatewayHarness, ) { + self.register_gateway_url(gateway.url(), mock_gateway); + } + + pub(crate) fn register_gateway_url(&self, gateway_url: String, mock_gateway: &MockGatewayHarness) { self.control - .register_gateway_url(gateway.url(), mock_gateway.socket_path()); + .register_gateway_url(gateway_url, mock_gateway.socket_path()); } pub(crate) fn handler_spawn_attempt_count(&self, gateway_id: Id) -> u64 { self.control.handler_spawn_attempt_count(gateway_id) } + pub(crate) fn gateway_notification_count(&self, gateway_id: Id) -> u64 { + self.control.gateway_notification_count(gateway_id) + } + pub(crate) async fn wait_for_handler_spawn_attempt_count( &self, gateway_id: Id, expected_count: u64, ) { - timeout(TEST_TIMEOUT, async { - loop { - if self.handler_spawn_attempt_count(gateway_id) >= expected_count { - return; - } - - tokio::time::sleep(Duration::from_millis(20)).await; - } - }) + timeout( + TEST_TIMEOUT, + self.control + .wait_for_handler_spawn_attempt_count(gateway_id, expected_count), + ) .await .expect("timed out waiting for gateway manager handler spawn attempt"); } + pub(crate) async fn wait_for_gateway_notification_count( + &self, + gateway_id: Id, + expected_count: u64, + ) { + timeout( + TEST_TIMEOUT, + self.control + .wait_for_gateway_notification_count(gateway_id, expected_count), + ) + .await + .expect("timed out waiting for gateway manager database notification"); + } + pub(crate) async fn start(&mut self) { assert!( self.manager_task.is_none(), @@ -607,6 +660,14 @@ pub(crate) async fn create_gateway_with_enabled( location_id: Id, enabled: bool, ) -> Gateway { + let gateway = build_gateway_with_enabled(location_id, enabled); + gateway + .save(pool) + .await + .expect("failed to create test gateway") +} + +pub(crate) fn build_gateway_with_enabled(location_id: Id, enabled: bool) -> Gateway { let port = 20_000 + i32::try_from(next_test_id() % 40_000).expect("port offset fits in i32"); let mut gateway = Gateway::new( location_id, @@ -617,7 +678,4 @@ pub(crate) async fn create_gateway_with_enabled( ); gateway.enabled = enabled; gateway - .save(pool) - .await - .expect("failed to create test gateway") } diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/manager.rs b/crates/defguard_gateway_manager/tests/gateway_manager/manager.rs index f404442315..5c7f5219c8 100644 --- a/crates/defguard_gateway_manager/tests/gateway_manager/manager.rs +++ b/crates/defguard_gateway_manager/tests/gateway_manager/manager.rs @@ -1,12 +1,30 @@ -use defguard_common::db::models::gateway::Gateway; +use defguard_common::db::{Id, models::gateway::Gateway}; use defguard_proto::gateway::core_response; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use crate::common::{ - ManagerTestContext, MockGatewayHarness, create_gateway, create_gateway_with_enabled, - create_network, reload_gateway, wait_for_gateway_connection_state, + ManagerTestContext, MockGatewayHarness, build_gateway_with_enabled, create_gateway, + create_gateway_with_enabled, create_network, reload_gateway, + wait_for_gateway_connection_state, }; +async fn complete_manager_handshake( + context: &ManagerTestContext, + gateway: &Gateway, + mock_gateway: &mut MockGatewayHarness, +) { + mock_gateway.wait_connected().await; + mock_gateway.send_config_request(); + let outbound = mock_gateway.recv_outbound().await; + assert!(matches!( + outbound.payload, + Some(core_response::Payload::Config(_)) + )); + + let gateway_after = wait_for_gateway_connection_state(&context.pool, gateway.id, true).await; + assert!(gateway_after.is_connected()); +} + #[sqlx::test] async fn test_starts_existing_enabled_gateway_on_startup( _: PgPoolOptions, @@ -19,17 +37,7 @@ async fn test_starts_existing_enabled_gateway_on_startup( context.register_gateway_mock(&gateway, &mock_gateway); context.start().await; - mock_gateway.wait_connected().await; - - mock_gateway.send_config_request(); - let outbound = mock_gateway.recv_outbound().await; - assert!(matches!( - outbound.payload, - Some(core_response::Payload::Config(_)) - )); - - let gateway_after = wait_for_gateway_connection_state(&context.pool, gateway.id, true).await; - assert!(gateway_after.is_connected()); + complete_manager_handshake(&context, &gateway, &mut mock_gateway).await; context.finish().await; } @@ -61,19 +69,131 @@ async fn test_starts_gateway_after_enabled_update_notification( .await .expect("failed to enable test gateway"); + context.wait_for_gateway_notification_count(gateway.id, 1).await; context .wait_for_handler_spawn_attempt_count(gateway.id, 1) .await; - mock_gateway.wait_connected().await; - mock_gateway.send_config_request(); - let outbound = mock_gateway.recv_outbound().await; - assert!(matches!( - outbound.payload, - Some(core_response::Payload::Config(_)) - )); + complete_manager_handshake(&context, &gateway, &mut mock_gateway).await; - let gateway_after = wait_for_gateway_connection_state(&context.pool, gateway.id, true).await; + context.finish().await; +} + +#[sqlx::test] +async fn test_noop_gateway_update_does_not_restart_handler( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = ManagerTestContext::new(options).await; + let network = create_network(&context.pool).await; + let mut gateway = create_gateway(&context.pool, network.id).await; + let mut mock_gateway = MockGatewayHarness::start().await; + context.register_gateway_mock(&gateway, &mock_gateway); + + context.start().await; + complete_manager_handshake(&context, &gateway, &mut mock_gateway).await; + + gateway = reload_gateway(&context.pool, gateway.id).await; + let initial_spawn_attempts = context.handler_spawn_attempt_count(gateway.id); + let initial_notification_count = context.gateway_notification_count(gateway.id); + let initial_connection_count = mock_gateway.connection_count(); + + gateway.modified_by = "manager-noop-update".to_string(); + gateway + .save(&context.pool) + .await + .expect("failed to save no-op gateway update"); + + context + .wait_for_gateway_notification_count(gateway.id, initial_notification_count + 1) + .await; + assert_eq!( + context.handler_spawn_attempt_count(gateway.id), + initial_spawn_attempts, + "no-op gateway update should not restart the handler" + ); + assert_eq!( + mock_gateway.connection_count(), + initial_connection_count, + "no-op gateway update should not reconnect the handler" + ); + + let gateway_after = reload_gateway(&context.pool, gateway.id).await; + assert!(gateway_after.is_connected()); + + context.finish().await; +} + +#[sqlx::test] +async fn test_gateway_address_change_restarts_handler( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = ManagerTestContext::new(options).await; + let network = create_network(&context.pool).await; + let mut gateway = create_gateway(&context.pool, network.id).await; + let mut original_mock_gateway = MockGatewayHarness::start().await; + context.register_gateway_mock(&gateway, &original_mock_gateway); + + context.start().await; + complete_manager_handshake(&context, &gateway, &mut original_mock_gateway).await; + + gateway = reload_gateway(&context.pool, gateway.id).await; + + let replacement_mock_url = { + gateway.address = "127.0.0.2".to_string(); + gateway.modified_by = "manager-address-update".to_string(); + gateway.url() + }; + let mut replacement_mock_gateway = MockGatewayHarness::start().await; + context.register_gateway_url(replacement_mock_url, &replacement_mock_gateway); + + let initial_spawn_attempts = context.handler_spawn_attempt_count(gateway.id); + let initial_notification_count = context.gateway_notification_count(gateway.id); + + gateway + .save(&context.pool) + .await + .expect("failed to save gateway address update"); + + context + .wait_for_gateway_notification_count(gateway.id, initial_notification_count + 1) + .await; + context + .wait_for_handler_spawn_attempt_count(gateway.id, initial_spawn_attempts + 1) + .await; + replacement_mock_gateway.wait_for_connection_count(1).await; + complete_manager_handshake(&context, &gateway, &mut replacement_mock_gateway).await; + + let gateway_after = reload_gateway(&context.pool, gateway.id).await; + assert_eq!(gateway_after.address, "127.0.0.2"); assert!(gateway_after.is_connected()); + original_mock_gateway.expect_server_finished().await; + + context.finish().await; +} + +#[sqlx::test] +async fn test_insert_notification_starts_handler_for_enabled_gateway( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = ManagerTestContext::new(options).await; + let network = create_network(&context.pool).await; + let gateway = build_gateway_with_enabled(network.id, true); + let gateway_url = gateway.url(); + let mut mock_gateway = MockGatewayHarness::start().await; + context.register_gateway_url(gateway_url, &mock_gateway); + + context.start().await; + + let gateway = gateway + .save(&context.pool) + .await + .expect("failed to insert enabled test gateway"); + + context.wait_for_gateway_notification_count(gateway.id, 1).await; + context.wait_for_handler_spawn_attempt_count(gateway.id, 1).await; + complete_manager_handshake(&context, &gateway, &mut mock_gateway).await; context.finish().await; } @@ -90,16 +210,7 @@ async fn test_delete_notification_purges_and_aborts_gateway_connection( context.register_gateway_mock(&gateway, &mock_gateway); context.start().await; - mock_gateway.wait_connected().await; - - mock_gateway.send_config_request(); - let outbound = mock_gateway.recv_outbound().await; - assert!(matches!( - outbound.payload, - Some(core_response::Payload::Config(_)) - )); - let gateway_after = wait_for_gateway_connection_state(&context.pool, gateway.id, true).await; - assert!(gateway_after.is_connected()); + complete_manager_handshake(&context, &gateway, &mut mock_gateway).await; Gateway::delete_by_id(&context.pool, gateway.id) .await From 4d8986b4bd01a717a7c653fb03df16be1025168b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Wed, 18 Mar 2026 15:57:30 +0100 Subject: [PATCH 19/36] cleanup --- .../defguard_gateway_manager/src/handler.rs | 28 +- crates/defguard_gateway_manager/src/lib.rs | 320 ++++++++++++++---- .../tests/common/mod.rs | 55 ++- .../tests/gateway_manager/manager.rs | 218 +++++++++++- 4 files changed, 539 insertions(+), 82 deletions(-) diff --git a/crates/defguard_gateway_manager/src/handler.rs b/crates/defguard_gateway_manager/src/handler.rs index d7bd0c314f..d0887a0b8a 100644 --- a/crates/defguard_gateway_manager/src/handler.rs +++ b/crates/defguard_gateway_manager/src/handler.rs @@ -47,7 +47,7 @@ use tokio::{ use tokio_stream::wrappers::UnboundedReceiverStream; use tonic::{Code, Status, transport::Endpoint}; -use crate::{Client, TEN_SECS, error::GatewayError}; +use crate::{Client, GatewayManagerTestSupport, TEN_SECS, error::GatewayError}; #[derive(Debug, Default)] struct GatewayTestTransport { @@ -77,6 +77,7 @@ pub(super) struct GatewayHandler { peer_stats_tx: UnboundedSender, certs_rx: watch::Receiver>>, test_transport: GatewayTestTransport, + test_support: Option, } impl GatewayHandler { @@ -103,6 +104,7 @@ impl GatewayHandler { peer_stats_tx, certs_rx, test_transport: GatewayTestTransport::default(), + test_support: None, }) } @@ -119,6 +121,10 @@ impl GatewayHandler { Ok(handler) } + pub(super) fn attach_test_support(&mut self, test_support: GatewayManagerTestSupport) { + self.test_support = Some(test_support); + } + fn endpoint(&self) -> Result { let mut url = self.url.clone(); @@ -305,17 +311,24 @@ impl GatewayHandler { Version::parse(VERSION).expect("failed to parse self version"), ); let mut client = gateway_client::GatewayClient::with_interceptor(channel, interceptor); + if let Some(test_support) = &self.test_support { + test_support.note_handler_connection_attempt(self.gateway.id); + } clients .lock() .expect("GatewayHandler failed to lock clients") .insert(self.gateway.id, client.clone()); let (tx, rx) = mpsc::unbounded_channel(); + let retry_delay = self + .test_support + .as_ref() + .map_or(TEN_SECS, GatewayManagerTestSupport::handler_reconnect_delay); let response = match client.bidi(UnboundedReceiverStream::new(rx)).await { Ok(response) => response, Err(err) => { error!("Failed to connect to Gateway {uri}, retrying: {err}"); if retry_on_connect_failure { - sleep(TEN_SECS).await; + sleep(retry_delay).await; return Ok(()); } @@ -417,8 +430,8 @@ impl GatewayHandler { error!("Disconnected from Gateway at {uri}, error: {err}"); self.handle_disconnection_error().await; if retry_on_connect_failure { - debug!("Waiting 10s to re-establish the connection"); - sleep(TEN_SECS).await; + debug!("Waiting {retry_delay:?} to re-establish the connection"); + sleep(retry_delay).await; } return Ok(()); } @@ -438,13 +451,12 @@ impl GatewayHandler { } } -#[doc(hidden)] -pub struct TestGatewayHandler { +pub(crate) struct TestGatewayHandler { inner: GatewayHandler, } impl TestGatewayHandler { - pub fn new( + pub(crate) fn new( gateway: Gateway, pool: PgPool, events_tx: Sender, @@ -463,7 +475,7 @@ impl TestGatewayHandler { Ok(Self { inner }) } - pub async fn handle_connection_once(&mut self) -> anyhow::Result<()> { + pub(crate) async fn handle_connection_once(&mut self) -> anyhow::Result<()> { let clients = Arc::>>::default(); self.inner .handle_connection_iteration(clients, false) diff --git a/crates/defguard_gateway_manager/src/lib.rs b/crates/defguard_gateway_manager/src/lib.rs index eeea23bf69..75ea5ff662 100644 --- a/crates/defguard_gateway_manager/src/lib.rs +++ b/crates/defguard_gateway_manager/src/lib.rs @@ -31,9 +31,6 @@ mod certs; mod error; mod handler; -#[doc(hidden)] -pub use handler::TestGatewayHandler; - const GATEWAY_TABLE_TRIGGER: &str = "gateway_change"; const GATEWAY_RECONNECT_DELAY: Duration = Duration::from_secs(5); const TEN_SECS: Duration = Duration::from_secs(10); @@ -65,10 +62,13 @@ struct GatewayManagerTestSupport { socket_paths_by_url: Arc>>, handler_spawn_attempts_by_gateway: Arc>>, handler_spawn_attempt_notify: Arc, + handler_connection_attempts_by_gateway: Arc>>, + handler_connection_attempt_notify: Arc, gateway_notifications_by_gateway: Arc>>, gateway_notification_notify: Arc, listener_ready: Arc, listener_ready_notify: Arc, + retry_delay_override: Arc>>, } impl GatewayManagerTestSupport { @@ -120,6 +120,39 @@ impl GatewayManagerTestSupport { } } + fn note_handler_connection_attempt(&self, gateway_id: Id) { + let mut handler_connection_attempts = self + .handler_connection_attempts_by_gateway + .lock() + .expect("Failed to lock GatewayManager handler connection attempts registry"); + *handler_connection_attempts.entry(gateway_id).or_default() += 1; + self.handler_connection_attempt_notify.notify_waiters(); + } + + fn handler_connection_attempt_count(&self, gateway_id: Id) -> u64 { + self.handler_connection_attempts_by_gateway + .lock() + .expect("Failed to lock GatewayManager handler connection attempts registry") + .get(&gateway_id) + .copied() + .unwrap_or_default() + } + + async fn wait_for_handler_connection_attempt_count(&self, gateway_id: Id, expected_count: u64) { + loop { + if self.handler_connection_attempt_count(gateway_id) >= expected_count { + return; + } + + let notified = self.handler_connection_attempt_notify.notified(); + if self.handler_connection_attempt_count(gateway_id) >= expected_count { + return; + } + + notified.await; + } + } + fn note_gateway_notification(&self, gateway_id: Id) { let mut gateway_notifications = self .gateway_notifications_by_gateway @@ -172,58 +205,207 @@ impl GatewayManagerTestSupport { notified.await; } } + + fn set_retry_delay(&self, retry_delay: Duration) { + *self + .retry_delay_override + .lock() + .expect("Failed to lock GatewayManager retry delay override") = Some(retry_delay); + } + + fn manager_reconnect_delay(&self) -> Duration { + self.retry_delay_override + .lock() + .expect("Failed to lock GatewayManager retry delay override") + .unwrap_or(GATEWAY_RECONNECT_DELAY) + } + + fn handler_reconnect_delay(&self) -> Duration { + self.retry_delay_override + .lock() + .expect("Failed to lock GatewayManager retry delay override") + .unwrap_or(TEN_SECS) + } } -#[doc(hidden)] #[derive(Clone, Default)] -pub struct TestGatewayManagerControl { +struct TestGatewayManagerControl { inner: GatewayManagerTestSupport, } impl TestGatewayManagerControl { #[must_use] - pub fn new() -> Self { + fn new() -> Self { Self::default() } - pub fn register_gateway_url(&self, gateway_url: String, socket_path: PathBuf) { + fn register_gateway_url(&self, gateway_url: String, socket_path: PathBuf) { self.inner.register_gateway_url(gateway_url, socket_path); } - #[doc(hidden)] - pub fn handler_spawn_attempt_count(&self, gateway_id: Id) -> u64 { + fn handler_spawn_attempt_count(&self, gateway_id: Id) -> u64 { self.inner.handler_spawn_attempt_count(gateway_id) } - #[doc(hidden)] - pub async fn wait_for_handler_spawn_attempt_count(&self, gateway_id: Id, expected_count: u64) { + async fn wait_for_handler_spawn_attempt_count(&self, gateway_id: Id, expected_count: u64) { self.inner .wait_for_handler_spawn_attempt_count(gateway_id, expected_count) .await; } - #[doc(hidden)] - pub fn gateway_notification_count(&self, gateway_id: Id) -> u64 { + fn handler_connection_attempt_count(&self, gateway_id: Id) -> u64 { + self.inner.handler_connection_attempt_count(gateway_id) + } + + async fn wait_for_handler_connection_attempt_count(&self, gateway_id: Id, expected_count: u64) { + self.inner + .wait_for_handler_connection_attempt_count(gateway_id, expected_count) + .await; + } + + fn gateway_notification_count(&self, gateway_id: Id) -> u64 { self.inner.gateway_notification_count(gateway_id) } - #[doc(hidden)] - pub async fn wait_for_gateway_notification_count(&self, gateway_id: Id, expected_count: u64) { + async fn wait_for_gateway_notification_count(&self, gateway_id: Id, expected_count: u64) { self.inner .wait_for_gateway_notification_count(gateway_id, expected_count) .await; } - pub async fn wait_until_listener_ready(&self) { + async fn wait_until_listener_ready(&self) { self.inner.wait_until_listener_ready().await; } + + fn set_retry_delay(&self, retry_delay: Duration) { + self.inner.set_retry_delay(retry_delay); + } +} + +#[doc(hidden)] +pub mod test_support { + use std::{collections::HashMap, path::PathBuf, sync::Arc, time::Duration}; + + use defguard_common::{ + db::{Id, models::gateway::Gateway}, + messages::peer_stats_update::PeerStatsUpdate, + }; + use defguard_core::grpc::GatewayEvent; + use sqlx::PgPool; + use tokio::sync::{broadcast::Sender, mpsc::UnboundedSender, watch::Receiver}; + + use crate::{ + GatewayManager, GatewayTxSet, TestGatewayManagerControl, handler::TestGatewayHandler, + }; + + #[derive(Clone, Default)] + pub struct GatewayManagerControl { + inner: TestGatewayManagerControl, + } + + impl GatewayManagerControl { + #[must_use] + pub fn new() -> Self { + Self { + inner: TestGatewayManagerControl::new(), + } + } + + pub fn register_gateway_url(&self, gateway_url: String, socket_path: PathBuf) { + self.inner.register_gateway_url(gateway_url, socket_path); + } + + #[must_use] + pub fn new_manager(&self, pool: PgPool, tx: GatewayTxSet) -> GatewayManager { + GatewayManager::new_for_test(pool, tx, self.inner.clone()) + } + + pub fn handler_spawn_attempt_count(&self, gateway_id: Id) -> u64 { + self.inner.handler_spawn_attempt_count(gateway_id) + } + + pub async fn wait_for_handler_spawn_attempt_count( + &self, + gateway_id: Id, + expected_count: u64, + ) { + self.inner + .wait_for_handler_spawn_attempt_count(gateway_id, expected_count) + .await; + } + + pub fn handler_connection_attempt_count(&self, gateway_id: Id) -> u64 { + self.inner.handler_connection_attempt_count(gateway_id) + } + + pub async fn wait_for_handler_connection_attempt_count( + &self, + gateway_id: Id, + expected_count: u64, + ) { + self.inner + .wait_for_handler_connection_attempt_count(gateway_id, expected_count) + .await; + } + + pub fn gateway_notification_count(&self, gateway_id: Id) -> u64 { + self.inner.gateway_notification_count(gateway_id) + } + + pub async fn wait_for_gateway_notification_count( + &self, + gateway_id: Id, + expected_count: u64, + ) { + self.inner + .wait_for_gateway_notification_count(gateway_id, expected_count) + .await; + } + + pub async fn wait_until_listener_ready(&self) { + self.inner.wait_until_listener_ready().await; + } + + pub fn set_retry_delay(&self, retry_delay: Duration) { + self.inner.set_retry_delay(retry_delay); + } + } + + pub struct GatewayHandler { + inner: TestGatewayHandler, + } + + impl GatewayHandler { + pub fn new( + gateway: Gateway, + pool: PgPool, + events_tx: Sender, + peer_stats_tx: UnboundedSender, + certs_rx: Receiver>>, + socket_path: PathBuf, + ) -> anyhow::Result { + let inner = TestGatewayHandler::new( + gateway, + pool, + events_tx, + peer_stats_tx, + certs_rx, + socket_path, + )?; + Ok(Self { inner }) + } + + pub async fn handle_connection_once(&mut self) -> anyhow::Result<()> { + self.inner.handle_connection_once().await + } + } } pub struct GatewayManager { clients: Arc>>, pool: PgPool, handlers: JoinSet>, - test_support: GatewayManagerTestSupport, + test_support: Option, tx: GatewayTxSet, } @@ -234,23 +416,18 @@ impl GatewayManager { clients: Arc::default(), handlers: JoinSet::new(), pool, - test_support: GatewayManagerTestSupport::default(), + test_support: None, tx, } } - #[doc(hidden)] #[must_use] - pub fn new_for_test( - pool: PgPool, - tx: GatewayTxSet, - control: TestGatewayManagerControl, - ) -> Self { + fn new_for_test(pool: PgPool, tx: GatewayTxSet, control: TestGatewayManagerControl) -> Self { Self { clients: Arc::default(), handlers: JoinSet::new(), pool, - test_support: control.inner, + test_support: Some(control.inner), tx, } } @@ -282,7 +459,9 @@ impl GatewayManager { // Observe gateway URL changes. let mut listener = PgListener::connect_with(&self.pool).await?; listener.listen(GATEWAY_TABLE_TRIGGER).await?; - self.test_support.mark_listener_ready(); + if let Some(test_support) = &self.test_support { + test_support.mark_listener_ready(); + } while let Ok(notification) = listener.recv().await { let payload = notification.payload(); match serde_json::from_str::>>(payload) { @@ -313,6 +492,7 @@ impl GatewayManager { else { continue; }; + let mut old = old; let id = new.id; if old.address == new.address @@ -321,7 +501,9 @@ impl GatewayManager { { debug!("Gateway address/port/state didn't change"); } else { + self.remove_client(old.id); if let Some(abort_handle) = abort_handles.remove(&old.id) { + old.touch_disconnected(&self.pool).await?; info!( "Aborting connection to Gateway {old}, it has changed in the \ database" @@ -352,12 +534,7 @@ impl GatewayManager { }; // Send purge request to Gateway. - let maybe_client = { - self.clients - .lock() - .expect("Failed to lock GatewayManager::clients") - .remove(&old.id) - }; + let maybe_client = self.remove_client(old.id); if let Some(mut client) = maybe_client { debug!("Sending purge request to Gateway {old}"); @@ -380,15 +557,19 @@ impl GatewayManager { ); abort_handle.abort(); } else if old.enabled { - warn!("Cannot find Gateway {old} on the list of connected gateways"); + warn!( + "Cannot find Gateway {old} on the list of connected gateways" + ); } Some(old.id) } }; - if let Some(gateway_id) = maybe_gateway_id { - self.test_support.note_gateway_notification(gateway_id); + if let (Some(gateway_id), Some(test_support)) = + (maybe_gateway_id, self.test_support.as_ref()) + { + test_support.note_gateway_notification(gateway_id); } } Err(err) => error!("Failed to de-serialize database notification object: {err}"), @@ -408,39 +589,64 @@ impl GatewayManager { clients: Arc>>, certs_rx: Receiver>>, ) -> Result { - self.test_support.note_handler_spawn_attempt(gateway.id); - let mut gateway_handler = - if let Some(socket_path) = self.test_support.socket_path_for(&gateway) { - GatewayHandler::new_with_test_socket( - gateway, - self.pool.clone(), - self.tx.events.clone(), - self.tx.peer_stats.clone(), - certs_rx.clone(), - socket_path, - )? - } else { - GatewayHandler::new( - gateway, - self.pool.clone(), - self.tx.events.clone(), - self.tx.peer_stats.clone(), - certs_rx.clone(), - )? - }; + let maybe_test_support = self.test_support.clone(); + + if let Some(test_support) = &maybe_test_support { + test_support.note_handler_spawn_attempt(gateway.id); + } + + let mut gateway_handler = if let Some(socket_path) = maybe_test_support + .as_ref() + .and_then(|test_support| test_support.socket_path_for(&gateway)) + { + GatewayHandler::new_with_test_socket( + gateway, + self.pool.clone(), + self.tx.events.clone(), + self.tx.peer_stats.clone(), + certs_rx.clone(), + socket_path, + )? + } else { + GatewayHandler::new( + gateway, + self.pool.clone(), + self.tx.events.clone(), + self.tx.peer_stats.clone(), + certs_rx.clone(), + )? + }; + + if let Some(test_support) = maybe_test_support { + gateway_handler.attach_test_support(test_support); + } + + let manager_reconnect_delay = self.test_support.as_ref().map_or( + GATEWAY_RECONNECT_DELAY, + GatewayManagerTestSupport::manager_reconnect_delay, + ); let abort_handle = self.handlers.spawn(async move { loop { if let Err(err) = gateway_handler .handle_connection(Arc::clone(&clients)) .await { - error!("Gateway connection error: {err}, retrying in 5 seconds..."); - tokio::time::sleep(GATEWAY_RECONNECT_DELAY).await; + error!( + "Gateway connection error: {err}, retrying in {manager_reconnect_delay:?}..." + ); + tokio::time::sleep(manager_reconnect_delay).await; } } }); Ok(abort_handle) } + + fn remove_client(&self, gateway_id: Id) -> Option { + self.clients + .lock() + .expect("Failed to lock GatewayManager::clients") + .remove(&gateway_id) + } } /// Shared set of outbound channels that gateway instances use to forward diff --git a/crates/defguard_gateway_manager/tests/common/mod.rs b/crates/defguard_gateway_manager/tests/common/mod.rs index bd89dfc097..f860f232f8 100644 --- a/crates/defguard_gateway_manager/tests/common/mod.rs +++ b/crates/defguard_gateway_manager/tests/common/mod.rs @@ -21,7 +21,8 @@ use defguard_common::{ }; use defguard_core::grpc::GatewayEvent; use defguard_gateway_manager::{ - GatewayManager, GatewayTxSet, TestGatewayHandler, TestGatewayManagerControl, + GatewayTxSet, + test_support::{GatewayHandler, GatewayManagerControl}, }; use defguard_proto::gateway::{ ConfigurationRequest, CoreRequest, CoreResponse, PeerStats, core_request, gateway_server, @@ -69,6 +70,10 @@ fn unique_socket_path() -> PathBuf { )) } +pub(crate) fn unique_mock_gateway_socket_path() -> PathBuf { + unique_socket_path() +} + #[derive(Clone)] struct MockGatewayService { state: Arc, @@ -155,7 +160,10 @@ pub(crate) struct MockGatewayHarness { impl MockGatewayHarness { pub(crate) async fn start() -> Self { - let socket_path = unique_socket_path(); + Self::start_at(unique_socket_path()).await + } + + pub(crate) async fn start_at(socket_path: PathBuf) -> Self { let _ = std::fs::remove_file(&socket_path); let listener = @@ -326,7 +334,7 @@ impl Drop for MockGatewayHarness { pub(crate) struct ManagerTestContext { pub(crate) pool: PgPool, - control: TestGatewayManagerControl, + control: GatewayManagerControl, manager_task: Option>>, } @@ -339,7 +347,7 @@ impl ManagerTestContext { Self { pool, - control: TestGatewayManagerControl::new(), + control: GatewayManagerControl::new(), manager_task: None, } } @@ -352,15 +360,26 @@ impl ManagerTestContext { self.register_gateway_url(gateway.url(), mock_gateway); } - pub(crate) fn register_gateway_url(&self, gateway_url: String, mock_gateway: &MockGatewayHarness) { - self.control - .register_gateway_url(gateway_url, mock_gateway.socket_path()); + pub(crate) fn register_gateway_url( + &self, + gateway_url: String, + mock_gateway: &MockGatewayHarness, + ) { + self.register_gateway_socket_path(gateway_url, mock_gateway.socket_path()); + } + + pub(crate) fn register_gateway_socket_path(&self, gateway_url: String, socket_path: PathBuf) { + self.control.register_gateway_url(gateway_url, socket_path); } pub(crate) fn handler_spawn_attempt_count(&self, gateway_id: Id) -> u64 { self.control.handler_spawn_attempt_count(gateway_id) } + pub(crate) fn handler_connection_attempt_count(&self, gateway_id: Id) -> u64 { + self.control.handler_connection_attempt_count(gateway_id) + } + pub(crate) fn gateway_notification_count(&self, gateway_id: Id) -> u64 { self.control.gateway_notification_count(gateway_id) } @@ -379,6 +398,20 @@ impl ManagerTestContext { .expect("timed out waiting for gateway manager handler spawn attempt"); } + pub(crate) async fn wait_for_handler_connection_attempt_count( + &self, + gateway_id: Id, + expected_count: u64, + ) { + timeout( + TEST_TIMEOUT, + self.control + .wait_for_handler_connection_attempt_count(gateway_id, expected_count), + ) + .await + .expect("timed out waiting for gateway manager handler connection attempt"); + } + pub(crate) async fn wait_for_gateway_notification_count( &self, gateway_id: Id, @@ -402,7 +435,7 @@ impl ManagerTestContext { let (events_tx, _) = broadcast::channel(16); let (peer_stats_tx, _peer_stats_rx) = mpsc::unbounded_channel(); let tx = GatewayTxSet::new(events_tx, peer_stats_tx); - let mut manager = GatewayManager::new_for_test(self.pool.clone(), tx, self.control.clone()); + let mut manager = self.control.new_manager(self.pool.clone(), tx); let manager_task = tokio::spawn(async move { manager.run().await }); timeout(TEST_TIMEOUT, self.control.wait_until_listener_ready()) @@ -411,6 +444,10 @@ impl ManagerTestContext { self.manager_task = Some(manager_task); } + pub(crate) fn set_retry_delay(&self, retry_delay: Duration) { + self.control.set_retry_delay(retry_delay); + } + pub(crate) async fn finish(mut self) { if let Some(manager_task) = self.manager_task.take() { manager_task.abort(); @@ -464,7 +501,7 @@ impl HandlerTestContext { let (peer_stats_tx, peer_stats_rx) = mpsc::unbounded_channel(); let (_, certs_rx) = watch::channel(Arc::new(HashMap::new())); let mut mock_gateway = MockGatewayHarness::start().await; - let mut handler = TestGatewayHandler::new( + let mut handler = GatewayHandler::new( gateway.clone(), pool.clone(), events_tx.clone(), diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/manager.rs b/crates/defguard_gateway_manager/tests/gateway_manager/manager.rs index 5c7f5219c8..9855aeabc7 100644 --- a/crates/defguard_gateway_manager/tests/gateway_manager/manager.rs +++ b/crates/defguard_gateway_manager/tests/gateway_manager/manager.rs @@ -1,13 +1,18 @@ +use std::time::Duration; + use defguard_common::db::{Id, models::gateway::Gateway}; use defguard_proto::gateway::core_response; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; +use tonic::Status; use crate::common::{ ManagerTestContext, MockGatewayHarness, build_gateway_with_enabled, create_gateway, - create_gateway_with_enabled, create_network, reload_gateway, + create_gateway_with_enabled, create_network, reload_gateway, unique_mock_gateway_socket_path, wait_for_gateway_connection_state, }; +const FAST_RETRY_DELAY: Duration = Duration::from_millis(20); + async fn complete_manager_handshake( context: &ManagerTestContext, gateway: &Gateway, @@ -69,7 +74,9 @@ async fn test_starts_gateway_after_enabled_update_notification( .await .expect("failed to enable test gateway"); - context.wait_for_gateway_notification_count(gateway.id, 1).await; + context + .wait_for_gateway_notification_count(gateway.id, 1) + .await; context .wait_for_handler_spawn_attempt_count(gateway.id, 1) .await; @@ -124,10 +131,7 @@ async fn test_noop_gateway_update_does_not_restart_handler( } #[sqlx::test] -async fn test_gateway_address_change_restarts_handler( - _: PgPoolOptions, - options: PgConnectOptions, -) { +async fn test_gateway_address_change_restarts_handler(_: PgPoolOptions, options: PgConnectOptions) { let mut context = ManagerTestContext::new(options).await; let network = create_network(&context.pool).await; let mut gateway = create_gateway(&context.pool, network.id).await; @@ -172,6 +176,53 @@ async fn test_gateway_address_change_restarts_handler( context.finish().await; } +#[sqlx::test] +async fn test_enabled_gateway_update_to_disabled_stops_handler( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = ManagerTestContext::new(options).await; + let network = create_network(&context.pool).await; + let mut gateway = create_gateway(&context.pool, network.id).await; + let mut mock_gateway = MockGatewayHarness::start().await; + context.register_gateway_mock(&gateway, &mock_gateway); + + context.start().await; + complete_manager_handshake(&context, &gateway, &mut mock_gateway).await; + + gateway = reload_gateway(&context.pool, gateway.id).await; + let initial_spawn_attempts = context.handler_spawn_attempt_count(gateway.id); + let initial_notification_count = context.gateway_notification_count(gateway.id); + let initial_connection_count = mock_gateway.connection_count(); + + gateway.enabled = false; + gateway.modified_by = "manager-disable-update".to_string(); + gateway + .save(&context.pool) + .await + .expect("failed to save gateway disable update"); + + context + .wait_for_gateway_notification_count(gateway.id, initial_notification_count + 1) + .await; + let gateway_after = wait_for_gateway_connection_state(&context.pool, gateway.id, false).await; + assert!(!gateway_after.is_connected()); + assert!(gateway_after.disconnected_at.is_some()); + assert_eq!( + context.handler_spawn_attempt_count(gateway.id), + initial_spawn_attempts, + "disabling the gateway should stop the existing handler without spawning a replacement" + ); + assert_eq!( + mock_gateway.connection_count(), + initial_connection_count, + "disabling the gateway should not create a new gateway connection" + ); + mock_gateway.expect_server_finished().await; + + context.finish().await; +} + #[sqlx::test] async fn test_insert_notification_starts_handler_for_enabled_gateway( _: PgPoolOptions, @@ -191,8 +242,12 @@ async fn test_insert_notification_starts_handler_for_enabled_gateway( .await .expect("failed to insert enabled test gateway"); - context.wait_for_gateway_notification_count(gateway.id, 1).await; - context.wait_for_handler_spawn_attempt_count(gateway.id, 1).await; + context + .wait_for_gateway_notification_count(gateway.id, 1) + .await; + context + .wait_for_handler_spawn_attempt_count(gateway.id, 1) + .await; complete_manager_handshake(&context, &gateway, &mut mock_gateway).await; context.finish().await; @@ -221,3 +276,150 @@ async fn test_delete_notification_purges_and_aborts_gateway_connection( context.finish().await; } + +#[sqlx::test] +async fn test_retries_failed_connection_without_notification_or_duplicate_handler( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = ManagerTestContext::new(options).await; + context.set_retry_delay(FAST_RETRY_DELAY); + + let network = create_network(&context.pool).await; + let gateway = create_gateway(&context.pool, network.id).await; + let socket_path = unique_mock_gateway_socket_path(); + context.register_gateway_socket_path(gateway.url(), socket_path.clone()); + + context.start().await; + context + .wait_for_handler_spawn_attempt_count(gateway.id, 1) + .await; + context + .wait_for_handler_connection_attempt_count(gateway.id, 2) + .await; + + assert_eq!( + context.gateway_notification_count(gateway.id), + 0, + "manager reconnect retries should not depend on gateway table notifications" + ); + assert_eq!( + context.handler_spawn_attempt_count(gateway.id), + 1, + "manager reconnect retries should reuse the existing handler task" + ); + + let mut mock_gateway = MockGatewayHarness::start_at(socket_path).await; + mock_gateway.wait_for_connection_count(1).await; + complete_manager_handshake(&context, &gateway, &mut mock_gateway).await; + + assert_eq!( + context.handler_spawn_attempt_count(gateway.id), + 1, + "reconnect success should not create a second concurrent handler" + ); + + context.finish().await; +} + +#[sqlx::test] +async fn test_retries_after_stream_close_with_single_handler_supervisor( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = ManagerTestContext::new(options).await; + context.set_retry_delay(FAST_RETRY_DELAY); + + let network = create_network(&context.pool).await; + let gateway = create_gateway(&context.pool, network.id).await; + let mut mock_gateway = MockGatewayHarness::start().await; + context.register_gateway_mock(&gateway, &mock_gateway); + + context.start().await; + complete_manager_handshake(&context, &gateway, &mut mock_gateway).await; + + let initial_spawn_attempts = context.handler_spawn_attempt_count(gateway.id); + let initial_connection_attempts = context.handler_connection_attempt_count(gateway.id); + let reconnect_socket_path = mock_gateway.socket_path(); + + mock_gateway.close_stream(); + + let gateway_after_disconnect = + wait_for_gateway_connection_state(&context.pool, gateway.id, false).await; + assert!(!gateway_after_disconnect.is_connected()); + assert!(gateway_after_disconnect.disconnected_at.is_some()); + + mock_gateway.expect_server_finished().await; + + context + .wait_for_handler_connection_attempt_count(gateway.id, initial_connection_attempts + 1) + .await; + assert_eq!( + context.handler_spawn_attempt_count(gateway.id), + initial_spawn_attempts, + "stream closure retries should keep a single handler supervisor" + ); + + let mut replacement_mock_gateway = MockGatewayHarness::start_at(reconnect_socket_path).await; + replacement_mock_gateway.wait_for_connection_count(1).await; + complete_manager_handshake(&context, &gateway, &mut replacement_mock_gateway).await; + + assert_eq!( + context.handler_spawn_attempt_count(gateway.id), + initial_spawn_attempts, + "successful reconnect after stream closure should not create a duplicate handler" + ); + + context.finish().await; +} + +#[sqlx::test] +async fn test_retries_after_stream_error_with_single_handler_supervisor( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = ManagerTestContext::new(options).await; + context.set_retry_delay(FAST_RETRY_DELAY); + + let network = create_network(&context.pool).await; + let gateway = create_gateway(&context.pool, network.id).await; + let mut mock_gateway = MockGatewayHarness::start().await; + context.register_gateway_mock(&gateway, &mock_gateway); + + context.start().await; + complete_manager_handshake(&context, &gateway, &mut mock_gateway).await; + + let initial_spawn_attempts = context.handler_spawn_attempt_count(gateway.id); + let initial_connection_attempts = context.handler_connection_attempt_count(gateway.id); + let reconnect_socket_path = mock_gateway.socket_path(); + + mock_gateway.send_stream_error(Status::internal("mock gateway stream failure")); + + let gateway_after_disconnect = + wait_for_gateway_connection_state(&context.pool, gateway.id, false).await; + assert!(!gateway_after_disconnect.is_connected()); + assert!(gateway_after_disconnect.disconnected_at.is_some()); + + mock_gateway.expect_server_finished().await; + + context + .wait_for_handler_connection_attempt_count(gateway.id, initial_connection_attempts + 1) + .await; + assert_eq!( + context.handler_spawn_attempt_count(gateway.id), + initial_spawn_attempts, + "stream failure retries should keep a single handler supervisor" + ); + + let mut replacement_mock_gateway = MockGatewayHarness::start_at(reconnect_socket_path).await; + replacement_mock_gateway.wait_for_connection_count(1).await; + complete_manager_handshake(&context, &gateway, &mut replacement_mock_gateway).await; + + assert_eq!( + context.handler_spawn_attempt_count(gateway.id), + initial_spawn_attempts, + "successful reconnect after stream failure should not create a duplicate handler" + ); + + context.finish().await; +} From a00b07d25bd5bfa138c03e0787ca57ddf5d9e726 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Wed, 18 Mar 2026 19:43:05 +0100 Subject: [PATCH 20/36] review fixes --- crates/defguard_common/src/auth/claims.rs | 74 +++++++++++++++++-- crates/defguard_core/src/grpc/mod.rs | 23 +++++- .../tests/integration/grpc/common/mod.rs | 23 ++---- .../defguard_gateway_manager/src/handler.rs | 54 +++++++++++--- crates/defguard_gateway_manager/src/lib.rs | 8 +- 5 files changed, 146 insertions(+), 36 deletions(-) diff --git a/crates/defguard_common/src/auth/claims.rs b/crates/defguard_common/src/auth/claims.rs index bca84e18e7..583add3eea 100644 --- a/crates/defguard_common/src/auth/claims.rs +++ b/crates/defguard_common/src/auth/claims.rs @@ -1,5 +1,6 @@ use std::{ env, + sync::OnceLock, time::{Duration, SystemTime}, }; @@ -13,6 +14,8 @@ pub static AUTH_SECRET_ENV: &str = "DEFGUARD_AUTH_SECRET"; pub static GATEWAY_SECRET_ENV: &str = "DEFGUARD_GATEWAY_SECRET"; pub static YUBIBRIDGE_SECRET_ENV: &str = "DEFGUARD_YUBIBRIDGE_SECRET"; +static JWT_SECRET_OVERRIDES: OnceLock = OnceLock::new(); + #[derive(Clone, Copy, Default)] pub enum ClaimsType { #[default] @@ -39,6 +42,31 @@ pub struct Claims { pub nbf: u64, } +#[derive(Clone, Debug, PartialEq, Eq)] +struct JwtSecretOverrides { + auth: String, + gateway: String, + yubibridge: String, +} + +impl JwtSecretOverrides { + fn new(auth: String, gateway: String, yubibridge: String) -> Self { + Self { + auth, + gateway, + yubibridge, + } + } + + fn secret_for(&self, claims_type: ClaimsType) -> &str { + match claims_type { + ClaimsType::Auth | ClaimsType::DesktopClient => &self.auth, + ClaimsType::Gateway => &self.gateway, + ClaimsType::YubiBridge => &self.yubibridge, + } + } +} + impl Claims { #[must_use] pub fn new(claims_type: ClaimsType, sub: String, client_id: String, duration: u64) -> Self { @@ -64,12 +92,11 @@ impl Claims { } fn get_secret(claims_type: ClaimsType) -> String { - let env_var = match claims_type { - ClaimsType::Auth | ClaimsType::DesktopClient => AUTH_SECRET_ENV, - ClaimsType::Gateway => GATEWAY_SECRET_ENV, - ClaimsType::YubiBridge => YUBIBRIDGE_SECRET_ENV, - }; - env::var(env_var).unwrap_or_default() + if let Some(secret_overrides) = JWT_SECRET_OVERRIDES.get() { + return secret_overrides.secret_for(claims_type).to_string(); + } + + env::var(secret_env(claims_type)).unwrap_or_default() } /// Convert claims to JWT. @@ -96,3 +123,38 @@ impl Claims { .map(|data| data.claims) } } + +fn secret_env(claims_type: ClaimsType) -> &'static str { + match claims_type { + ClaimsType::Auth | ClaimsType::DesktopClient => AUTH_SECRET_ENV, + ClaimsType::Gateway => GATEWAY_SECRET_ENV, + ClaimsType::YubiBridge => YUBIBRIDGE_SECRET_ENV, + } +} + +#[doc(hidden)] +pub mod test_support { + use super::{JWT_SECRET_OVERRIDES, JwtSecretOverrides}; + + pub fn initialize_jwt_secret_overrides( + auth_secret: impl Into, + gateway_secret: impl Into, + yubibridge_secret: impl Into, + ) { + let secret_overrides = JwtSecretOverrides::new( + auth_secret.into(), + gateway_secret.into(), + yubibridge_secret.into(), + ); + + if let Err(secret_overrides) = JWT_SECRET_OVERRIDES.set(secret_overrides) { + let existing_overrides = JWT_SECRET_OVERRIDES + .get() + .expect("JWT secret overrides should be initialized"); + assert_eq!( + existing_overrides, &secret_overrides, + "JWT secret overrides already initialized with different values" + ); + } + } +} diff --git a/crates/defguard_core/src/grpc/mod.rs b/crates/defguard_core/src/grpc/mod.rs index a4ec3e97aa..cd7a95ff94 100644 --- a/crates/defguard_core/src/grpc/mod.rs +++ b/crates/defguard_core/src/grpc/mod.rs @@ -96,7 +96,7 @@ pub async fn run_grpc_server( Ok(()) } -pub async fn build_grpc_service_router( +pub(crate) async fn build_grpc_service_router( server: Server, pool: PgPool, worker_state: Arc>, @@ -128,6 +128,27 @@ pub async fn build_grpc_service_router( Ok(router) } +#[doc(hidden)] +pub mod test_support { + use std::sync::{Arc, Mutex}; + + use sqlx::PgPool; + use tonic::transport::{Server, server::Router}; + + use crate::auth::failed_login::FailedLoginMap; + + use super::WorkerState; + + pub async fn build_grpc_service_router( + server: Server, + pool: PgPool, + worker_state: Arc>, + failed_logins: Arc>, + ) -> Result { + super::build_grpc_service_router(server, pool, worker_state, failed_logins).await + } +} + pub struct Job { id: u32, first_name: String, diff --git a/crates/defguard_core/tests/integration/grpc/common/mod.rs b/crates/defguard_core/tests/integration/grpc/common/mod.rs index 3d477cfdd1..68e9e4baaf 100644 --- a/crates/defguard_core/tests/integration/grpc/common/mod.rs +++ b/crates/defguard_core/tests/integration/grpc/common/mod.rs @@ -1,18 +1,13 @@ -use std::{ - env, - sync::{Arc, Mutex, Once}, -}; +use std::sync::{Arc, Mutex}; use defguard_common::{ - auth::claims::{ - AUTH_SECRET_ENV, Claims, ClaimsType, GATEWAY_SECRET_ENV, YUBIBRIDGE_SECRET_ENV, - }, + auth::claims::{Claims, ClaimsType, test_support::initialize_jwt_secret_overrides}, db::setup_pool, }; use defguard_core::{ auth::failed_login::FailedLoginMap, db::AppEvent, - grpc::{AUTHORIZATION_HEADER, WorkerState, build_grpc_service_router}, + grpc::{AUTHORIZATION_HEADER, WorkerState, test_support::build_grpc_service_router}, }; use hyper_util::rt::TokioIo; use sqlx::{ @@ -32,8 +27,6 @@ use tower::service_fn; use crate::common::initialize_users; -static JWT_SECRETS: Once = Once::new(); - pub struct TestGrpcServer { grpc_server_task_handle: JoinHandle<()>, pub worker_state: Arc>, @@ -169,9 +162,9 @@ pub(crate) fn worker_request(message: T, username: &str) -> Request { } fn initialize_jwt_secrets() { - JWT_SECRETS.call_once(|| unsafe { - env::set_var(AUTH_SECRET_ENV, "defguard-test-auth-secret"); - env::set_var(GATEWAY_SECRET_ENV, "defguard-test-gateway-secret"); - env::set_var(YUBIBRIDGE_SECRET_ENV, "defguard-test-yubibridge-secret"); - }); + initialize_jwt_secret_overrides( + "defguard-test-auth-secret", + "defguard-test-gateway-secret", + "defguard-test-yubibridge-secret", + ); } diff --git a/crates/defguard_gateway_manager/src/handler.rs b/crates/defguard_gateway_manager/src/handler.rs index d0887a0b8a..57dd9abbd2 100644 --- a/crates/defguard_gateway_manager/src/handler.rs +++ b/crates/defguard_gateway_manager/src/handler.rs @@ -268,6 +268,30 @@ impl GatewayHandler { self.mark_disconnected().await; } + fn remove_client(&self, clients: &Arc>>) { + clients + .lock() + .expect("GatewayHandler failed to lock clients") + .remove(&self.gateway.id); + } + + async fn handle_stream_disconnection( + &mut self, + clients: &Arc>>, + retry_on_connect_failure: bool, + retry_delay: std::time::Duration, + ) { + self.remove_client(clients); + self.handle_disconnection_error().await; + + if !retry_on_connect_failure { + return; + } + + debug!("Waiting {retry_delay:?} to re-establish the connection"); + sleep(retry_delay).await; + } + async fn handle_connection_iteration( &mut self, clients: Arc>>, @@ -314,10 +338,6 @@ impl GatewayHandler { if let Some(test_support) = &self.test_support { test_support.note_handler_connection_attempt(self.gateway.id); } - clients - .lock() - .expect("GatewayHandler failed to lock clients") - .insert(self.gateway.id, client.clone()); let (tx, rx) = mpsc::unbounded_channel(); let retry_delay = self .test_support @@ -335,8 +355,6 @@ impl GatewayHandler { return Err(err.into()); } }; - info!("Connected to Defguard Gateway {uri}"); - let maybe_info = defguard_version::ComponentInfo::from_metadata(response.metadata()); let (version, _info) = defguard_version::get_tracing_variables(&maybe_info); @@ -345,6 +363,12 @@ impl GatewayHandler { gateway.save(&self.pool).await?; } + clients + .lock() + .expect("GatewayHandler failed to lock clients") + .insert(self.gateway.id, client.clone()); + info!("Connected to Defguard Gateway {uri}"); + let mut resp_stream = response.into_inner(); let mut config_sent = false; @@ -352,7 +376,12 @@ impl GatewayHandler { match resp_stream.message().await { Ok(None) => { info!("Stream was closed by the sender."); - self.mark_disconnected().await; + self.handle_stream_disconnection( + &clients, + retry_on_connect_failure, + retry_delay, + ) + .await; return Ok(()); } Ok(Some(received)) => { @@ -428,11 +457,12 @@ impl GatewayHandler { } Err(err) => { error!("Disconnected from Gateway at {uri}, error: {err}"); - self.handle_disconnection_error().await; - if retry_on_connect_failure { - debug!("Waiting {retry_delay:?} to re-establish the connection"); - sleep(retry_delay).await; - } + self.handle_stream_disconnection( + &clients, + retry_on_connect_failure, + retry_delay, + ) + .await; return Ok(()); } } diff --git a/crates/defguard_gateway_manager/src/lib.rs b/crates/defguard_gateway_manager/src/lib.rs index 75ea5ff662..c6de81161e 100644 --- a/crates/defguard_gateway_manager/src/lib.rs +++ b/crates/defguard_gateway_manager/src/lib.rs @@ -503,7 +503,11 @@ impl GatewayManager { } else { self.remove_client(old.id); if let Some(abort_handle) = abort_handles.remove(&old.id) { - old.touch_disconnected(&self.pool).await?; + if let Err(err) = old.touch_disconnected(&self.pool).await { + error!( + "Failed to update disconnection time for Gateway {old} after database change: {err}" + ); + } info!( "Aborting connection to Gateway {old}, it has changed in the \ database" @@ -552,7 +556,7 @@ impl GatewayManager { // Kill the `GatewayHandler` and the connection. if let Some(abort_handle) = abort_handles.remove(&old.id) { info!( - "Aborting connection to Gateway {old}, it has disappeard from the \ + "Aborting connection to Gateway {old}, it has disappeared from the \ database" ); abort_handle.abort(); From 2098fc72ceb75d93679a24d0221af6533b543b50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Thu, 19 Mar 2026 06:57:57 +0100 Subject: [PATCH 21/36] move tests back into the crate to simplify test setup --- .../defguard_gateway_manager/src/handler.rs | 163 +++++---- crates/defguard_gateway_manager/src/lib.rs | 338 ++++++------------ .../{ => src}/tests/common/mod.rs | 14 +- .../tests/gateway_manager/handler.rs | 2 +- .../gateway_manager/handler/device_events.rs | 9 +- .../handler/firewall_events.rs | 0 .../gateway_manager/handler/handshake.rs | 0 .../gateway_manager/handler/lifecycle.rs | 0 .../tests/gateway_manager/handler/mfa.rs | 0 .../gateway_manager/handler/network_events.rs | 10 +- .../tests/gateway_manager/handler/stats.rs | 0 .../tests/gateway_manager/handler/support.rs | 2 +- .../tests/gateway_manager/manager.rs | 2 +- .../{ => src}/tests/gateway_manager/mod.rs | 0 .../defguard_gateway_manager/src/tests/mod.rs | 2 + crates/defguard_gateway_manager/tests/mod.rs | 2 - 16 files changed, 226 insertions(+), 318 deletions(-) rename crates/defguard_gateway_manager/{ => src}/tests/common/mod.rs (98%) rename crates/defguard_gateway_manager/{ => src}/tests/gateway_manager/handler.rs (93%) rename crates/defguard_gateway_manager/{ => src}/tests/gateway_manager/handler/device_events.rs (96%) rename crates/defguard_gateway_manager/{ => src}/tests/gateway_manager/handler/firewall_events.rs (100%) rename crates/defguard_gateway_manager/{ => src}/tests/gateway_manager/handler/handshake.rs (100%) rename crates/defguard_gateway_manager/{ => src}/tests/gateway_manager/handler/lifecycle.rs (100%) rename crates/defguard_gateway_manager/{ => src}/tests/gateway_manager/handler/mfa.rs (100%) rename crates/defguard_gateway_manager/{ => src}/tests/gateway_manager/handler/network_events.rs (97%) rename crates/defguard_gateway_manager/{ => src}/tests/gateway_manager/handler/stats.rs (100%) rename crates/defguard_gateway_manager/{ => src}/tests/gateway_manager/handler/support.rs (99%) rename crates/defguard_gateway_manager/{ => src}/tests/gateway_manager/manager.rs (99%) rename crates/defguard_gateway_manager/{ => src}/tests/gateway_manager/mod.rs (100%) create mode 100644 crates/defguard_gateway_manager/src/tests/mod.rs delete mode 100644 crates/defguard_gateway_manager/tests/mod.rs diff --git a/crates/defguard_gateway_manager/src/handler.rs b/crates/defguard_gateway_manager/src/handler.rs index 57dd9abbd2..b1308bf6e8 100644 --- a/crates/defguard_gateway_manager/src/handler.rs +++ b/crates/defguard_gateway_manager/src/handler.rs @@ -1,7 +1,6 @@ use std::{ collections::HashMap, net::IpAddr, - path::PathBuf, str::FromStr, sync::{ Arc, Mutex, @@ -9,6 +8,9 @@ use std::{ }, }; +#[cfg(test)] +use std::path::PathBuf; + use chrono::DateTime; use defguard_common::{ VERSION, @@ -47,13 +49,18 @@ use tokio::{ use tokio_stream::wrappers::UnboundedReceiverStream; use tonic::{Code, Status, transport::Endpoint}; -use crate::{Client, GatewayManagerTestSupport, TEN_SECS, error::GatewayError}; +use crate::{Client, TEN_SECS, error::GatewayError}; +#[cfg(test)] +use crate::GatewayManagerTestSupport; + +#[cfg(test)] #[derive(Debug, Default)] struct GatewayTestTransport { socket_path: Option, } +#[cfg(test)] impl GatewayTestTransport { fn with_socket_path(socket_path: PathBuf) -> Self { Self { @@ -67,7 +74,7 @@ impl GatewayTestTransport { } /// One instance per connected Gateway. -pub(super) struct GatewayHandler { +pub(crate) struct GatewayHandler { // Gateway server endpoint URL. url: Url, gateway: Gateway, @@ -76,7 +83,9 @@ pub(super) struct GatewayHandler { events_tx: Sender, peer_stats_tx: UnboundedSender, certs_rx: watch::Receiver>>, + #[cfg(test)] test_transport: GatewayTestTransport, + #[cfg(test)] test_support: Option, } @@ -103,11 +112,14 @@ impl GatewayHandler { events_tx, peer_stats_tx, certs_rx, + #[cfg(test)] test_transport: GatewayTestTransport::default(), + #[cfg(test)] test_support: None, }) } + #[cfg(test)] pub(crate) fn new_with_test_socket( gateway: Gateway, pool: PgPool, @@ -121,10 +133,85 @@ impl GatewayHandler { Ok(handler) } - pub(super) fn attach_test_support(&mut self, test_support: GatewayManagerTestSupport) { + #[cfg(test)] + pub(crate) fn attach_test_support(&mut self, test_support: GatewayManagerTestSupport) { self.test_support = Some(test_support); } + #[cfg(test)] + fn note_handler_connection_attempt_for_tests(&self) { + if let Some(test_support) = &self.test_support { + test_support.note_handler_connection_attempt(self.gateway.id); + } + } + + #[cfg(not(test))] + fn note_handler_connection_attempt_for_tests(&self) {} + + #[cfg(test)] + fn handler_retry_delay(&self) -> std::time::Duration { + self.test_support + .as_ref() + .map_or(TEN_SECS, GatewayManagerTestSupport::handler_reconnect_delay) + } + + #[cfg(not(test))] + fn handler_retry_delay(&self) -> std::time::Duration { + TEN_SECS + } + + #[cfg(test)] + fn connect_channel( + &self, + endpoint: Endpoint, + ) -> Result { + if let Some(socket_path) = self.test_transport.socket_path().cloned() { + return Ok(endpoint.connect_with_connector_lazy(tower::service_fn( + move |_: tonic::transport::Uri| { + let socket_path = socket_path.clone(); + async move { + Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new( + tokio::net::UnixStream::connect(socket_path).await?, + )) + } + }, + ))); + } + + self.connect_tls_channel(endpoint) + } + + #[cfg(not(test))] + fn connect_channel( + &self, + endpoint: Endpoint, + ) -> Result { + self.connect_tls_channel(endpoint) + } + + fn connect_tls_channel( + &self, + endpoint: Endpoint, + ) -> Result { + let settings = Settings::get_current_settings(); + let Some(ca_cert_der) = settings.ca_cert_der else { + return Err(GatewayError::EndpointError( + "Core CA is not setup, can't create a Gateway endpoint.".to_string(), + )); + }; + let tls_config = + tls_certs::client_config(&ca_cert_der, self.certs_rx.clone(), self.gateway.id) + .map_err(|err| GatewayError::EndpointError(err.to_string()))?; + let connector = HttpsConnectorBuilder::new() + .with_tls_config(tls_config) + .https_only() + .enable_http2() + .build(); + let connector = HttpsSchemeConnector::new(connector); + + Ok(endpoint.connect_with_connector_lazy(connector)) + } + fn endpoint(&self) -> Result { let mut url = self.url.clone(); @@ -300,49 +387,16 @@ impl GatewayHandler { let endpoint = self.endpoint()?; let uri = endpoint.uri().to_string(); - let channel = if let Some(socket_path) = self.test_transport.socket_path().cloned() { - endpoint.connect_with_connector_lazy(tower::service_fn( - move |_: tonic::transport::Uri| { - let socket_path = socket_path.clone(); - async move { - Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new( - tokio::net::UnixStream::connect(socket_path).await?, - )) - } - }, - )) - } else { - let settings = Settings::get_current_settings(); - let Some(ca_cert_der) = settings.ca_cert_der else { - return Err(GatewayError::EndpointError( - "Core CA is not setup, can't create a Gateway endpoint.".to_string(), - )); - }; - let tls_config = - tls_certs::client_config(&ca_cert_der, self.certs_rx.clone(), self.gateway.id) - .map_err(|err| GatewayError::EndpointError(err.to_string()))?; - let connector = HttpsConnectorBuilder::new() - .with_tls_config(tls_config) - .https_only() - .enable_http2() - .build(); - let connector = HttpsSchemeConnector::new(connector); - endpoint.connect_with_connector_lazy(connector) - }; + let channel = self.connect_channel(endpoint)?; debug!("Connecting to Gateway {uri}"); let interceptor = ClientVersionInterceptor::new( Version::parse(VERSION).expect("failed to parse self version"), ); let mut client = gateway_client::GatewayClient::with_interceptor(channel, interceptor); - if let Some(test_support) = &self.test_support { - test_support.note_handler_connection_attempt(self.gateway.id); - } + self.note_handler_connection_attempt_for_tests(); let (tx, rx) = mpsc::unbounded_channel(); - let retry_delay = self - .test_support - .as_ref() - .map_or(TEN_SECS, GatewayManagerTestSupport::handler_reconnect_delay); + let retry_delay = self.handler_retry_delay(); let response = match client.bidi(UnboundedReceiverStream::new(rx)).await { Ok(response) => response, Err(err) => { @@ -479,36 +533,11 @@ impl GatewayHandler { .await?; } } -} - -pub(crate) struct TestGatewayHandler { - inner: GatewayHandler, -} - -impl TestGatewayHandler { - pub(crate) fn new( - gateway: Gateway, - pool: PgPool, - events_tx: Sender, - peer_stats_tx: UnboundedSender, - certs_rx: watch::Receiver>>, - socket_path: PathBuf, - ) -> anyhow::Result { - let inner = GatewayHandler::new_with_test_socket( - gateway, - pool, - events_tx, - peer_stats_tx, - certs_rx, - socket_path, - )?; - Ok(Self { inner }) - } + #[cfg(test)] pub(crate) async fn handle_connection_once(&mut self) -> anyhow::Result<()> { let clients = Arc::>>::default(); - self.inner - .handle_connection_iteration(clients, false) + self.handle_connection_iteration(clients, false) .await .map_err(anyhow::Error::from) } diff --git a/crates/defguard_gateway_manager/src/lib.rs b/crates/defguard_gateway_manager/src/lib.rs index c6de81161e..06ee518a79 100644 --- a/crates/defguard_gateway_manager/src/lib.rs +++ b/crates/defguard_gateway_manager/src/lib.rs @@ -1,13 +1,15 @@ use std::{ collections::HashMap, - path::PathBuf, - sync::{ - Arc, Mutex, - atomic::{AtomicBool, Ordering}, - }, + sync::{Arc, Mutex}, time::Duration, }; +#[cfg(test)] +use std::{ + path::PathBuf, + sync::atomic::{AtomicBool, Ordering}, +}; + use defguard_common::{ db::{ChangeNotification, Id, TriggerOperation, models::gateway::Gateway}, messages::peer_stats_update::PeerStatsUpdate, @@ -17,11 +19,14 @@ use defguard_proto::gateway::gateway_client::GatewayClient; use defguard_version::client::ClientVersionInterceptor; use sqlx::{PgPool, postgres::PgListener}; use tokio::{ - sync::{Notify, broadcast::Sender, mpsc::UnboundedSender, watch::Receiver}, + sync::{broadcast::Sender, mpsc::UnboundedSender, watch::Receiver}, task::{AbortHandle, JoinSet}, }; use tonic::{Request, service::interceptor::InterceptedService, transport::Channel}; +#[cfg(test)] +use tokio::sync::Notify; + use crate::{error::GatewayError, handler::GatewayHandler}; #[macro_use] @@ -31,6 +36,9 @@ mod certs; mod error; mod handler; +#[cfg(test)] +mod tests; + const GATEWAY_TABLE_TRIGGER: &str = "gateway_change"; const GATEWAY_RECONNECT_DELAY: Duration = Duration::from_secs(5); const TEN_SECS: Duration = Duration::from_secs(10); @@ -57,6 +65,7 @@ impl Drop for AbortTaskOnDrop { } } +#[cfg(test)] #[derive(Clone, Default)] struct GatewayManagerTestSupport { socket_paths_by_url: Arc>>, @@ -71,6 +80,7 @@ struct GatewayManagerTestSupport { retry_delay_override: Arc>>, } +#[cfg(test)] impl GatewayManagerTestSupport { fn register_gateway_url(&self, gateway_url: String, socket_path: PathBuf) { self.socket_paths_by_url @@ -96,6 +106,7 @@ impl GatewayManagerTestSupport { self.handler_spawn_attempt_notify.notify_waiters(); } + #[cfg(test)] fn handler_spawn_attempt_count(&self, gateway_id: Id) -> u64 { self.handler_spawn_attempts_by_gateway .lock() @@ -105,6 +116,7 @@ impl GatewayManagerTestSupport { .unwrap_or_default() } + #[cfg(test)] async fn wait_for_handler_spawn_attempt_count(&self, gateway_id: Id, expected_count: u64) { loop { if self.handler_spawn_attempt_count(gateway_id) >= expected_count { @@ -129,6 +141,7 @@ impl GatewayManagerTestSupport { self.handler_connection_attempt_notify.notify_waiters(); } + #[cfg(test)] fn handler_connection_attempt_count(&self, gateway_id: Id) -> u64 { self.handler_connection_attempts_by_gateway .lock() @@ -138,6 +151,7 @@ impl GatewayManagerTestSupport { .unwrap_or_default() } + #[cfg(test)] async fn wait_for_handler_connection_attempt_count(&self, gateway_id: Id, expected_count: u64) { loop { if self.handler_connection_attempt_count(gateway_id) >= expected_count { @@ -162,6 +176,7 @@ impl GatewayManagerTestSupport { self.gateway_notification_notify.notify_waiters(); } + #[cfg(test)] fn gateway_notification_count(&self, gateway_id: Id) -> u64 { self.gateway_notifications_by_gateway .lock() @@ -171,6 +186,7 @@ impl GatewayManagerTestSupport { .unwrap_or_default() } + #[cfg(test)] async fn wait_for_gateway_notification_count(&self, gateway_id: Id, expected_count: u64) { loop { if self.gateway_notification_count(gateway_id) >= expected_count { @@ -191,6 +207,7 @@ impl GatewayManagerTestSupport { self.listener_ready_notify.notify_waiters(); } + #[cfg(test)] async fn wait_until_listener_ready(&self) { loop { if self.listener_ready.load(Ordering::Acquire) { @@ -206,6 +223,7 @@ impl GatewayManagerTestSupport { } } + #[cfg(test)] fn set_retry_delay(&self, retry_delay: Duration) { *self .retry_delay_override @@ -228,183 +246,11 @@ impl GatewayManagerTestSupport { } } -#[derive(Clone, Default)] -struct TestGatewayManagerControl { - inner: GatewayManagerTestSupport, -} - -impl TestGatewayManagerControl { - #[must_use] - fn new() -> Self { - Self::default() - } - - fn register_gateway_url(&self, gateway_url: String, socket_path: PathBuf) { - self.inner.register_gateway_url(gateway_url, socket_path); - } - - fn handler_spawn_attempt_count(&self, gateway_id: Id) -> u64 { - self.inner.handler_spawn_attempt_count(gateway_id) - } - - async fn wait_for_handler_spawn_attempt_count(&self, gateway_id: Id, expected_count: u64) { - self.inner - .wait_for_handler_spawn_attempt_count(gateway_id, expected_count) - .await; - } - - fn handler_connection_attempt_count(&self, gateway_id: Id) -> u64 { - self.inner.handler_connection_attempt_count(gateway_id) - } - - async fn wait_for_handler_connection_attempt_count(&self, gateway_id: Id, expected_count: u64) { - self.inner - .wait_for_handler_connection_attempt_count(gateway_id, expected_count) - .await; - } - - fn gateway_notification_count(&self, gateway_id: Id) -> u64 { - self.inner.gateway_notification_count(gateway_id) - } - - async fn wait_for_gateway_notification_count(&self, gateway_id: Id, expected_count: u64) { - self.inner - .wait_for_gateway_notification_count(gateway_id, expected_count) - .await; - } - - async fn wait_until_listener_ready(&self) { - self.inner.wait_until_listener_ready().await; - } - - fn set_retry_delay(&self, retry_delay: Duration) { - self.inner.set_retry_delay(retry_delay); - } -} - -#[doc(hidden)] -pub mod test_support { - use std::{collections::HashMap, path::PathBuf, sync::Arc, time::Duration}; - - use defguard_common::{ - db::{Id, models::gateway::Gateway}, - messages::peer_stats_update::PeerStatsUpdate, - }; - use defguard_core::grpc::GatewayEvent; - use sqlx::PgPool; - use tokio::sync::{broadcast::Sender, mpsc::UnboundedSender, watch::Receiver}; - - use crate::{ - GatewayManager, GatewayTxSet, TestGatewayManagerControl, handler::TestGatewayHandler, - }; - - #[derive(Clone, Default)] - pub struct GatewayManagerControl { - inner: TestGatewayManagerControl, - } - - impl GatewayManagerControl { - #[must_use] - pub fn new() -> Self { - Self { - inner: TestGatewayManagerControl::new(), - } - } - - pub fn register_gateway_url(&self, gateway_url: String, socket_path: PathBuf) { - self.inner.register_gateway_url(gateway_url, socket_path); - } - - #[must_use] - pub fn new_manager(&self, pool: PgPool, tx: GatewayTxSet) -> GatewayManager { - GatewayManager::new_for_test(pool, tx, self.inner.clone()) - } - - pub fn handler_spawn_attempt_count(&self, gateway_id: Id) -> u64 { - self.inner.handler_spawn_attempt_count(gateway_id) - } - - pub async fn wait_for_handler_spawn_attempt_count( - &self, - gateway_id: Id, - expected_count: u64, - ) { - self.inner - .wait_for_handler_spawn_attempt_count(gateway_id, expected_count) - .await; - } - - pub fn handler_connection_attempt_count(&self, gateway_id: Id) -> u64 { - self.inner.handler_connection_attempt_count(gateway_id) - } - - pub async fn wait_for_handler_connection_attempt_count( - &self, - gateway_id: Id, - expected_count: u64, - ) { - self.inner - .wait_for_handler_connection_attempt_count(gateway_id, expected_count) - .await; - } - - pub fn gateway_notification_count(&self, gateway_id: Id) -> u64 { - self.inner.gateway_notification_count(gateway_id) - } - - pub async fn wait_for_gateway_notification_count( - &self, - gateway_id: Id, - expected_count: u64, - ) { - self.inner - .wait_for_gateway_notification_count(gateway_id, expected_count) - .await; - } - - pub async fn wait_until_listener_ready(&self) { - self.inner.wait_until_listener_ready().await; - } - - pub fn set_retry_delay(&self, retry_delay: Duration) { - self.inner.set_retry_delay(retry_delay); - } - } - - pub struct GatewayHandler { - inner: TestGatewayHandler, - } - - impl GatewayHandler { - pub fn new( - gateway: Gateway, - pool: PgPool, - events_tx: Sender, - peer_stats_tx: UnboundedSender, - certs_rx: Receiver>>, - socket_path: PathBuf, - ) -> anyhow::Result { - let inner = TestGatewayHandler::new( - gateway, - pool, - events_tx, - peer_stats_tx, - certs_rx, - socket_path, - )?; - Ok(Self { inner }) - } - - pub async fn handle_connection_once(&mut self) -> anyhow::Result<()> { - self.inner.handle_connection_once().await - } - } -} - pub struct GatewayManager { clients: Arc>>, pool: PgPool, handlers: JoinSet>, + #[cfg(test)] test_support: Option, tx: GatewayTxSet, } @@ -416,22 +262,104 @@ impl GatewayManager { clients: Arc::default(), handlers: JoinSet::new(), pool, + #[cfg(test)] test_support: None, tx, } } + #[cfg(test)] #[must_use] - fn new_for_test(pool: PgPool, tx: GatewayTxSet, control: TestGatewayManagerControl) -> Self { + fn new_for_test( + pool: PgPool, + tx: GatewayTxSet, + test_support: GatewayManagerTestSupport, + ) -> Self { Self { clients: Arc::default(), handlers: JoinSet::new(), pool, - test_support: Some(control.inner), + test_support: Some(test_support), tx, } } + fn mark_listener_ready_for_tests(&self) { + #[cfg(test)] + if let Some(test_support) = &self.test_support { + test_support.mark_listener_ready(); + } + } + + fn note_gateway_notification_for_tests(&self, maybe_gateway_id: Option) { + #[cfg(test)] + if let (Some(gateway_id), Some(test_support)) = + (maybe_gateway_id, self.test_support.as_ref()) + { + test_support.note_gateway_notification(gateway_id); + } + + #[cfg(not(test))] + let _ = maybe_gateway_id; + } + + fn manager_reconnect_delay(&self) -> Duration { + #[cfg(test)] + { + return self.test_support.as_ref().map_or( + GATEWAY_RECONNECT_DELAY, + GatewayManagerTestSupport::manager_reconnect_delay, + ); + } + + #[cfg(not(test))] + { + GATEWAY_RECONNECT_DELAY + } + } + + fn build_handler( + &self, + gateway: Gateway, + certs_rx: Receiver>>, + ) -> Result { + #[cfg(test)] + if let Some(test_support) = self.test_support.clone() { + test_support.note_handler_spawn_attempt(gateway.id); + + let mut gateway_handler = + if let Some(socket_path) = test_support.socket_path_for(&gateway) { + GatewayHandler::new_with_test_socket( + gateway, + self.pool.clone(), + self.tx.events.clone(), + self.tx.peer_stats.clone(), + certs_rx, + socket_path, + )? + } else { + GatewayHandler::new( + gateway, + self.pool.clone(), + self.tx.events.clone(), + self.tx.peer_stats.clone(), + certs_rx, + )? + }; + gateway_handler.attach_test_support(test_support); + + return Ok(gateway_handler); + } + + GatewayHandler::new( + gateway, + self.pool.clone(), + self.tx.events.clone(), + self.tx.peer_stats.clone(), + certs_rx, + ) + } + /// Bi-directional gRPC stream for communication with Defguard Gateway. pub async fn run(&mut self) -> Result<(), anyhow::Error> { let (certs_tx, certs_rx) = tokio::sync::watch::channel(Arc::new(HashMap::new())); @@ -459,9 +387,7 @@ impl GatewayManager { // Observe gateway URL changes. let mut listener = PgListener::connect_with(&self.pool).await?; listener.listen(GATEWAY_TABLE_TRIGGER).await?; - if let Some(test_support) = &self.test_support { - test_support.mark_listener_ready(); - } + self.mark_listener_ready_for_tests(); while let Ok(notification) = listener.recv().await { let payload = notification.payload(); match serde_json::from_str::>>(payload) { @@ -570,11 +496,7 @@ impl GatewayManager { } }; - if let (Some(gateway_id), Some(test_support)) = - (maybe_gateway_id, self.test_support.as_ref()) - { - test_support.note_gateway_notification(gateway_id); - } + self.note_gateway_notification_for_tests(maybe_gateway_id); } Err(err) => error!("Failed to de-serialize database notification object: {err}"), } @@ -593,42 +515,8 @@ impl GatewayManager { clients: Arc>>, certs_rx: Receiver>>, ) -> Result { - let maybe_test_support = self.test_support.clone(); - - if let Some(test_support) = &maybe_test_support { - test_support.note_handler_spawn_attempt(gateway.id); - } - - let mut gateway_handler = if let Some(socket_path) = maybe_test_support - .as_ref() - .and_then(|test_support| test_support.socket_path_for(&gateway)) - { - GatewayHandler::new_with_test_socket( - gateway, - self.pool.clone(), - self.tx.events.clone(), - self.tx.peer_stats.clone(), - certs_rx.clone(), - socket_path, - )? - } else { - GatewayHandler::new( - gateway, - self.pool.clone(), - self.tx.events.clone(), - self.tx.peer_stats.clone(), - certs_rx.clone(), - )? - }; - - if let Some(test_support) = maybe_test_support { - gateway_handler.attach_test_support(test_support); - } - - let manager_reconnect_delay = self.test_support.as_ref().map_or( - GATEWAY_RECONNECT_DELAY, - GatewayManagerTestSupport::manager_reconnect_delay, - ); + let mut gateway_handler = self.build_handler(gateway, certs_rx)?; + let manager_reconnect_delay = self.manager_reconnect_delay(); let abort_handle = self.handlers.spawn(async move { loop { if let Err(err) = gateway_handler diff --git a/crates/defguard_gateway_manager/tests/common/mod.rs b/crates/defguard_gateway_manager/src/tests/common/mod.rs similarity index 98% rename from crates/defguard_gateway_manager/tests/common/mod.rs rename to crates/defguard_gateway_manager/src/tests/common/mod.rs index f860f232f8..16ba51f528 100644 --- a/crates/defguard_gateway_manager/tests/common/mod.rs +++ b/crates/defguard_gateway_manager/src/tests/common/mod.rs @@ -20,10 +20,6 @@ use defguard_common::{ messages::peer_stats_update::PeerStatsUpdate, }; use defguard_core::grpc::GatewayEvent; -use defguard_gateway_manager::{ - GatewayTxSet, - test_support::{GatewayHandler, GatewayManagerControl}, -}; use defguard_proto::gateway::{ ConfigurationRequest, CoreRequest, CoreResponse, PeerStats, core_request, gateway_server, }; @@ -41,6 +37,8 @@ use tokio::{ use tokio_stream::{once, wrappers::UnboundedReceiverStream}; use tonic::{Request, Response, Status, Streaming, transport::Server}; +use crate::{GatewayManager, GatewayManagerTestSupport, GatewayTxSet, handler::GatewayHandler}; + const TEST_TIMEOUT: Duration = Duration::from_secs(2); macro_rules! assert_some { @@ -334,7 +332,7 @@ impl Drop for MockGatewayHarness { pub(crate) struct ManagerTestContext { pub(crate) pool: PgPool, - control: GatewayManagerControl, + control: GatewayManagerTestSupport, manager_task: Option>>, } @@ -347,7 +345,7 @@ impl ManagerTestContext { Self { pool, - control: GatewayManagerControl::new(), + control: GatewayManagerTestSupport::default(), manager_task: None, } } @@ -435,7 +433,7 @@ impl ManagerTestContext { let (events_tx, _) = broadcast::channel(16); let (peer_stats_tx, _peer_stats_rx) = mpsc::unbounded_channel(); let tx = GatewayTxSet::new(events_tx, peer_stats_tx); - let mut manager = self.control.new_manager(self.pool.clone(), tx); + let mut manager = GatewayManager::new_for_test(self.pool.clone(), tx, self.control.clone()); let manager_task = tokio::spawn(async move { manager.run().await }); timeout(TEST_TIMEOUT, self.control.wait_until_listener_ready()) @@ -501,7 +499,7 @@ impl HandlerTestContext { let (peer_stats_tx, peer_stats_rx) = mpsc::unbounded_channel(); let (_, certs_rx) = watch::channel(Arc::new(HashMap::new())); let mut mock_gateway = MockGatewayHarness::start().await; - let mut handler = GatewayHandler::new( + let mut handler = GatewayHandler::new_with_test_socket( gateway.clone(), pool.clone(), events_tx.clone(), diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler.rs similarity index 93% rename from crates/defguard_gateway_manager/tests/gateway_manager/handler.rs rename to crates/defguard_gateway_manager/src/tests/gateway_manager/handler.rs index b7b12df582..35896035f6 100644 --- a/crates/defguard_gateway_manager/tests/gateway_manager/handler.rs +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler.rs @@ -17,7 +17,7 @@ use self::support::{ create_device_for_network, create_device_info_for_current_network, enable_internal_mfa_for_network, expected_keepalive_interval, panic_unexpected, parse_test_ip, }; -use crate::common::{HandlerTestContext, build_peer_stats, reload_gateway}; +use crate::tests::common::{HandlerTestContext, build_peer_stats, reload_gateway}; include!("handler/handshake.rs"); include!("handler/lifecycle.rs"); diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/handler/device_events.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/device_events.rs similarity index 96% rename from crates/defguard_gateway_manager/tests/gateway_manager/handler/device_events.rs rename to crates/defguard_gateway_manager/src/tests/gateway_manager/handler/device_events.rs index f7c9042c39..67cee98e10 100644 --- a/crates/defguard_gateway_manager/tests/gateway_manager/handler/device_events.rs +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/device_events.rs @@ -72,11 +72,10 @@ async fn test_device_modified_for_network_produces_peer_modify_update( ) .await; - let mut network_device = - WireguardNetworkDevice::find(&context.pool, device.id, context.network.id) - .await - .expect("failed to load device network info") - .expect("expected device network info for modified device"); + let mut network_device = WireguardNetworkDevice::find(&context.pool, device.id, context.network.id) + .await + .expect("failed to load device network info") + .expect("expected device network info for modified device"); network_device.wireguard_ips = vec![parse_test_ip("10.10.0.21")]; network_device.preshared_key = Some("modified-preshared-key".to_string()); network_device diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/handler/firewall_events.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/firewall_events.rs similarity index 100% rename from crates/defguard_gateway_manager/tests/gateway_manager/handler/firewall_events.rs rename to crates/defguard_gateway_manager/src/tests/gateway_manager/handler/firewall_events.rs diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/handler/handshake.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/handshake.rs similarity index 100% rename from crates/defguard_gateway_manager/tests/gateway_manager/handler/handshake.rs rename to crates/defguard_gateway_manager/src/tests/gateway_manager/handler/handshake.rs diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/handler/lifecycle.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/lifecycle.rs similarity index 100% rename from crates/defguard_gateway_manager/tests/gateway_manager/handler/lifecycle.rs rename to crates/defguard_gateway_manager/src/tests/gateway_manager/handler/lifecycle.rs diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/handler/mfa.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/mfa.rs similarity index 100% rename from crates/defguard_gateway_manager/tests/gateway_manager/handler/mfa.rs rename to crates/defguard_gateway_manager/src/tests/gateway_manager/handler/mfa.rs diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/handler/network_events.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/network_events.rs similarity index 97% rename from crates/defguard_gateway_manager/tests/gateway_manager/handler/network_events.rs rename to crates/defguard_gateway_manager/src/tests/gateway_manager/handler/network_events.rs index dabe813c7a..e7aeddcdbe 100644 --- a/crates/defguard_gateway_manager/tests/gateway_manager/handler/network_events.rs +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/network_events.rs @@ -163,14 +163,8 @@ async fn test_only_matching_handler_receives_network_modified_update( 1400, 7, ); - matching_context - .mock_gateway_mut() - .expect_no_outbound() - .await; - unrelated_context - .mock_gateway_mut() - .expect_no_outbound() - .await; + matching_context.mock_gateway_mut().expect_no_outbound().await; + unrelated_context.mock_gateway_mut().expect_no_outbound().await; matching_context .finish() diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/handler/stats.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/stats.rs similarity index 100% rename from crates/defguard_gateway_manager/tests/gateway_manager/handler/stats.rs rename to crates/defguard_gateway_manager/src/tests/gateway_manager/handler/stats.rs diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/handler/support.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs similarity index 99% rename from crates/defguard_gateway_manager/tests/gateway_manager/handler/support.rs rename to crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs index 5f1f80ef4c..9316d29b0f 100644 --- a/crates/defguard_gateway_manager/tests/gateway_manager/handler/support.rs +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs @@ -19,7 +19,7 @@ use defguard_proto::gateway::{ }; use sqlx::postgres::PgConnectOptions; -use crate::common::HandlerTestContext; +use crate::tests::common::HandlerTestContext; macro_rules! assert_send_ok { ($result:expr, $message:literal) => { diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/manager.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/manager.rs similarity index 99% rename from crates/defguard_gateway_manager/tests/gateway_manager/manager.rs rename to crates/defguard_gateway_manager/src/tests/gateway_manager/manager.rs index 9855aeabc7..89481edfa0 100644 --- a/crates/defguard_gateway_manager/tests/gateway_manager/manager.rs +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/manager.rs @@ -5,7 +5,7 @@ use defguard_proto::gateway::core_response; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use tonic::Status; -use crate::common::{ +use crate::tests::common::{ ManagerTestContext, MockGatewayHarness, build_gateway_with_enabled, create_gateway, create_gateway_with_enabled, create_network, reload_gateway, unique_mock_gateway_socket_path, wait_for_gateway_connection_state, diff --git a/crates/defguard_gateway_manager/tests/gateway_manager/mod.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/mod.rs similarity index 100% rename from crates/defguard_gateway_manager/tests/gateway_manager/mod.rs rename to crates/defguard_gateway_manager/src/tests/gateway_manager/mod.rs diff --git a/crates/defguard_gateway_manager/src/tests/mod.rs b/crates/defguard_gateway_manager/src/tests/mod.rs new file mode 100644 index 0000000000..ef5f24f2c4 --- /dev/null +++ b/crates/defguard_gateway_manager/src/tests/mod.rs @@ -0,0 +1,2 @@ +mod common; +mod gateway_manager; diff --git a/crates/defguard_gateway_manager/tests/mod.rs b/crates/defguard_gateway_manager/tests/mod.rs deleted file mode 100644 index c49ae37809..0000000000 --- a/crates/defguard_gateway_manager/tests/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub(crate) mod common; -pub(crate) mod gateway_manager; From 2824eb5a13de4c224143b49254dbbc7460fdc62f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Thu, 19 Mar 2026 10:05:02 +0100 Subject: [PATCH 22/36] update query data --- ...d0b53ca805cff303dfe2a67880adeaca1e10e50bea9b9fc53e08845.json | 2 +- ...10a53745afd80c607fc52cb537dfded663e6bbee73d5f908f0ca89e.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.sqlx/query-47c406366d0b53ca805cff303dfe2a67880adeaca1e10e50bea9b9fc53e08845.json b/.sqlx/query-47c406366d0b53ca805cff303dfe2a67880adeaca1e10e50bea9b9fc53e08845.json index 0e36c94cdb..8a6776d365 100644 --- a/.sqlx/query-47c406366d0b53ca805cff303dfe2a67880adeaca1e10e50bea9b9fc53e08845.json +++ b/.sqlx/query-47c406366d0b53ca805cff303dfe2a67880adeaca1e10e50bea9b9fc53e08845.json @@ -82,7 +82,7 @@ false, false, true, - false, + true, false, false, false, diff --git a/.sqlx/query-59d048f8110a53745afd80c607fc52cb537dfded663e6bbee73d5f908f0ca89e.json b/.sqlx/query-59d048f8110a53745afd80c607fc52cb537dfded663e6bbee73d5f908f0ca89e.json index d62e957d33..548a89c193 100644 --- a/.sqlx/query-59d048f8110a53745afd80c607fc52cb537dfded663e6bbee73d5f908f0ca89e.json +++ b/.sqlx/query-59d048f8110a53745afd80c607fc52cb537dfded663e6bbee73d5f908f0ca89e.json @@ -80,7 +80,7 @@ false, false, true, - false, + true, false, false, false, From 09f5958ee9d9eaaabf45cfeaa04dddd806bce40f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Thu, 19 Mar 2026 13:06:15 +0100 Subject: [PATCH 23/36] remove test lock workaround --- .cargo/config.toml | 3 --- crates/defguard_core/src/enterprise/ldap/tests.rs | 13 +++++++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/.cargo/config.toml b/.cargo/config.toml index 4ad6414d7e..df4976b850 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,6 +1,3 @@ # Without running core on windows will result in stack overflow [target.x86_64-pc-windows-msvc] rustflags = ["-C", "link-args=/STACK:8388608"] - -[env] -RUST_TEST_THREADS = { value = "1", force = true } diff --git a/crates/defguard_core/src/enterprise/ldap/tests.rs b/crates/defguard_core/src/enterprise/ldap/tests.rs index 9c1eb613b3..43023432c2 100644 --- a/crates/defguard_core/src/enterprise/ldap/tests.rs +++ b/crates/defguard_core/src/enterprise/ldap/tests.rs @@ -262,6 +262,7 @@ fn test_using_username_as_rdn() { #[sqlx::test] async fn test_update_users_state(_: PgPoolOptions, options: PgConnectOptions) { + let _test_guard = test_state_lock().lock_owned().await; let mut ldap_conn = LDAPConnection::create().await.unwrap(); let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; @@ -1582,6 +1583,7 @@ fn test_extract_intersecting_users_no_matches(_: PgPoolOptions, options: PgConne #[sqlx::test] async fn test_fix_missing_user_path(_: PgPoolOptions, options: PgConnectOptions) { + let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; @@ -1990,6 +1992,7 @@ async fn test_sync_users_with_empty_paths_and_nested_ous( #[sqlx::test] async fn test_sync_simple_nested_ou_changes(_: PgPoolOptions, options: PgConnectOptions) { + let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; set_test_license_business(); @@ -2158,6 +2161,7 @@ async fn test_sync_defguard_authority_with_complex_nested_ous( _: PgPoolOptions, options: PgConnectOptions, ) { + let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; @@ -2377,6 +2381,7 @@ async fn test_sync_group_membership_with_intersecting_users( _: PgPoolOptions, options: PgConnectOptions, ) { + let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; set_test_license_business(); @@ -2576,6 +2581,7 @@ async fn test_ldap_login_does_not_create_user_when_user_license_limit_is_reached #[sqlx::test] async fn test_get_empty_user_path(_: PgPoolOptions, options: PgConnectOptions) { + let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; let user = make_test_user("testuser", None, None); @@ -3237,6 +3243,7 @@ async fn test_ldap_sync_allowed_with_empty_sync_groups( #[sqlx::test] async fn test_ldap_sync_allowed_with_inactive_user(_: PgPoolOptions, options: PgConnectOptions) { + let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; @@ -3251,6 +3258,7 @@ async fn test_ldap_sync_allowed_with_inactive_user(_: PgPoolOptions, options: Pg #[sqlx::test] async fn test_ldap_sync_allowed_with_unenrolled_user(_: PgPoolOptions, options: PgConnectOptions) { + let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; @@ -3270,6 +3278,7 @@ async fn test_ldap_sync_allowed_with_sync_groups_user_in_group( _: PgPoolOptions, options: PgConnectOptions, ) { + let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; @@ -3294,6 +3303,7 @@ async fn test_ldap_sync_allowed_with_sync_groups_user_not_in_group( _: PgPoolOptions, options: PgConnectOptions, ) { + let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; @@ -3319,6 +3329,7 @@ async fn test_ldap_sync_allowed_with_multiple_sync_groups( _: PgPoolOptions, options: PgConnectOptions, ) { + let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; @@ -3365,6 +3376,7 @@ async fn test_ldap_sync_allowed_enrolled_via_openid(_: PgPoolOptions, options: P #[sqlx::test] async fn test_ldap_sync_allowed_enrolled_via_ldap(_: PgPoolOptions, options: PgConnectOptions) { + let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; @@ -3381,6 +3393,7 @@ async fn test_ldap_sync_allowed_enrolled_via_ldap(_: PgPoolOptions, options: PgC #[sqlx::test] async fn test_ldap_sync_allowed_all_conditions_false(_: PgPoolOptions, options: PgConnectOptions) { + let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; From 92f008e9361fe1422e7f0e32866d9c1974d7aba7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Thu, 19 Mar 2026 16:22:15 +0100 Subject: [PATCH 24/36] review cleanup --- crates/defguard_common/Cargo.toml | 3 ++ crates/defguard_common/src/auth/claims.rs | 23 ++++------- crates/defguard_common/src/config.rs | 39 ++++++++++--------- crates/defguard_core/Cargo.toml | 1 + .../src/enterprise/ldap/model.rs | 34 ++++++++-------- crates/defguard_core/src/handlers/user.rs | 12 ++++-- .../tests/integration/grpc/common/mod.rs | 5 ++- crates/defguard_gateway_manager/Cargo.toml | 4 +- 8 files changed, 66 insertions(+), 55 deletions(-) diff --git a/crates/defguard_common/Cargo.toml b/crates/defguard_common/Cargo.toml index 12f0bbe1df..7d8e244724 100644 --- a/crates/defguard_common/Cargo.toml +++ b/crates/defguard_common/Cargo.toml @@ -7,6 +7,9 @@ homepage.workspace = true repository.workspace = true rust-version.workspace = true +[features] +test-support = [] + [dependencies] model_derive.workspace = true diff --git a/crates/defguard_common/src/auth/claims.rs b/crates/defguard_common/src/auth/claims.rs index 583add3eea..2ca0853abc 100644 --- a/crates/defguard_common/src/auth/claims.rs +++ b/crates/defguard_common/src/auth/claims.rs @@ -5,7 +5,7 @@ use std::{ }; use jsonwebtoken::{ - DecodingKey, EncodingKey, Header, Validation, decode, encode, errors::Error as JWTError, + decode, encode, errors::Error as JWTError, DecodingKey, EncodingKey, Header, Validation, }; use serde::{Deserialize, Serialize}; @@ -50,14 +50,6 @@ struct JwtSecretOverrides { } impl JwtSecretOverrides { - fn new(auth: String, gateway: String, yubibridge: String) -> Self { - Self { - auth, - gateway, - yubibridge, - } - } - fn secret_for(&self, claims_type: ClaimsType) -> &str { match claims_type { ClaimsType::Auth | ClaimsType::DesktopClient => &self.auth, @@ -132,20 +124,21 @@ fn secret_env(claims_type: ClaimsType) -> &'static str { } } +#[cfg(any(test, feature = "test-support"))] #[doc(hidden)] pub mod test_support { - use super::{JWT_SECRET_OVERRIDES, JwtSecretOverrides}; + use super::{JwtSecretOverrides, JWT_SECRET_OVERRIDES}; pub fn initialize_jwt_secret_overrides( auth_secret: impl Into, gateway_secret: impl Into, yubibridge_secret: impl Into, ) { - let secret_overrides = JwtSecretOverrides::new( - auth_secret.into(), - gateway_secret.into(), - yubibridge_secret.into(), - ); + let secret_overrides = JwtSecretOverrides { + auth: auth_secret.into(), + gateway: gateway_secret.into(), + yubibridge: yubibridge_secret.into(), + }; if let Err(secret_overrides) = JWT_SECRET_OVERRIDES.set(secret_overrides) { let existing_overrides = JWT_SECRET_OVERRIDES diff --git a/crates/defguard_common/src/config.rs b/crates/defguard_common/src/config.rs index 386c37ad36..fe04dda86e 100644 --- a/crates/defguard_common/src/config.rs +++ b/crates/defguard_common/src/config.rs @@ -3,18 +3,18 @@ use std::{net::IpAddr, sync::OnceLock}; use clap::{Args, Parser, Subcommand}; use humantime::Duration; use ipnetwork::IpNetwork; -use openidconnect::{JsonWebKeyId, core::CoreRsaPrivateSigningKey}; +use openidconnect::{core::CoreRsaPrivateSigningKey, JsonWebKeyId}; use reqwest::Url; use rsa::{ - RsaPrivateKey, pkcs1::{DecodeRsaPrivateKey, EncodeRsaPrivateKey}, pkcs8::{DecodePrivateKey, LineEnding}, traits::PublicKeyParts, + RsaPrivateKey, }; use secrecy::{ExposeSecret, SecretString}; use serde::Serialize; -use crate::{VERSION, db::models::Settings}; +use crate::{db::models::Settings, VERSION}; pub static SERVER_CONFIG: OnceLock = OnceLock::new(); @@ -228,10 +228,13 @@ impl DefGuardConfig { } // this is an ugly workaround to avoid `cargo test` args being captured by `clap` - #[allow(deprecated)] #[must_use] pub fn new_test_config() -> Self { - Self { + #[expect( + deprecated, + reason = "Test config still initializes compatibility-only deprecated fields" + )] + let config = Self { log_level: "info".to_string(), log_file: None, secret_key: None, @@ -259,14 +262,16 @@ impl DefGuardConfig { cookie_domain: None, cookie_insecure: false, cmd: None, - check_period: Duration::from(std::time::Duration::from_secs(12 * 3600)), - check_period_no_license: Duration::from(std::time::Duration::from_secs(24 * 3600)), - check_period_renewal_window: Duration::from(std::time::Duration::from_secs(3600)), + check_period: std::time::Duration::from_secs(12 * 3600).into(), + check_period_no_license: std::time::Duration::from_secs(24 * 3600).into(), + check_period_renewal_window: std::time::Duration::from_secs(3600).into(), http_bind_address: None, grpc_bind_address: None, adopt_gateway: None, adopt_edge: None, - } + }; + + config } /// Validate that the auto-adoption flags are consistent. @@ -364,15 +369,11 @@ mod tests { ); // only one flag at a time: must be an error - assert!( - make_config(Some("edge.example.com:8080"), None) - .validate_adopt_flags() - .is_err() - ); - assert!( - make_config(None, Some("gw.example.com:8080")) - .validate_adopt_flags() - .is_err() - ); + assert!(make_config(Some("edge.example.com:8080"), None) + .validate_adopt_flags() + .is_err()); + assert!(make_config(None, Some("gw.example.com:8080")) + .validate_adopt_flags() + .is_err()); } } diff --git a/crates/defguard_core/Cargo.toml b/crates/defguard_core/Cargo.toml index 119d4a8c7c..9861923d7d 100644 --- a/crates/defguard_core/Cargo.toml +++ b/crates/defguard_core/Cargo.toml @@ -87,6 +87,7 @@ async-stream = "0.3" [dev-dependencies] claims.workspace = true +defguard_common = { workspace = true, features = ["test-support"] } hyper-util = "0.1" matches.workspace = true reqwest = { version = "0.12", features = [ diff --git a/crates/defguard_core/src/enterprise/ldap/model.rs b/crates/defguard_core/src/enterprise/ldap/model.rs index 8291aad215..f5b1e799d3 100644 --- a/crates/defguard_core/src/enterprise/ldap/model.rs +++ b/crates/defguard_core/src/enterprise/ldap/model.rs @@ -1,4 +1,4 @@ -use std::collections::HashSet; +use std::{collections::HashSet, future::Future}; use defguard_common::db::{ Id, @@ -247,24 +247,26 @@ pub(crate) fn maybe_update_rdn(user: &mut User) { /// - he is in a group that is allowed to be synced or no such groups are configured /// - he is active (not disabled) /// - he is enrolled -pub(crate) async fn ldap_sync_allowed_for_user<'e, E>( - user: &User, +pub(crate) fn ldap_sync_allowed_for_user<'a, 'e, E>( + user: &'a User, executor: E, -) -> sqlx::Result +) -> impl Future> + Send + 'a where - E: Acquire<'e, Database = Postgres>, + E: Acquire<'e, Database = Postgres> + Send + 'a, { - let mut connection = executor.acquire().await?; - let sync_groups = Settings::get(&mut *connection) - .await? - .unwrap_or_default() - .ldap_sync_groups; - let my_groups = user.member_of(&mut *connection).await?; - Ok( - (sync_groups.is_empty() || my_groups.iter().any(|g| sync_groups.contains(&g.name))) - && user.is_active - && user.is_enrolled(), - ) + async move { + let mut connection = executor.acquire().await?; + let sync_groups = Settings::get(&mut *connection) + .await? + .unwrap_or_default() + .ldap_sync_groups; + let my_groups = user.member_of(&mut *connection).await?; + Ok( + (sync_groups.is_empty() || my_groups.iter().any(|g| sync_groups.contains(&g.name))) + && user.is_active + && user.is_enrolled(), + ) + } } pub(super) async fn get_users_without_ldap_path<'e, E>(executor: E) -> sqlx::Result>> diff --git a/crates/defguard_core/src/handlers/user.rs b/crates/defguard_core/src/handlers/user.rs index 16eb38dfcf..d65d1e0ada 100644 --- a/crates/defguard_core/src/handlers/user.rs +++ b/crates/defguard_core/src/handlers/user.rs @@ -17,7 +17,7 @@ use defguard_common::{ use defguard_mail::{Mail, templates}; use humantime::parse_duration; use serde_json::json; -use sqlx::PgPool; +use sqlx::{Acquire, PgPool}; use utoipa::ToSchema; use super::{ @@ -724,7 +724,10 @@ pub async fn modify_user( let status_changing = user_info.is_active != user.is_active; let mut transaction = appstate.pool.begin().await?; - let ldap_sync_allowed = ldap_sync_allowed_for_user(&user, &appstate.pool).await?; + let ldap_sync_allowed = { + let transaction_connection = transaction.acquire().await?; + ldap_sync_allowed_for_user(&user, transaction_connection).await? + }; // remove authorized apps if needed let request_app_ids: Vec = user_info @@ -891,7 +894,10 @@ pub async fn delete_user( session.user.username ); let mut transaction = appstate.pool.begin().await?; - let user_for_ldap = if ldap_sync_allowed_for_user(&user, &appstate.pool).await? { + let user_for_ldap = if { + let transaction_connection = transaction.acquire().await?; + ldap_sync_allowed_for_user(&user, transaction_connection).await? + } { Some(user.clone().as_noid()) } else { None diff --git a/crates/defguard_core/tests/integration/grpc/common/mod.rs b/crates/defguard_core/tests/integration/grpc/common/mod.rs index 68e9e4baaf..b640a4de0e 100644 --- a/crates/defguard_core/tests/integration/grpc/common/mod.rs +++ b/crates/defguard_core/tests/integration/grpc/common/mod.rs @@ -2,7 +2,7 @@ use std::sync::{Arc, Mutex}; use defguard_common::{ auth::claims::{Claims, ClaimsType, test_support::initialize_jwt_secret_overrides}, - db::setup_pool, + db::{models::settings::initialize_current_settings, setup_pool}, }; use defguard_core::{ auth::failed_login::FailedLoginMap, @@ -94,6 +94,9 @@ pub(crate) async fn create_client_channel(client_stream: DuplexStream) -> Channe pub(crate) async fn make_grpc_test_server(pool: &PgPool) -> TestGrpcServer { initialize_jwt_secrets(); initialize_users(pool).await; + initialize_current_settings(pool) + .await + .expect("failed to initialize current settings for gRPC tests"); let (client_stream, server_stream) = tokio::io::duplex(1024); let client_channel = create_client_channel(client_stream).await; diff --git a/crates/defguard_gateway_manager/Cargo.toml b/crates/defguard_gateway_manager/Cargo.toml index fe0d595112..9afb53828e 100644 --- a/crates/defguard_gateway_manager/Cargo.toml +++ b/crates/defguard_gateway_manager/Cargo.toml @@ -17,7 +17,6 @@ defguard_version.workspace = true anyhow.workspace = true chrono.workspace = true -hyper-util = "0.1" hyper-rustls.workspace = true reqwest.workspace = true semver.workspace = true @@ -29,3 +28,6 @@ tokio-stream.workspace = true tonic.workspace = true tower.workspace = true tracing.workspace = true + +[dev-dependencies] +hyper-util = "0.1" From 369e35baa7d909647b083473b1eaf4b5029def13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Thu, 19 Mar 2026 16:28:39 +0100 Subject: [PATCH 25/36] formatting --- crates/defguard_common/src/auth/claims.rs | 4 ++-- crates/defguard_common/src/config.rs | 22 +++++++++++++--------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/crates/defguard_common/src/auth/claims.rs b/crates/defguard_common/src/auth/claims.rs index 2ca0853abc..031c5338ef 100644 --- a/crates/defguard_common/src/auth/claims.rs +++ b/crates/defguard_common/src/auth/claims.rs @@ -5,7 +5,7 @@ use std::{ }; use jsonwebtoken::{ - decode, encode, errors::Error as JWTError, DecodingKey, EncodingKey, Header, Validation, + DecodingKey, EncodingKey, Header, Validation, decode, encode, errors::Error as JWTError, }; use serde::{Deserialize, Serialize}; @@ -127,7 +127,7 @@ fn secret_env(claims_type: ClaimsType) -> &'static str { #[cfg(any(test, feature = "test-support"))] #[doc(hidden)] pub mod test_support { - use super::{JwtSecretOverrides, JWT_SECRET_OVERRIDES}; + use super::{JWT_SECRET_OVERRIDES, JwtSecretOverrides}; pub fn initialize_jwt_secret_overrides( auth_secret: impl Into, diff --git a/crates/defguard_common/src/config.rs b/crates/defguard_common/src/config.rs index fe04dda86e..13a5cea975 100644 --- a/crates/defguard_common/src/config.rs +++ b/crates/defguard_common/src/config.rs @@ -3,18 +3,18 @@ use std::{net::IpAddr, sync::OnceLock}; use clap::{Args, Parser, Subcommand}; use humantime::Duration; use ipnetwork::IpNetwork; -use openidconnect::{core::CoreRsaPrivateSigningKey, JsonWebKeyId}; +use openidconnect::{JsonWebKeyId, core::CoreRsaPrivateSigningKey}; use reqwest::Url; use rsa::{ + RsaPrivateKey, pkcs1::{DecodeRsaPrivateKey, EncodeRsaPrivateKey}, pkcs8::{DecodePrivateKey, LineEnding}, traits::PublicKeyParts, - RsaPrivateKey, }; use secrecy::{ExposeSecret, SecretString}; use serde::Serialize; -use crate::{db::models::Settings, VERSION}; +use crate::{VERSION, db::models::Settings}; pub static SERVER_CONFIG: OnceLock = OnceLock::new(); @@ -369,11 +369,15 @@ mod tests { ); // only one flag at a time: must be an error - assert!(make_config(Some("edge.example.com:8080"), None) - .validate_adopt_flags() - .is_err()); - assert!(make_config(None, Some("gw.example.com:8080")) - .validate_adopt_flags() - .is_err()); + assert!( + make_config(Some("edge.example.com:8080"), None) + .validate_adopt_flags() + .is_err() + ); + assert!( + make_config(None, Some("gw.example.com:8080")) + .validate_adopt_flags() + .is_err() + ); } } From 08dad14fee41ce621edc60245b9f652d1a56ef40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Thu, 19 Mar 2026 18:03:00 +0100 Subject: [PATCH 26/36] remove unnecessary changes --- .../src/enterprise/ldap/model.rs | 32 ++++++++----------- crates/defguard_core/src/handlers/user.rs | 5 +-- 2 files changed, 14 insertions(+), 23 deletions(-) diff --git a/crates/defguard_core/src/enterprise/ldap/model.rs b/crates/defguard_core/src/enterprise/ldap/model.rs index f5b1e799d3..77a9cc61e4 100644 --- a/crates/defguard_core/src/enterprise/ldap/model.rs +++ b/crates/defguard_core/src/enterprise/ldap/model.rs @@ -1,11 +1,11 @@ -use std::{collections::HashSet, future::Future}; +use std::collections::HashSet; use defguard_common::db::{ Id, models::{Settings, User}, }; use ldap3::{Mod, SearchEntry}; -use sqlx::{Acquire, PgExecutor, Postgres}; +use sqlx::PgExecutor; use super::{LDAPConfig, error::LdapError}; use crate::{handlers::user::check_username, hashset}; @@ -247,26 +247,20 @@ pub(crate) fn maybe_update_rdn(user: &mut User) { /// - he is in a group that is allowed to be synced or no such groups are configured /// - he is active (not disabled) /// - he is enrolled -pub(crate) fn ldap_sync_allowed_for_user<'a, 'e, E>( - user: &'a User, +pub(crate) async fn ldap_sync_allowed_for_user<'e, E>( + user: &User, executor: E, -) -> impl Future> + Send + 'a +) -> sqlx::Result where - E: Acquire<'e, Database = Postgres> + Send + 'a, + E: PgExecutor<'e>, { - async move { - let mut connection = executor.acquire().await?; - let sync_groups = Settings::get(&mut *connection) - .await? - .unwrap_or_default() - .ldap_sync_groups; - let my_groups = user.member_of(&mut *connection).await?; - Ok( - (sync_groups.is_empty() || my_groups.iter().any(|g| sync_groups.contains(&g.name))) - && user.is_active - && user.is_enrolled(), - ) - } + let sync_groups = Settings::get_current_settings().ldap_sync_groups; + let my_groups = user.member_of(executor).await?; + Ok( + (sync_groups.is_empty() || my_groups.iter().any(|g| sync_groups.contains(&g.name))) + && user.is_active + && user.is_enrolled(), + ) } pub(super) async fn get_users_without_ldap_path<'e, E>(executor: E) -> sqlx::Result>> diff --git a/crates/defguard_core/src/handlers/user.rs b/crates/defguard_core/src/handlers/user.rs index d65d1e0ada..425e39c61d 100644 --- a/crates/defguard_core/src/handlers/user.rs +++ b/crates/defguard_core/src/handlers/user.rs @@ -894,10 +894,7 @@ pub async fn delete_user( session.user.username ); let mut transaction = appstate.pool.begin().await?; - let user_for_ldap = if { - let transaction_connection = transaction.acquire().await?; - ldap_sync_allowed_for_user(&user, transaction_connection).await? - } { + let user_for_ldap = if ldap_sync_allowed_for_user(&user, &mut *transaction).await? { Some(user.clone().as_noid()) } else { None From 0f0314b88af5bc93ddd8f70dcb924c310e096393 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Fri, 20 Mar 2026 12:10:27 +0100 Subject: [PATCH 27/36] remove unnecessary test feature --- crates/defguard_common/Cargo.toml | 3 - crates/defguard_common/src/auth/claims.rs | 67 ++--------------------- crates/defguard_core/Cargo.toml | 1 - 3 files changed, 6 insertions(+), 65 deletions(-) diff --git a/crates/defguard_common/Cargo.toml b/crates/defguard_common/Cargo.toml index 7d8e244724..12f0bbe1df 100644 --- a/crates/defguard_common/Cargo.toml +++ b/crates/defguard_common/Cargo.toml @@ -7,9 +7,6 @@ homepage.workspace = true repository.workspace = true rust-version.workspace = true -[features] -test-support = [] - [dependencies] model_derive.workspace = true diff --git a/crates/defguard_common/src/auth/claims.rs b/crates/defguard_common/src/auth/claims.rs index 031c5338ef..bca84e18e7 100644 --- a/crates/defguard_common/src/auth/claims.rs +++ b/crates/defguard_common/src/auth/claims.rs @@ -1,6 +1,5 @@ use std::{ env, - sync::OnceLock, time::{Duration, SystemTime}, }; @@ -14,8 +13,6 @@ pub static AUTH_SECRET_ENV: &str = "DEFGUARD_AUTH_SECRET"; pub static GATEWAY_SECRET_ENV: &str = "DEFGUARD_GATEWAY_SECRET"; pub static YUBIBRIDGE_SECRET_ENV: &str = "DEFGUARD_YUBIBRIDGE_SECRET"; -static JWT_SECRET_OVERRIDES: OnceLock = OnceLock::new(); - #[derive(Clone, Copy, Default)] pub enum ClaimsType { #[default] @@ -42,23 +39,6 @@ pub struct Claims { pub nbf: u64, } -#[derive(Clone, Debug, PartialEq, Eq)] -struct JwtSecretOverrides { - auth: String, - gateway: String, - yubibridge: String, -} - -impl JwtSecretOverrides { - fn secret_for(&self, claims_type: ClaimsType) -> &str { - match claims_type { - ClaimsType::Auth | ClaimsType::DesktopClient => &self.auth, - ClaimsType::Gateway => &self.gateway, - ClaimsType::YubiBridge => &self.yubibridge, - } - } -} - impl Claims { #[must_use] pub fn new(claims_type: ClaimsType, sub: String, client_id: String, duration: u64) -> Self { @@ -84,11 +64,12 @@ impl Claims { } fn get_secret(claims_type: ClaimsType) -> String { - if let Some(secret_overrides) = JWT_SECRET_OVERRIDES.get() { - return secret_overrides.secret_for(claims_type).to_string(); - } - - env::var(secret_env(claims_type)).unwrap_or_default() + let env_var = match claims_type { + ClaimsType::Auth | ClaimsType::DesktopClient => AUTH_SECRET_ENV, + ClaimsType::Gateway => GATEWAY_SECRET_ENV, + ClaimsType::YubiBridge => YUBIBRIDGE_SECRET_ENV, + }; + env::var(env_var).unwrap_or_default() } /// Convert claims to JWT. @@ -115,39 +96,3 @@ impl Claims { .map(|data| data.claims) } } - -fn secret_env(claims_type: ClaimsType) -> &'static str { - match claims_type { - ClaimsType::Auth | ClaimsType::DesktopClient => AUTH_SECRET_ENV, - ClaimsType::Gateway => GATEWAY_SECRET_ENV, - ClaimsType::YubiBridge => YUBIBRIDGE_SECRET_ENV, - } -} - -#[cfg(any(test, feature = "test-support"))] -#[doc(hidden)] -pub mod test_support { - use super::{JWT_SECRET_OVERRIDES, JwtSecretOverrides}; - - pub fn initialize_jwt_secret_overrides( - auth_secret: impl Into, - gateway_secret: impl Into, - yubibridge_secret: impl Into, - ) { - let secret_overrides = JwtSecretOverrides { - auth: auth_secret.into(), - gateway: gateway_secret.into(), - yubibridge: yubibridge_secret.into(), - }; - - if let Err(secret_overrides) = JWT_SECRET_OVERRIDES.set(secret_overrides) { - let existing_overrides = JWT_SECRET_OVERRIDES - .get() - .expect("JWT secret overrides should be initialized"); - assert_eq!( - existing_overrides, &secret_overrides, - "JWT secret overrides already initialized with different values" - ); - } - } -} diff --git a/crates/defguard_core/Cargo.toml b/crates/defguard_core/Cargo.toml index 9861923d7d..119d4a8c7c 100644 --- a/crates/defguard_core/Cargo.toml +++ b/crates/defguard_core/Cargo.toml @@ -87,7 +87,6 @@ async-stream = "0.3" [dev-dependencies] claims.workspace = true -defguard_common = { workspace = true, features = ["test-support"] } hyper-util = "0.1" matches.workspace = true reqwest = { version = "0.12", features = [ From 3726fb6fdbc87e60fb990d9a1df49640bb5ef253 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Fri, 20 Mar 2026 12:43:41 +0100 Subject: [PATCH 28/36] remove unnecessary test locks --- .../src/enterprise/directory_sync/tests.rs | 17 --------------- .../firewall/tests/all_locations.rs | 7 +------ .../enterprise/firewall/tests/destination.rs | 9 +------- .../src/enterprise/firewall/tests/mod.rs | 5 ----- .../src/enterprise/ldap/tests.rs | 21 ------------------- crates/defguard_core/src/enterprise/mod.rs | 15 ------------- .../tests/integration/grpc/common/mod.rs | 13 +----------- 7 files changed, 3 insertions(+), 84 deletions(-) diff --git a/crates/defguard_core/src/enterprise/directory_sync/tests.rs b/crates/defguard_core/src/enterprise/directory_sync/tests.rs index f8d0772f1e..6bed257fd3 100644 --- a/crates/defguard_core/src/enterprise/directory_sync/tests.rs +++ b/crates/defguard_core/src/enterprise/directory_sync/tests.rs @@ -23,7 +23,6 @@ mod test { db::models::openid_provider::{DirectorySyncTarget, OpenIdProviderKind}, license::{License, LicenseTier, set_cached_license}, limits::{get_counts, update_counts}, - test_state_lock, }, grpc::proto::enterprise::license::LicenseLimits, }; @@ -143,7 +142,6 @@ mod test { // Keep both users and admins #[sqlx::test] async fn test_users_state_keep_both(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let config = DefGuardConfig::new_test_config(); @@ -184,7 +182,6 @@ mod test { // Delete users, keep admins #[sqlx::test] async fn test_users_state_delete_users(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let config = DefGuardConfig::new_test_config(); @@ -228,7 +225,6 @@ mod test { } #[sqlx::test] async fn test_users_state_delete_admins(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let config = DefGuardConfig::new_test_config(); @@ -284,7 +280,6 @@ mod test { #[sqlx::test] async fn test_users_state_delete_both(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let config = DefGuardConfig::new_test_config(); @@ -354,7 +349,6 @@ mod test { #[sqlx::test] async fn test_users_state_disable_users(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let config = DefGuardConfig::new_test_config(); @@ -438,7 +432,6 @@ mod test { } #[sqlx::test] async fn test_users_state_disable_admins(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let config = DefGuardConfig::new_test_config(); @@ -517,7 +510,6 @@ mod test { #[sqlx::test] async fn test_users_groups(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let config = DefGuardConfig::new_test_config(); @@ -575,7 +567,6 @@ mod test { #[sqlx::test] async fn test_sync_user_groups(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let config = DefGuardConfig::new_test_config(); @@ -605,7 +596,6 @@ mod test { #[sqlx::test] async fn test_sync_target_users(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let config = DefGuardConfig::new_test_config(); @@ -631,7 +621,6 @@ mod test { #[sqlx::test] async fn test_sync_target_all(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let config = DefGuardConfig::new_test_config(); @@ -683,7 +672,6 @@ mod test { #[sqlx::test] async fn test_sync_target_groups(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let config = DefGuardConfig::new_test_config(); @@ -712,7 +700,6 @@ mod test { #[sqlx::test] async fn test_sync_unassign_last_admin_group(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let config = DefGuardConfig::new_test_config(); @@ -760,7 +747,6 @@ mod test { #[sqlx::test] async fn test_sync_delete_last_admin_user(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let config = DefGuardConfig::new_test_config(); @@ -801,7 +787,6 @@ mod test { #[sqlx::test] async fn test_users_no_prefetch(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let config = DefGuardConfig::new_test_config(); @@ -836,7 +821,6 @@ mod test { #[sqlx::test] async fn test_users_prefetch(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let config = DefGuardConfig::new_test_config(); @@ -874,7 +858,6 @@ mod test { _: PgPoolOptions, options: PgConnectOptions, ) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let config = DefGuardConfig::new_test_config(); diff --git a/crates/defguard_core/src/enterprise/firewall/tests/all_locations.rs b/crates/defguard_core/src/enterprise/firewall/tests/all_locations.rs index f69c2d8b3a..3ee0ea139e 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/all_locations.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/all_locations.rs @@ -6,16 +6,13 @@ use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use crate::enterprise::{ db::models::acl::{AclRule, AclRuleNetwork, RuleState}, firewall::{ - tests::{ - create_test_users_and_devices, lock_enterprise_test_state, set_test_license_business, - }, + tests::{create_test_users_and_devices, set_test_license_business}, try_get_location_firewall_config, }, }; #[sqlx::test] async fn test_acl_rules_all_locations_ipv4(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = lock_enterprise_test_state().await; let pool = setup_pool(options).await; let mut rng = thread_rng(); set_test_license_business(); @@ -115,7 +112,6 @@ async fn test_acl_rules_all_locations_ipv4(_: PgPoolOptions, options: PgConnectO #[sqlx::test] async fn test_acl_rules_all_locations_ipv6(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = lock_enterprise_test_state().await; set_test_license_business(); let pool = setup_pool(options).await; let mut rng = thread_rng(); @@ -215,7 +211,6 @@ async fn test_acl_rules_all_locations_ipv6(_: PgPoolOptions, options: PgConnectO #[sqlx::test] async fn test_acl_rules_all_locations_ipv4_and_ipv6(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = lock_enterprise_test_state().await; set_test_license_business(); let pool = setup_pool(options).await; let mut rng = thread_rng(); diff --git a/crates/defguard_core/src/enterprise/firewall/tests/destination.rs b/crates/defguard_core/src/enterprise/firewall/tests/destination.rs index 106c246201..213bbd1a18 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/destination.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/destination.rs @@ -10,10 +10,7 @@ use defguard_proto::enterprise::firewall::{ use rand::thread_rng; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; -use super::{ - create_acl_rule, create_test_users_and_devices, lock_enterprise_test_state, - set_test_license_business, -}; +use super::{create_acl_rule, create_test_users_and_devices, set_test_license_business}; use crate::enterprise::{ db::models::acl::{ AclAlias, AclAliasDestinationRange, AclRule, AclRuleDestinationRange, AliasKind, RuleState, @@ -148,7 +145,6 @@ async fn test_any_address_overwrites_manual_destination( _: PgPoolOptions, options: PgConnectOptions, ) { - let _test_guard = lock_enterprise_test_state().await; set_test_license_business(); let pool = setup_pool(options).await; @@ -230,7 +226,6 @@ async fn test_any_address_overwrites_destination_alias_addrs( _: PgPoolOptions, options: PgConnectOptions, ) { - let _test_guard = lock_enterprise_test_state().await; set_test_license_business(); let pool = setup_pool(options).await; @@ -330,7 +325,6 @@ async fn test_manual_destination_includes_component_alias_address_range( _: PgPoolOptions, options: PgConnectOptions, ) { - let _test_guard = lock_enterprise_test_state().await; set_test_license_business(); let pool = setup_pool(options).await; @@ -433,7 +427,6 @@ async fn test_manual_destination_merges_rule_and_component_alias_address_ranges( _: PgPoolOptions, options: PgConnectOptions, ) { - let _test_guard = lock_enterprise_test_state().await; set_test_license_business(); let pool = setup_pool(options).await; diff --git a/crates/defguard_core/src/enterprise/firewall/tests/mod.rs b/crates/defguard_core/src/enterprise/firewall/tests/mod.rs index 9e58f0d1ab..6a78e4b09e 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/mod.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/mod.rs @@ -28,7 +28,6 @@ use crate::enterprise::{ }, firewall::try_get_location_firewall_config, license::{License, LicenseTier, set_cached_license}, - test_state_lock, }; mod all_locations; @@ -63,10 +62,6 @@ fn set_test_license_business() { set_cached_license(Some(license)); } -pub(super) async fn lock_enterprise_test_state() -> tokio::sync::OwnedMutexGuard<()> { - test_state_lock().lock_owned().await -} - fn random_user_with_id(rng: &mut R, id: Id) -> User { let mut user: User = rng.r#gen(); user.id = id; diff --git a/crates/defguard_core/src/enterprise/ldap/tests.rs b/crates/defguard_core/src/enterprise/ldap/tests.rs index 43023432c2..ccc7d3b1a9 100644 --- a/crates/defguard_core/src/enterprise/ldap/tests.rs +++ b/crates/defguard_core/src/enterprise/ldap/tests.rs @@ -17,7 +17,6 @@ use crate::{ enterprise::{ license::{License, LicenseTier, set_cached_license}, limits::get_counts, - test_state_lock, }, grpc::proto::enterprise::license::LicenseLimits, }; @@ -262,7 +261,6 @@ fn test_using_username_as_rdn() { #[sqlx::test] async fn test_update_users_state(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = test_state_lock().lock_owned().await; let mut ldap_conn = LDAPConnection::create().await.unwrap(); let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; @@ -1583,7 +1581,6 @@ fn test_extract_intersecting_users_no_matches(_: PgPoolOptions, options: PgConne #[sqlx::test] async fn test_fix_missing_user_path(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; @@ -1722,7 +1719,6 @@ async fn test_sync_users_with_empty_paths_and_nested_ous( _: PgPoolOptions, options: PgConnectOptions, ) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; set_test_license_business(); @@ -1992,7 +1988,6 @@ async fn test_sync_users_with_empty_paths_and_nested_ous( #[sqlx::test] async fn test_sync_simple_nested_ou_changes(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; set_test_license_business(); @@ -2071,7 +2066,6 @@ async fn test_sync_incremental_with_nested_ou_conflicts( _: PgPoolOptions, options: PgConnectOptions, ) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; set_test_license_business(); @@ -2161,7 +2155,6 @@ async fn test_sync_defguard_authority_with_complex_nested_ous( _: PgPoolOptions, options: PgConnectOptions, ) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; @@ -2291,7 +2284,6 @@ async fn test_sync_defguard_authority_with_complex_nested_ous( #[sqlx::test] async fn test_sync_with_ou_path_edge_cases(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; let mut ldap_conn = super::LDAPConnection::create().await.unwrap(); @@ -2381,7 +2373,6 @@ async fn test_sync_group_membership_with_intersecting_users( _: PgPoolOptions, options: PgConnectOptions, ) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; set_test_license_business(); @@ -2462,7 +2453,6 @@ async fn test_sync_ldap_to_defguard_does_not_exceed_user_license_limit( _: PgPoolOptions, options: PgConnectOptions, ) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; @@ -2520,7 +2510,6 @@ async fn test_ldap_login_does_not_create_user_when_user_license_limit_is_reached _: PgPoolOptions, options: PgConnectOptions, ) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; @@ -2581,7 +2570,6 @@ async fn test_ldap_login_does_not_create_user_when_user_license_limit_is_reached #[sqlx::test] async fn test_get_empty_user_path(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; let user = make_test_user("testuser", None, None); @@ -3227,7 +3215,6 @@ async fn test_ldap_sync_allowed_with_empty_sync_groups( _: PgPoolOptions, options: PgConnectOptions, ) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; set_test_license_business(); @@ -3243,7 +3230,6 @@ async fn test_ldap_sync_allowed_with_empty_sync_groups( #[sqlx::test] async fn test_ldap_sync_allowed_with_inactive_user(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; @@ -3258,7 +3244,6 @@ async fn test_ldap_sync_allowed_with_inactive_user(_: PgPoolOptions, options: Pg #[sqlx::test] async fn test_ldap_sync_allowed_with_unenrolled_user(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; @@ -3278,7 +3263,6 @@ async fn test_ldap_sync_allowed_with_sync_groups_user_in_group( _: PgPoolOptions, options: PgConnectOptions, ) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; @@ -3303,7 +3287,6 @@ async fn test_ldap_sync_allowed_with_sync_groups_user_not_in_group( _: PgPoolOptions, options: PgConnectOptions, ) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; @@ -3329,7 +3312,6 @@ async fn test_ldap_sync_allowed_with_multiple_sync_groups( _: PgPoolOptions, options: PgConnectOptions, ) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; @@ -3358,7 +3340,6 @@ async fn test_ldap_sync_allowed_with_multiple_sync_groups( #[sqlx::test] async fn test_ldap_sync_allowed_enrolled_via_openid(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; set_test_license_business(); @@ -3376,7 +3357,6 @@ async fn test_ldap_sync_allowed_enrolled_via_openid(_: PgPoolOptions, options: P #[sqlx::test] async fn test_ldap_sync_allowed_enrolled_via_ldap(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; @@ -3393,7 +3373,6 @@ async fn test_ldap_sync_allowed_enrolled_via_ldap(_: PgPoolOptions, options: PgC #[sqlx::test] async fn test_ldap_sync_allowed_all_conditions_false(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = test_state_lock().lock_owned().await; let pool = setup_pool(options).await; let _ = initialize_current_settings(&pool).await; diff --git a/crates/defguard_core/src/enterprise/mod.rs b/crates/defguard_core/src/enterprise/mod.rs index 53943c1020..1c4855ae82 100644 --- a/crates/defguard_core/src/enterprise/mod.rs +++ b/crates/defguard_core/src/enterprise/mod.rs @@ -10,15 +10,9 @@ pub mod limits; pub mod snat; mod utils; -#[cfg(test)] -use std::sync::{Arc, OnceLock}; - use license::{get_cached_license, validate_license}; use limits::get_counts; -#[cfg(test)] -use tokio::sync::Mutex; - use crate::enterprise::license::LicenseTier; /// Helper function to gate features which require a base license (Team or Business tier) @@ -46,15 +40,6 @@ fn is_license_tier_active(tier: LicenseTier) -> bool { validation_result.is_ok() } -#[cfg(test)] -pub(crate) fn test_state_lock() -> Arc> { - static TEST_STATE_LOCK: OnceLock>> = OnceLock::new(); - - TEST_STATE_LOCK - .get_or_init(|| Arc::new(Mutex::new(()))) - .clone() -} - #[cfg(test)] mod test { use chrono::{TimeDelta, Utc}; diff --git a/crates/defguard_core/tests/integration/grpc/common/mod.rs b/crates/defguard_core/tests/integration/grpc/common/mod.rs index b640a4de0e..40649370c7 100644 --- a/crates/defguard_core/tests/integration/grpc/common/mod.rs +++ b/crates/defguard_core/tests/integration/grpc/common/mod.rs @@ -1,7 +1,7 @@ use std::sync::{Arc, Mutex}; use defguard_common::{ - auth::claims::{Claims, ClaimsType, test_support::initialize_jwt_secret_overrides}, + auth::claims::{Claims, ClaimsType}, db::{models::settings::initialize_current_settings, setup_pool}, }; use defguard_core::{ @@ -92,7 +92,6 @@ pub(crate) async fn create_client_channel(client_stream: DuplexStream) -> Channe } pub(crate) async fn make_grpc_test_server(pool: &PgPool) -> TestGrpcServer { - initialize_jwt_secrets(); initialize_users(pool).await; initialize_current_settings(pool) .await @@ -124,7 +123,6 @@ pub(crate) async fn make_grpc_test_server(pool: &PgPool) -> TestGrpcServer { } pub(crate) fn create_yubibridge_jwt(username: &str) -> String { - initialize_jwt_secrets(); Claims::new( ClaimsType::YubiBridge, username.to_string(), @@ -136,7 +134,6 @@ pub(crate) fn create_yubibridge_jwt(username: &str) -> String { } pub(crate) fn create_gateway_jwt(username: &str, client_id: &str) -> String { - initialize_jwt_secrets(); Claims::new( ClaimsType::Gateway, username.to_string(), @@ -163,11 +160,3 @@ pub(crate) fn worker_request(message: T, username: &str) -> Request { add_worker_auth_metadata(&mut request, username); request } - -fn initialize_jwt_secrets() { - initialize_jwt_secret_overrides( - "defguard-test-auth-secret", - "defguard-test-gateway-secret", - "defguard-test-yubibridge-secret", - ); -} From 62d2d6b047588097c77f676a13caee03af65e215 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Mon, 23 Mar 2026 11:47:28 +0100 Subject: [PATCH 29/36] update dependencies --- Cargo.lock | 24 ++++++++++++------------ flake.lock | 12 ++++++------ 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7128bb1124..fcdc270ece 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3086,9 +3086,9 @@ dependencies = [ [[package]] name = "iri-string" -version = "0.7.10" +version = "0.7.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a" +checksum = "d8e7418f59cc01c88316161279a7f665217ae316b388e58a0d10e29f54f1e5eb" dependencies = [ "memchr", "serde", @@ -4681,9 +4681,9 @@ dependencies = [ [[package]] name = "pulldown-cmark" -version = "0.13.1" +version = "0.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83c41efbf8f90ac44de7f3a868f0867851d261b56291732d0cbf7cceaaeb55a6" +checksum = "7c3a14896dfa883796f1cb410461aef38810ea05f2b2c33c5aded3649095fdad" dependencies = [ "bitflags 2.11.0", "getopts", @@ -4803,9 +4803,9 @@ dependencies = [ [[package]] name = "quoted_printable" -version = "0.5.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "640c9bd8497b02465aeef5375144c26062e0dcd5939dfcbb0f5db76cb8c17c73" +checksum = "478e0585659a122aa407eb7e3c0e1fa51b1d8a870038bd29f0cf4a8551eea972" [[package]] name = "r-efi" @@ -5273,9 +5273,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.103.9" +version = "0.103.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53" +checksum = "df33b2b81ac578cabaf06b89b0631153a3f416b0a886e8a7a1707fb51abbd1ef" dependencies = [ "aws-lc-rs", "ring", @@ -5509,9 +5509,9 @@ dependencies = [ [[package]] name = "serde_qs" -version = "1.0.0" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac22439301a0b6f45a037681518e3169e8db1db76080e2e9600a08d1027df037" +checksum = "3c742cd44662647326f86b514eadcc227fff4ce684dbbdaf1943f758d5ea058c" dependencies = [ "itoa", "percent-encoding", @@ -7951,9 +7951,9 @@ dependencies = [ [[package]] name = "zune-jpeg" -version = "0.5.13" +version = "0.5.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec5f41c76397b7da451efd19915684f727d7e1d516384ca6bd0ec43ec94de23c" +checksum = "0b7a1c0af6e5d8d1363f4994b7a091ccf963d8b694f7da5b0b9cceb82da2c0a6" dependencies = [ "zune-core", ] diff --git a/flake.lock b/flake.lock index 1b6bbdcda0..4db212dbcd 100644 --- a/flake.lock +++ b/flake.lock @@ -32,11 +32,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1773821835, - "narHash": "sha256-TJ3lSQtW0E2JrznGVm8hOQGVpXjJyXY2guAxku2O9A4=", + "lastModified": 1774106199, + "narHash": "sha256-US5Tda2sKmjrg2lNHQL3jRQ6p96cgfWh3J1QBliQ8Ws=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "b40629efe5d6ec48dd1efba650c797ddbd39ace0", + "rev": "6c9a78c09ff4d6c21d0319114873508a6ec01655", "type": "github" }, "original": { @@ -74,11 +74,11 @@ ] }, "locked": { - "lastModified": 1773889863, - "narHash": "sha256-tSsmZOHBgq4qfu5MNCAEsKZL1cI4avNLw2oUTXWeb74=", + "lastModified": 1774235565, + "narHash": "sha256-D8OOwvq3zDDCtIhMcNueb9tGSZaZUanKpWDleRgQ80U=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "dbfd51be2692cb7022e301d14c139accb4ee63f0", + "rev": "dc00324a2438762582b49954373112b8eab29cab", "type": "github" }, "original": { From 51c721a705533448a060cb5476bd27fce2029343 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Mon, 23 Mar 2026 12:29:53 +0100 Subject: [PATCH 30/36] remove unnecessary test guards --- crates/defguard_gateway_manager/src/lib.rs | 3 +- .../tests/auto_adoption_wizard.rs | 10 ++-- crates/defguard_setup/tests/common/mod.rs | 46 +++---------------- crates/defguard_setup/tests/initial_setup.rs | 25 ++++------ .../defguard_setup/tests/migration_wizard.rs | 7 +-- crates/defguard_setup/tests/session_info.rs | 10 ++-- crates/defguard_setup/tests/wizard_state.rs | 6 +-- 7 files changed, 25 insertions(+), 82 deletions(-) diff --git a/crates/defguard_gateway_manager/src/lib.rs b/crates/defguard_gateway_manager/src/lib.rs index 393efc7388..ddb9cc53cf 100644 --- a/crates/defguard_gateway_manager/src/lib.rs +++ b/crates/defguard_gateway_manager/src/lib.rs @@ -413,12 +413,11 @@ impl GatewayManager { Some(id) } TriggerOperation::Update => { - let (Some(old), Some(new)) = + let (Some(mut old), Some(new)) = (gateway_notification.old, gateway_notification.new) else { continue; }; - let mut old = old; let id = new.id; if old.address == new.address diff --git a/crates/defguard_setup/tests/auto_adoption_wizard.rs b/crates/defguard_setup/tests/auto_adoption_wizard.rs index a4b7f10114..4d48a48d59 100644 --- a/crates/defguard_setup/tests/auto_adoption_wizard.rs +++ b/crates/defguard_setup/tests/auto_adoption_wizard.rs @@ -76,7 +76,6 @@ async fn seed_wireguard_network(pool: &sqlx::PgPool) -> WireguardNetwork { #[sqlx::test] async fn test_auto_adoption_full_flow(_: PgPoolOptions, options: PgConnectOptions) { - let test_guard = common::setup_test_guard().await; let pool = setup_pool(options).await; initialize_current_settings(&pool) .await @@ -89,7 +88,7 @@ async fn test_auto_adoption_full_flow(_: PgPoolOptions, options: PgConnectOption .await .expect("Failed to init wizard"); - let (client, shutdown_rx) = make_setup_test_client(pool.clone(), test_guard).await; + let (client, shutdown_rx) = make_setup_test_client(pool.clone()).await; assert_auto_adoption_step(&pool, AutoAdoptionWizardStep::Welcome).await; @@ -209,7 +208,6 @@ async fn test_auto_adoption_full_flow(_: PgPoolOptions, options: PgConnectOption #[sqlx::test] async fn test_auto_adoption_auth_enforcement(_: PgPoolOptions, options: PgConnectOptions) { - let test_guard = common::setup_test_guard().await; let pool = setup_pool(options).await; initialize_current_settings(&pool) .await @@ -229,8 +227,7 @@ async fn test_auto_adoption_auth_enforcement(_: PgPoolOptions, options: PgConnec .expect("Failed to build unauthenticated reqwest client") }; - let (client_with_session, _shutdown_rx) = - make_setup_test_client(pool.clone(), test_guard).await; + let (client_with_session, _shutdown_rx) = make_setup_test_client(pool.clone()).await; let base_url = client_with_session.base_url(); let resp = unauthenticated_client @@ -340,7 +337,6 @@ async fn test_auto_adoption_vpn_settings_missing_network( _: PgPoolOptions, options: PgConnectOptions, ) { - let test_guard = common::setup_test_guard().await; let pool = setup_pool(options).await; initialize_current_settings(&pool) .await @@ -350,7 +346,7 @@ async fn test_auto_adoption_vpn_settings_missing_network( .await .expect("Failed to init wizard"); - let (client, _shutdown_rx) = make_setup_test_client(pool.clone(), test_guard).await; + let (client, _shutdown_rx) = make_setup_test_client(pool.clone()).await; // Create admin (no auth required yet) let resp = client diff --git a/crates/defguard_setup/tests/common/mod.rs b/crates/defguard_setup/tests/common/mod.rs index a0e469b69e..2abf9c548e 100644 --- a/crates/defguard_setup/tests/common/mod.rs +++ b/crates/defguard_setup/tests/common/mod.rs @@ -1,6 +1,6 @@ use std::{ net::{IpAddr, Ipv4Addr, SocketAddr}, - sync::{Arc, OnceLock}, + sync::Arc, }; use axum::serve; @@ -24,43 +24,21 @@ use reqwest::{ }; use semver::Version; use sqlx::PgPool; -use tokio::{ - net::TcpListener, - sync::{Mutex, OwnedMutexGuard, oneshot}, - task::JoinHandle, -}; +use tokio::{net::TcpListener, sync::oneshot, task::JoinHandle}; #[allow(dead_code)] pub const TEST_SECRET_KEY: &str = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"; -#[allow(dead_code)] -pub async fn setup_test_guard() -> OwnedMutexGuard<()> { - setup_test_lock().lock_owned().await -} - pub struct TestClient { pub client: Client, pub _jar: Arc, pub port: u16, - pub _test_guard: OwnedMutexGuard<()>, pub _task: JoinHandle<()>, } -fn setup_test_lock() -> Arc> { - static SETUP_TEST_LOCK: OnceLock>> = OnceLock::new(); - - SETUP_TEST_LOCK - .get_or_init(|| Arc::new(Mutex::new(()))) - .clone() -} - impl TestClient { - pub fn new( - router: axum::Router, - listener: TcpListener, - test_guard: OwnedMutexGuard<()>, - ) -> Self { + pub fn new(router: axum::Router, listener: TcpListener) -> Self { let port = listener.local_addr().unwrap().port(); let task = tokio::spawn(async move { serve( @@ -84,7 +62,6 @@ impl TestClient { client, _jar: jar, port, - _test_guard: test_guard, _task: task, } } @@ -108,10 +85,7 @@ impl TestClient { } #[allow(dead_code)] -pub async fn make_setup_test_client( - pool: PgPool, - test_guard: OwnedMutexGuard<()>, -) -> (TestClient, oneshot::Receiver<()>) { +pub async fn make_setup_test_client(pool: PgPool) -> (TestClient, oneshot::Receiver<()>) { let (setup_shutdown_tx, setup_shutdown_rx) = oneshot::channel::<()>(); let app = build_setup_webapp( pool, @@ -122,16 +96,12 @@ pub async fn make_setup_test_client( let listener = TcpListener::bind(addr) .await .expect("Could not bind ephemeral socket"); - ( - TestClient::new(app, listener, test_guard), - setup_shutdown_rx, - ) + (TestClient::new(app, listener), setup_shutdown_rx) } #[allow(dead_code)] pub async fn make_migration_test_client( pool: PgPool, - test_guard: OwnedMutexGuard<()>, ) -> ( TestClient, oneshot::Receiver<()>, @@ -151,11 +121,7 @@ pub async fn make_migration_test_client( let listener = TcpListener::bind(addr) .await .expect("Could not bind ephemeral socket"); - ( - TestClient::new(router, listener, test_guard), - setup_shutdown_rx, - webapp, - ) + (TestClient::new(router, listener), setup_shutdown_rx, webapp) } /// Initialise settings with a known secret key so `build_migration_webapp` can diff --git a/crates/defguard_setup/tests/initial_setup.rs b/crates/defguard_setup/tests/initial_setup.rs index a88648a866..013b5c5a20 100644 --- a/crates/defguard_setup/tests/initial_setup.rs +++ b/crates/defguard_setup/tests/initial_setup.rs @@ -47,7 +47,6 @@ async fn assert_setup_step(pool: &sqlx::PgPool, expected: InitialSetupStep) { #[sqlx::test] async fn test_create_admin(_: PgPoolOptions, options: PgConnectOptions) { - let test_guard = common::setup_test_guard().await; let pool = setup_pool(options).await; initialize_current_settings(&pool) .await @@ -56,7 +55,7 @@ async fn test_create_admin(_: PgPoolOptions, options: PgConnectOptions) { .await .expect("Failed to initialize wizard"); - let (client, _shutdown_rx) = make_setup_test_client(pool.clone(), test_guard).await; + let (client, _shutdown_rx) = make_setup_test_client(pool.clone()).await; let payload = json!({ "first_name": "Admin", @@ -105,7 +104,6 @@ async fn test_create_admin_with_automatic_group_assignment( _: PgPoolOptions, options: PgConnectOptions, ) { - let test_guard = common::setup_test_guard().await; let pool = setup_pool(options).await; initialize_current_settings(&pool) .await @@ -114,7 +112,7 @@ async fn test_create_admin_with_automatic_group_assignment( .await .expect("Failed to initialize wizard"); - let (client, _shutdown_rx) = make_setup_test_client(pool.clone(), test_guard).await; + let (client, _shutdown_rx) = make_setup_test_client(pool.clone()).await; let default_admin_group_name = Settings::get_current_settings().default_admin_group_name; let payload = json!({ @@ -153,7 +151,6 @@ async fn test_create_admin_with_automatic_group_assignment( #[sqlx::test] async fn test_setup_login_too_many_attempts(_: PgPoolOptions, options: PgConnectOptions) { - let test_guard = common::setup_test_guard().await; let pool = setup_pool(options).await; initialize_current_settings(&pool) .await @@ -162,7 +159,7 @@ async fn test_setup_login_too_many_attempts(_: PgPoolOptions, options: PgConnect .await .expect("Failed to initialize wizard"); - let (client, _shutdown_rx) = make_setup_test_client(pool.clone(), test_guard).await; + let (client, _shutdown_rx) = make_setup_test_client(pool.clone()).await; let response = client .post("/api/v1/initial_setup/admin") @@ -204,7 +201,6 @@ async fn test_setup_login_too_many_attempts(_: PgPoolOptions, options: PgConnect #[sqlx::test] async fn test_set_general_config(_: PgPoolOptions, options: PgConnectOptions) { - let test_guard = common::setup_test_guard().await; let pool = setup_pool(options).await; initialize_current_settings(&pool) .await @@ -213,7 +209,7 @@ async fn test_set_general_config(_: PgPoolOptions, options: PgConnectOptions) { .await .expect("Failed to initialize wizard"); - let (client, _shutdown_rx) = make_setup_test_client(pool.clone(), test_guard).await; + let (client, _shutdown_rx) = make_setup_test_client(pool.clone()).await; let response = client .post("/api/v1/initial_setup/admin") @@ -276,7 +272,6 @@ async fn test_set_general_config(_: PgPoolOptions, options: PgConnectOptions) { #[sqlx::test] async fn test_create_ca(_: PgPoolOptions, options: PgConnectOptions) { - let test_guard = common::setup_test_guard().await; let pool = setup_pool(options).await; initialize_current_settings(&pool) .await @@ -285,7 +280,7 @@ async fn test_create_ca(_: PgPoolOptions, options: PgConnectOptions) { .await .expect("Failed to initialize wizard"); - let (client, _shutdown_rx) = make_setup_test_client(pool.clone(), test_guard).await; + let (client, _shutdown_rx) = make_setup_test_client(pool.clone()).await; let response = client .post("/api/v1/initial_setup/admin") @@ -328,7 +323,6 @@ async fn test_create_ca(_: PgPoolOptions, options: PgConnectOptions) { #[sqlx::test] async fn test_upload_ca(_: PgPoolOptions, options: PgConnectOptions) { - let test_guard = common::setup_test_guard().await; let pool = setup_pool(options).await; initialize_current_settings(&pool) .await @@ -337,7 +331,7 @@ async fn test_upload_ca(_: PgPoolOptions, options: PgConnectOptions) { .await .expect("Failed to initialize wizard"); - let (client, _shutdown_rx) = make_setup_test_client(pool.clone(), test_guard).await; + let (client, _shutdown_rx) = make_setup_test_client(pool.clone()).await; let response = client .post("/api/v1/initial_setup/admin") @@ -378,7 +372,6 @@ async fn test_upload_ca(_: PgPoolOptions, options: PgConnectOptions) { #[sqlx::test] async fn test_get_ca(_: PgPoolOptions, options: PgConnectOptions) { - let test_guard = common::setup_test_guard().await; let pool = setup_pool(options).await; initialize_current_settings(&pool) .await @@ -387,7 +380,7 @@ async fn test_get_ca(_: PgPoolOptions, options: PgConnectOptions) { .await .expect("Failed to initialize wizard"); - let (client, _shutdown_rx) = make_setup_test_client(pool.clone(), test_guard).await; + let (client, _shutdown_rx) = make_setup_test_client(pool.clone()).await; let response = client .post("/api/v1/initial_setup/admin") @@ -433,7 +426,6 @@ async fn test_get_ca(_: PgPoolOptions, options: PgConnectOptions) { #[sqlx::test] async fn test_finish_setup(_: PgPoolOptions, options: PgConnectOptions) { - let test_guard = common::setup_test_guard().await; let pool = setup_pool(options).await; initialize_current_settings(&pool) .await @@ -442,7 +434,7 @@ async fn test_finish_setup(_: PgPoolOptions, options: PgConnectOptions) { .await .expect("Failed to initialize wizard"); - let (client, shutdown_rx) = make_setup_test_client(pool.clone(), test_guard).await; + let (client, shutdown_rx) = make_setup_test_client(pool.clone()).await; let response = client .post("/api/v1/initial_setup/admin") @@ -486,7 +478,6 @@ async fn test_finish_setup(_: PgPoolOptions, options: PgConnectOptions) { #[sqlx::test] async fn test_setup_flow(_: PgPoolOptions, options: PgConnectOptions) { - let _test_guard = common::setup_test_guard().await; let pool = setup_pool(options).await; initialize_current_settings(&pool) .await diff --git a/crates/defguard_setup/tests/migration_wizard.rs b/crates/defguard_setup/tests/migration_wizard.rs index 77db4d8128..e1903f1dfe 100644 --- a/crates/defguard_setup/tests/migration_wizard.rs +++ b/crates/defguard_setup/tests/migration_wizard.rs @@ -32,7 +32,6 @@ async fn assert_migration_step(pool: &sqlx::PgPool, expected_variant: &str) { #[sqlx::test] async fn test_migration_full_flow(_: PgPoolOptions, options: PgConnectOptions) { - let test_guard = common::setup_test_guard().await; let pool = setup_pool(options).await; init_settings_with_secret_key(&pool).await; @@ -45,7 +44,7 @@ async fn test_migration_full_flow(_: PgPoolOptions, options: PgConnectOptions) { let wizard = Wizard::get(&pool).await.expect("Failed to get wizard"); assert_eq!(wizard.active_wizard, ActiveWizard::Migration); - let (client, shutdown_rx, _webapp) = make_migration_test_client(pool.clone(), test_guard).await; + let (client, shutdown_rx, _webapp) = make_migration_test_client(pool.clone()).await; let resp = client .get("/api/v1/session-info") @@ -179,7 +178,6 @@ async fn test_migration_full_flow(_: PgPoolOptions, options: PgConnectOptions) { #[sqlx::test] async fn test_migration_auth_enforcement(_: PgPoolOptions, options: PgConnectOptions) { - let test_guard = common::setup_test_guard().await; let pool = setup_pool(options).await; init_settings_with_secret_key(&pool).await; @@ -188,8 +186,7 @@ async fn test_migration_auth_enforcement(_: PgPoolOptions, options: PgConnectOpt .await .expect("Failed to init wizard"); - let (client, _shutdown_rx, _webapp) = - make_migration_test_client(pool.clone(), test_guard).await; + let (client, _shutdown_rx, _webapp) = make_migration_test_client(pool.clone()).await; let unauth = { let mut headers = HeaderMap::new(); diff --git a/crates/defguard_setup/tests/session_info.rs b/crates/defguard_setup/tests/session_info.rs index e3ddff4868..f27d8d034a 100644 --- a/crates/defguard_setup/tests/session_info.rs +++ b/crates/defguard_setup/tests/session_info.rs @@ -16,7 +16,6 @@ use common::{init_settings_with_secret_key, make_migration_test_client, make_set #[sqlx::test] async fn test_session_info_setup_server(_: PgPoolOptions, options: PgConnectOptions) { - let test_guard = common::setup_test_guard().await; let pool = setup_pool(options).await; initialize_current_settings(&pool) .await @@ -25,7 +24,7 @@ async fn test_session_info_setup_server(_: PgPoolOptions, options: PgConnectOpti .await .expect("Failed to initialize wizard"); - let (client, _shutdown_rx) = make_setup_test_client(pool.clone(), test_guard).await; + let (client, _shutdown_rx) = make_setup_test_client(pool.clone()).await; let resp = client .get("/api/v1/session-info") @@ -81,7 +80,6 @@ async fn test_session_info_setup_server(_: PgPoolOptions, options: PgConnectOpti #[sqlx::test] async fn test_session_info_auto_adoption_wizard(_: PgPoolOptions, options: PgConnectOptions) { - let test_guard = common::setup_test_guard().await; let pool = setup_pool(options).await; initialize_current_settings(&pool) .await @@ -91,7 +89,7 @@ async fn test_session_info_auto_adoption_wizard(_: PgPoolOptions, options: PgCon .await .expect("Failed to initialize wizard"); - let (client, _shutdown_rx) = make_setup_test_client(pool.clone(), test_guard).await; + let (client, _shutdown_rx) = make_setup_test_client(pool.clone()).await; let resp = client .get("/api/v1/session-info") @@ -109,7 +107,6 @@ async fn test_session_info_auto_adoption_wizard(_: PgPoolOptions, options: PgCon #[sqlx::test] async fn test_session_info_migration_server(_: PgPoolOptions, options: PgConnectOptions) { - let test_guard = common::setup_test_guard().await; let pool = setup_pool(options).await; init_settings_with_secret_key(&pool).await; @@ -140,8 +137,7 @@ async fn test_session_info_migration_server(_: PgPoolOptions, options: PgConnect .await .expect("Failed to initialize wizard"); - let (client, _shutdown_rx, _webapp) = - make_migration_test_client(pool.clone(), test_guard).await; + let (client, _shutdown_rx, _webapp) = make_migration_test_client(pool.clone()).await; let resp = client .get("/api/v1/session-info") diff --git a/crates/defguard_setup/tests/wizard_state.rs b/crates/defguard_setup/tests/wizard_state.rs index 6103ea850a..1640fa448b 100644 --- a/crates/defguard_setup/tests/wizard_state.rs +++ b/crates/defguard_setup/tests/wizard_state.rs @@ -16,7 +16,6 @@ use common::make_setup_test_client; #[sqlx::test] async fn test_wizard_state_initial(_: PgPoolOptions, options: PgConnectOptions) { - let test_guard = common::setup_test_guard().await; let pool = setup_pool(options).await; initialize_current_settings(&pool) .await @@ -25,7 +24,7 @@ async fn test_wizard_state_initial(_: PgPoolOptions, options: PgConnectOptions) .await .expect("Failed to init wizard"); - let (client, _shutdown_rx) = make_setup_test_client(pool.clone(), test_guard).await; + let (client, _shutdown_rx) = make_setup_test_client(pool.clone()).await; let resp = client .get("/api/v1/wizard") @@ -134,7 +133,6 @@ async fn test_wizard_state_initial(_: PgPoolOptions, options: PgConnectOptions) #[sqlx::test] async fn test_wizard_state_auto_adoption(_: PgPoolOptions, options: PgConnectOptions) { - let test_guard = common::setup_test_guard().await; let pool = setup_pool(options).await; initialize_current_settings(&pool) .await @@ -164,7 +162,7 @@ async fn test_wizard_state_auto_adoption(_: PgPoolOptions, options: PgConnectOpt .await .expect("Failed to init wizard"); - let (client, _shutdown_rx) = make_setup_test_client(pool.clone(), test_guard).await; + let (client, _shutdown_rx) = make_setup_test_client(pool.clone()).await; let state: serde_json::Value = client .get("/api/v1/wizard") From 06888ad7fbbd6ab1ccd416126ab0a662f2008284 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Mon, 23 Mar 2026 12:38:44 +0100 Subject: [PATCH 31/36] reuse transaction --- crates/defguard_core/src/handlers/user.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/crates/defguard_core/src/handlers/user.rs b/crates/defguard_core/src/handlers/user.rs index 3304c3c4f6..850b0db81c 100644 --- a/crates/defguard_core/src/handlers/user.rs +++ b/crates/defguard_core/src/handlers/user.rs @@ -17,7 +17,7 @@ use defguard_common::{ use defguard_mail::{Mail, templates}; use humantime::parse_duration; use serde_json::json; -use sqlx::{Acquire, PgPool}; +use sqlx::PgPool; use utoipa::ToSchema; use super::{ @@ -743,10 +743,7 @@ pub(crate) async fn modify_user( let status_changing = user_info.is_active != user.is_active; let mut transaction = appstate.pool.begin().await?; - let ldap_sync_allowed = { - let transaction_connection = transaction.acquire().await?; - ldap_sync_allowed_for_user(&user, transaction_connection).await? - }; + let ldap_sync_allowed = ldap_sync_allowed_for_user(&user, &mut *transaction).await?; // remove authorized apps if needed let request_app_ids: Vec = user_info From b71a5a17ad5671d1ab0214580818e93c2d7fa843 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Mon, 23 Mar 2026 12:46:19 +0100 Subject: [PATCH 32/36] remove more test guards --- .../tests/integration/api/acl/mod.rs | 7 +++---- .../tests/integration/api/common/client.rs | 8 +------- .../tests/integration/api/common/mod.rs | 19 +++---------------- 3 files changed, 7 insertions(+), 27 deletions(-) diff --git a/crates/defguard_core/tests/integration/api/acl/mod.rs b/crates/defguard_core/tests/integration/api/acl/mod.rs index b88c81f2ff..6f43cc58c5 100644 --- a/crates/defguard_core/tests/integration/api/acl/mod.rs +++ b/crates/defguard_core/tests/integration/api/acl/mod.rs @@ -31,8 +31,8 @@ use sqlx::{ use tokio::net::TcpListener; use super::common::{ - acquire_api_test_guard, authenticate_admin, client::TestClient, exceed_enterprise_limits, - make_base_client, make_test_client, setup_pool, + authenticate_admin, client::TestClient, exceed_enterprise_limits, make_base_client, + make_test_client, setup_pool, }; use crate::common::{init_config, initialize_users}; @@ -41,7 +41,6 @@ mod destinations; mod rules; async fn make_client_v2(pool: PgPool, config: DefGuardConfig) -> TestClient { - let test_guard = acquire_api_test_guard().await; let listener = TcpListener::bind("127.0.0.1:0") .await .expect("Could not bind ephemeral socket"); @@ -49,7 +48,7 @@ async fn make_client_v2(pool: PgPool, config: DefGuardConfig) -> TestClient { initialize_current_settings(&pool) .await .expect("Could not initialize settings"); - let (client, _) = make_base_client(pool, config, listener, test_guard).await; + let (client, _) = make_base_client(pool, config, listener).await; client } diff --git a/crates/defguard_core/tests/integration/api/common/client.rs b/crates/defguard_core/tests/integration/api/common/client.rs index f2cccc6913..2203c0c899 100644 --- a/crates/defguard_core/tests/integration/api/common/client.rs +++ b/crates/defguard_core/tests/integration/api/common/client.rs @@ -15,10 +15,7 @@ use reqwest::{ }; use tokio::{ net::TcpListener, - sync::{ - OwnedMutexGuard, - mpsc::{UnboundedReceiver, error::TryRecvError}, - }, + sync::mpsc::{UnboundedReceiver, error::TryRecvError}, task::JoinHandle, }; @@ -27,7 +24,6 @@ pub struct TestClient { jar: Arc, port: u16, api_event_rx: UnboundedReceiver, - _test_guard: OwnedMutexGuard<()>, // Has to live during whole test api_task_handle: JoinHandle<()>, } @@ -38,7 +34,6 @@ impl TestClient { app: Router, listener: TcpListener, api_event_rx: UnboundedReceiver, - test_guard: OwnedMutexGuard<()>, ) -> Self { let port = listener.local_addr().unwrap().port(); @@ -67,7 +62,6 @@ impl TestClient { jar, port, api_event_rx, - _test_guard: test_guard, api_task_handle, } } diff --git a/crates/defguard_core/tests/integration/api/common/mod.rs b/crates/defguard_core/tests/integration/api/common/mod.rs index 56c1ba2966..a30aac97ab 100644 --- a/crates/defguard_core/tests/integration/api/common/mod.rs +++ b/crates/defguard_core/tests/integration/api/common/mod.rs @@ -2,7 +2,7 @@ pub(crate) mod client; use std::{ net::{IpAddr, Ipv4Addr, SocketAddr}, - sync::{Arc, Mutex, OnceLock}, + sync::{Arc, Mutex}, }; use axum_extra::extract::cookie::Key; @@ -31,7 +31,6 @@ use sqlx::PgPool; use tokio::{ net::TcpListener, sync::{ - Mutex as AsyncMutex, OwnedMutexGuard, broadcast::{self, Receiver}, mpsc::{channel, unbounded_channel}, }, @@ -58,16 +57,6 @@ pub(crate) struct ClientState { pub config: DefGuardConfig, } -static API_TEST_LOCK: OnceLock>> = OnceLock::new(); - -pub(crate) async fn acquire_api_test_guard() -> OwnedMutexGuard<()> { - API_TEST_LOCK - .get_or_init(|| Arc::new(AsyncMutex::new(()))) - .clone() - .lock_owned() - .await -} - impl ClientState { pub fn new( pool: PgPool, @@ -90,7 +79,6 @@ pub(crate) async fn make_base_client( pool: PgPool, config: DefGuardConfig, listener: TcpListener, - test_guard: OwnedMutexGuard<()>, ) -> (TestClient, ClientState) { let (api_event_tx, api_event_rx) = unbounded_channel::(); let (tx, rx) = unbounded_channel::(); @@ -158,13 +146,12 @@ pub(crate) async fn make_base_client( ); ( - TestClient::new(webapp, listener, api_event_rx, test_guard), + TestClient::new(webapp, listener, api_event_rx), client_state, ) } pub(crate) async fn make_test_client(pool: PgPool) -> (TestClient, ClientState) { - let test_guard = acquire_api_test_guard().await; let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0); let listener = TcpListener::bind(addr) .await @@ -175,7 +162,7 @@ pub(crate) async fn make_test_client(pool: PgPool) -> (TestClient, ClientState) initialize_current_settings(&pool) .await .expect("Could not initialize settings"); - make_base_client(pool, config, listener, test_guard).await + make_base_client(pool, config, listener).await } pub(crate) async fn fetch_user_details(client: &TestClient, username: &str) -> UserDetails { From a46fcdf6d15026620e338eaea53191bd5360520d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Mon, 23 Mar 2026 12:52:37 +0100 Subject: [PATCH 33/36] more unused test tooling --- crates/defguard_core/src/grpc/mod.rs | 23 +------------------ .../tests/integration/grpc/common/mod.rs | 2 +- 2 files changed, 2 insertions(+), 23 deletions(-) diff --git a/crates/defguard_core/src/grpc/mod.rs b/crates/defguard_core/src/grpc/mod.rs index 61b2638168..dea48b23b0 100644 --- a/crates/defguard_core/src/grpc/mod.rs +++ b/crates/defguard_core/src/grpc/mod.rs @@ -96,7 +96,7 @@ pub async fn run_grpc_server( Ok(()) } -pub(crate) async fn build_grpc_service_router( +pub async fn build_grpc_service_router( server: Server, pool: PgPool, worker_state: Arc>, @@ -128,27 +128,6 @@ pub(crate) async fn build_grpc_service_router( Ok(router) } -#[doc(hidden)] -pub mod test_support { - use std::sync::{Arc, Mutex}; - - use sqlx::PgPool; - use tonic::transport::{Server, server::Router}; - - use crate::auth::failed_login::FailedLoginMap; - - use super::WorkerState; - - pub async fn build_grpc_service_router( - server: Server, - pool: PgPool, - worker_state: Arc>, - failed_logins: Arc>, - ) -> Result { - super::build_grpc_service_router(server, pool, worker_state, failed_logins).await - } -} - pub struct Job { id: u32, first_name: String, diff --git a/crates/defguard_core/tests/integration/grpc/common/mod.rs b/crates/defguard_core/tests/integration/grpc/common/mod.rs index 40649370c7..3cd27f853b 100644 --- a/crates/defguard_core/tests/integration/grpc/common/mod.rs +++ b/crates/defguard_core/tests/integration/grpc/common/mod.rs @@ -7,7 +7,7 @@ use defguard_common::{ use defguard_core::{ auth::failed_login::FailedLoginMap, db::AppEvent, - grpc::{AUTHORIZATION_HEADER, WorkerState, test_support::build_grpc_service_router}, + grpc::{AUTHORIZATION_HEADER, WorkerState, build_grpc_service_router}, }; use hyper_util::rt::TokioIo; use sqlx::{ From 08d58092e63bd949c1bdb98f94af6ee47fab894d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Mon, 23 Mar 2026 13:32:14 +0100 Subject: [PATCH 34/36] move test helper methods to a dedicated impl block --- .../defguard_gateway_manager/src/handler.rs | 113 +++++++++--------- 1 file changed, 55 insertions(+), 58 deletions(-) diff --git a/crates/defguard_gateway_manager/src/handler.rs b/crates/defguard_gateway_manager/src/handler.rs index 691905df3f..9bd315a228 100644 --- a/crates/defguard_gateway_manager/src/handler.rs +++ b/crates/defguard_gateway_manager/src/handler.rs @@ -123,68 +123,11 @@ impl GatewayHandler { }) } - #[cfg(test)] - pub(crate) fn new_with_test_socket( - gateway: Gateway, - pool: PgPool, - events_tx: Sender, - peer_stats_tx: UnboundedSender, - certs_rx: watch::Receiver>>, - socket_path: PathBuf, - ) -> Result { - let mut handler = Self::new(gateway, pool, events_tx, peer_stats_tx, certs_rx)?; - handler.test_transport = GatewayTestTransport::with_socket_path(socket_path); - Ok(handler) - } - - #[cfg(test)] - pub(crate) fn attach_test_support(&mut self, test_support: GatewayManagerTestSupport) { - self.test_support = Some(test_support); - } - - #[cfg(test)] - fn note_handler_connection_attempt_for_tests(&self) { - if let Some(test_support) = &self.test_support { - test_support.note_handler_connection_attempt(self.gateway.id); - } - } - - #[cfg(not(test))] - fn note_handler_connection_attempt_for_tests(&self) {} - - #[cfg(test)] - fn handler_retry_delay(&self) -> std::time::Duration { - self.test_support - .as_ref() - .map_or(TEN_SECS, GatewayManagerTestSupport::handler_reconnect_delay) - } - #[cfg(not(test))] fn handler_retry_delay(&self) -> std::time::Duration { TEN_SECS } - #[cfg(test)] - fn connect_channel( - &self, - endpoint: Endpoint, - ) -> Result { - if let Some(socket_path) = self.test_transport.socket_path().cloned() { - return Ok(endpoint.connect_with_connector_lazy(tower::service_fn( - move |_: tonic::transport::Uri| { - let socket_path = socket_path.clone(); - async move { - Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new( - tokio::net::UnixStream::connect(socket_path).await?, - )) - } - }, - ))); - } - - self.connect_tls_channel(endpoint) - } - #[cfg(not(test))] fn connect_channel( &self, @@ -439,7 +382,10 @@ impl GatewayHandler { Version::parse(VERSION).expect("failed to parse self version"), ); let mut client = gateway_client::GatewayClient::with_interceptor(channel, interceptor); + + #[cfg(test)] self.note_handler_connection_attempt_for_tests(); + let (tx, rx) = mpsc::unbounded_channel(); let retry_delay = self.handler_retry_delay(); let response = match client.bidi(UnboundedReceiverStream::new(rx)).await { @@ -578,8 +524,59 @@ impl GatewayHandler { .await?; } } +} + +#[cfg(test)] +impl GatewayHandler { + pub(crate) fn new_with_test_socket( + gateway: Gateway, + pool: PgPool, + events_tx: Sender, + peer_stats_tx: UnboundedSender, + certs_rx: watch::Receiver>>, + socket_path: PathBuf, + ) -> Result { + let mut handler = Self::new(gateway, pool, events_tx, peer_stats_tx, certs_rx)?; + handler.test_transport = GatewayTestTransport::with_socket_path(socket_path); + Ok(handler) + } + + pub(crate) fn attach_test_support(&mut self, test_support: GatewayManagerTestSupport) { + self.test_support = Some(test_support); + } + + fn note_handler_connection_attempt_for_tests(&self) { + if let Some(test_support) = &self.test_support { + test_support.note_handler_connection_attempt(self.gateway.id); + } + } + + fn handler_retry_delay(&self) -> std::time::Duration { + self.test_support + .as_ref() + .map_or(TEN_SECS, GatewayManagerTestSupport::handler_reconnect_delay) + } + + fn connect_channel( + &self, + endpoint: Endpoint, + ) -> Result { + if let Some(socket_path) = self.test_transport.socket_path().cloned() { + return Ok(endpoint.connect_with_connector_lazy(tower::service_fn( + move |_: tonic::transport::Uri| { + let socket_path = socket_path.clone(); + async move { + Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new( + tokio::net::UnixStream::connect(socket_path).await?, + )) + } + }, + ))); + } + + self.connect_tls_channel(endpoint) + } - #[cfg(test)] pub(crate) async fn handle_connection_once(&mut self) -> anyhow::Result<()> { let clients = Arc::>>::default(); self.handle_connection_iteration(clients, false) From e1751131113db70a6e1fe86edf95c6eded5f8c2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Mon, 23 Mar 2026 13:33:33 +0100 Subject: [PATCH 35/36] review fixes --- crates/defguard_gateway_manager/src/handler.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/defguard_gateway_manager/src/handler.rs b/crates/defguard_gateway_manager/src/handler.rs index 9bd315a228..8be72a7020 100644 --- a/crates/defguard_gateway_manager/src/handler.rs +++ b/crates/defguard_gateway_manager/src/handler.rs @@ -9,7 +9,7 @@ use std::{ }; #[cfg(test)] -use std::path::PathBuf; +use std::{path::PathBuf, time::Duration}; use chrono::{DateTime, TimeDelta}; use defguard_common::{ @@ -551,7 +551,7 @@ impl GatewayHandler { } } - fn handler_retry_delay(&self) -> std::time::Duration { + fn handler_retry_delay(&self) -> Duration { self.test_support .as_ref() .map_or(TEN_SECS, GatewayManagerTestSupport::handler_reconnect_delay) From 6b0bcc5d259c101a4f8db76d458942358f7c4117 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Mon, 23 Mar 2026 18:39:34 +0100 Subject: [PATCH 36/36] add missing secret key initialization in grpc tests --- .../tests/integration/grpc/common/mod.rs | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/crates/defguard_core/tests/integration/grpc/common/mod.rs b/crates/defguard_core/tests/integration/grpc/common/mod.rs index 3cd27f853b..c6d19a25a7 100644 --- a/crates/defguard_core/tests/integration/grpc/common/mod.rs +++ b/crates/defguard_core/tests/integration/grpc/common/mod.rs @@ -2,7 +2,13 @@ use std::sync::{Arc, Mutex}; use defguard_common::{ auth::claims::{Claims, ClaimsType}, - db::{models::settings::initialize_current_settings, setup_pool}, + db::{ + models::{ + Settings, + settings::{initialize_current_settings, update_current_settings}, + }, + setup_pool, + }, }; use defguard_core::{ auth::failed_login::FailedLoginMap, @@ -27,6 +33,9 @@ use tower::service_fn; use crate::common::initialize_users; +pub const TEST_SECRET_KEY: &str = + "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"; + pub struct TestGrpcServer { grpc_server_task_handle: JoinHandle<()>, pub worker_state: Arc>, @@ -97,6 +106,13 @@ pub(crate) async fn make_grpc_test_server(pool: &PgPool) -> TestGrpcServer { .await .expect("failed to initialize current settings for gRPC tests"); + // set test secret for generating JWT tokens + let mut settings = Settings::get_current_settings(); + settings.secret_key = Some(TEST_SECRET_KEY.to_string()); + update_current_settings(pool, settings) + .await + .expect("Failed to update settings"); + let (client_stream, server_stream) = tokio::io::duplex(1024); let client_channel = create_client_channel(client_stream).await;