diff --git a/Cargo.lock b/Cargo.lock index 7f107e16f5..fcdc270ece 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1385,6 +1385,7 @@ dependencies = [ "defguard_web_ui", "futures", "humantime", + "hyper-util", "ipnetwork", "jsonwebkey", "jsonwebtoken", @@ -3085,9 +3086,9 @@ dependencies = [ [[package]] name = "iri-string" -version = "0.7.10" +version = "0.7.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a" +checksum = "d8e7418f59cc01c88316161279a7f665217ae316b388e58a0d10e29f54f1e5eb" dependencies = [ "memchr", "serde", @@ -4680,9 +4681,9 @@ dependencies = [ [[package]] name = "pulldown-cmark" -version = "0.13.1" +version = "0.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83c41efbf8f90ac44de7f3a868f0867851d261b56291732d0cbf7cceaaeb55a6" +checksum = "7c3a14896dfa883796f1cb410461aef38810ea05f2b2c33c5aded3649095fdad" dependencies = [ "bitflags 2.11.0", "getopts", @@ -4802,9 +4803,9 @@ dependencies = [ [[package]] name = "quoted_printable" -version = "0.5.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "640c9bd8497b02465aeef5375144c26062e0dcd5939dfcbb0f5db76cb8c17c73" +checksum = "478e0585659a122aa407eb7e3c0e1fa51b1d8a870038bd29f0cf4a8551eea972" [[package]] name = "r-efi" @@ -5508,9 +5509,9 @@ dependencies = [ [[package]] name = "serde_qs" -version = "1.0.0" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac22439301a0b6f45a037681518e3169e8db1db76080e2e9600a08d1027df037" +checksum = "3c742cd44662647326f86b514eadcc227fff4ce684dbbdaf1943f758d5ea058c" dependencies = [ "itoa", "percent-encoding", @@ -7950,9 +7951,9 @@ dependencies = [ [[package]] name = "zune-jpeg" -version = "0.5.13" +version = "0.5.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec5f41c76397b7da451efd19915684f727d7e1d516384ca6bd0ec43ec94de23c" +checksum = "0b7a1c0af6e5d8d1363f4994b7a091ccf963d8b694f7da5b0b9cceb82da2c0a6" dependencies = [ "zune-core", ] diff --git a/crates/defguard_common/src/config.rs b/crates/defguard_common/src/config.rs index 40be058538..13a5cea975 100644 --- a/crates/defguard_common/src/config.rs +++ b/crates/defguard_common/src/config.rs @@ -230,7 +230,48 @@ impl DefGuardConfig { // this is an ugly workaround to avoid `cargo test` args being captured by `clap` #[must_use] pub fn new_test_config() -> Self { - Self::parse_from::<[_; 0], String>([]) + #[expect( + deprecated, + reason = "Test config still initializes compatibility-only deprecated fields" + )] + let config = Self { + log_level: "info".to_string(), + log_file: None, + secret_key: None, + database_host: "localhost".to_string(), + database_port: 5432, + database_name: "defguard".to_string(), + database_user: "defguard".to_string(), + database_password: SecretString::from(String::new()), + http_port: 8000, + grpc_port: 50055, + grpc_cert: None, + grpc_key: None, + openid_signing_key: None, + url: None, + disable_stats_purge: None, + stats_purge_frequency: None, + stats_purge_threshold: None, + enrollment_url: None, + enrollment_token_timeout: None, + mfa_code_timeout: None, + session_timeout: None, + password_reset_token_timeout: None, + enrollment_session_timeout: None, + password_reset_session_timeout: None, + cookie_domain: None, + cookie_insecure: false, + cmd: None, + check_period: std::time::Duration::from_secs(12 * 3600).into(), + check_period_no_license: std::time::Duration::from_secs(24 * 3600).into(), + check_period_renewal_window: std::time::Duration::from_secs(3600).into(), + http_bind_address: None, + grpc_bind_address: None, + adopt_gateway: None, + adopt_edge: None, + }; + + config } /// Validate that the auto-adoption flags are consistent. diff --git a/crates/defguard_common/src/db/models/gateway.rs b/crates/defguard_common/src/db/models/gateway.rs index cb32ed6864..79c095467c 100644 --- a/crates/defguard_common/src/db/models/gateway.rs +++ b/crates/defguard_common/src/db/models/gateway.rs @@ -34,6 +34,12 @@ impl Gateway { self.connected_at.is_some() } } + + /// Return address and port as URL with HTTP scheme. + #[must_use] + pub fn url(&self) -> String { + format!("http://{}:{}", self.address, self.port) + } } impl Gateway { @@ -168,12 +174,6 @@ impl Gateway { Ok(record) } - /// Return address and port as URL with HTTP scheme. - #[must_use] - pub fn url(&self) -> String { - format!("http://{}:{}", self.address, self.port) - } - /// Disable all Gateways except one. Used for expired licence. pub async fn leave_one_enabled<'e, E>(executor: E) -> sqlx::Result<()> where diff --git a/crates/defguard_core/Cargo.toml b/crates/defguard_core/Cargo.toml index 6a4c176b61..119d4a8c7c 100644 --- a/crates/defguard_core/Cargo.toml +++ b/crates/defguard_core/Cargo.toml @@ -87,6 +87,7 @@ async-stream = "0.3" [dev-dependencies] claims.workspace = true +hyper-util = "0.1" matches.workspace = true reqwest = { version = "0.12", features = [ "cookies", diff --git a/crates/defguard_core/src/grpc/mod.rs b/crates/defguard_core/src/grpc/mod.rs index 204b836717..dea48b23b0 100644 --- a/crates/defguard_core/src/grpc/mod.rs +++ b/crates/defguard_core/src/grpc/mod.rs @@ -96,7 +96,7 @@ pub async fn run_grpc_server( Ok(()) } -pub(crate) async fn build_grpc_service_router( +pub async fn build_grpc_service_router( server: Server, pool: PgPool, worker_state: Arc>, @@ -114,6 +114,9 @@ pub(crate) async fn build_grpc_service_router( health_reporter .set_serving::>() .await; + health_reporter + .set_serving::>() + .await; let router = server .http2_keepalive_interval(Some(TEN_SECS)) diff --git a/crates/defguard_core/tests/integration/grpc/common/mock_gateway.rs b/crates/defguard_core/tests/integration/grpc/common/mock_gateway.rs deleted file mode 100644 index 11bcdafbfd..0000000000 --- a/crates/defguard_core/tests/integration/grpc/common/mock_gateway.rs +++ /dev/null @@ -1,159 +0,0 @@ -use std::time::Duration; - -use defguard_core::grpc::{AUTHORIZATION_HEADER, HOSTNAME_HEADER}; -use defguard_proto::gateway::{ - Configuration, ConfigurationRequest, Update, - -}; -use defguard_version::{Version, client::ClientVersionInterceptor}; -use tokio::{ - sync::mpsc::{UnboundedSender, unbounded_channel}, - task::JoinHandle, - time::timeout, -}; -use tokio_stream::wrappers::UnboundedReceiverStream; -use tonic::{ - Request, Response, Status, Streaming, - metadata::MetadataValue, - service::{Interceptor, InterceptorLayer, interceptor::InterceptedService}, - transport::Channel, -}; -use tower::ServiceBuilder; - -pub(crate) struct MockGateway { - client: GatewayServiceClient< - InterceptedService, ClientVersionInterceptor>, - >, - hostname: Option, - stats_update_thread_handle: Option>, - updates_stream: Option>, -} - -impl Drop for MockGateway { - fn drop(&mut self) { - if let Some(handle) = &self.stats_update_thread_handle { - handle.abort(); - } - } -} - -#[derive(Clone)] -struct AuthInterceptor { - auth_token: Option, - hostname: Option, -} - -impl AuthInterceptor { - pub(crate) fn new(auth_token: Option, hostname: Option) -> Self { - Self { - auth_token, - hostname, - } - } -} - -impl Interceptor for AuthInterceptor { - fn call(&mut self, mut request: tonic::Request<()>) -> Result, Status> { - // add authorization token - if let Some(token) = &self.auth_token { - request.metadata_mut().insert( - AUTHORIZATION_HEADER, - MetadataValue::try_from(token).expect("failed to convert token into metadata"), - ); - }; - - // add gateway hostname - if let Some(hostname) = &self.hostname { - request.metadata_mut().insert( - HOSTNAME_HEADER, - MetadataValue::try_from(hostname) - .expect("failed to convert hostname into metadata"), - ); - }; - - Ok(request) - } -} - -impl MockGateway { - #[must_use] - pub(crate) async fn new( - client_channel: Channel, - version: Version, - auth_token: Option, - hostname: Option, - ) -> Self { - let intercepted_channel = ServiceBuilder::new() - .layer(InterceptorLayer::new(ClientVersionInterceptor::new( - version, - ))) - .layer(InterceptorLayer::new(AuthInterceptor::new( - auth_token, - hostname.clone(), - ))) - .service(client_channel); - - let client = GatewayServiceClient::new(intercepted_channel); - - Self { - client, - hostname, - stats_update_thread_handle: None, - updates_stream: None, - } - } - - // Fetch gateway config from core - pub(crate) async fn get_gateway_config(&mut self) -> Result, Status> { - let request = Request::new(ConfigurationRequest { - name: self.hostname.clone(), - }); - - self.client.config(request).await - } - - pub(crate) async fn connect_to_updates_stream(&mut self) { - let request = Request::new(()); - - let updates_stream = self.client.updates(request).await.unwrap().into_inner(); - - self.updates_stream = Some(updates_stream); - } - - pub(crate) fn disconnect_from_updates_stream(&mut self) { - self.updates_stream = None; - } - - #[must_use] - pub(crate) async fn receive_next_update(&mut self) -> Option { - match &mut self.updates_stream { - Some(stream) => match timeout(Duration::from_millis(100), stream.message()).await { - Ok(result) => result.expect("failed to reveive update message"), - Err(_) => None, - }, - None => None, - } - } - - // Connect to interface stats update endpoint - // and return a tx which can be used to send stats updates to test gRPC server - #[must_use] - pub(crate) async fn setup_stats_update_stream(&mut self) -> UnboundedSender { - let (tx, rx) = unbounded_channel(); - - let request = Request::new(UnboundedReceiverStream::new(rx)); - - let mut client = self.client.clone(); - let task_handle = tokio::spawn(async move { - client.stats(request).await.expect("stats stream closed"); - }); - - self.stats_update_thread_handle = Some(task_handle); - - tx - } - - pub(crate) fn hostname(&self) -> String { - self.hostname.clone().unwrap_or_default() - } -} diff --git a/crates/defguard_core/tests/integration/grpc/common/mod.rs b/crates/defguard_core/tests/integration/grpc/common/mod.rs index d5d3794685..c6d19a25a7 100644 --- a/crates/defguard_core/tests/integration/grpc/common/mod.rs +++ b/crates/defguard_core/tests/integration/grpc/common/mod.rs @@ -1,45 +1,46 @@ use std::sync::{Arc, Mutex}; -use axum::http::Uri; use defguard_common::{ - db::models::settings::initialize_current_settings, messages::peer_stats_update::PeerStatsUpdate, + auth::claims::{Claims, ClaimsType}, + db::{ + models::{ + Settings, + settings::{initialize_current_settings, update_current_settings}, + }, + setup_pool, + }, }; use defguard_core::{ auth::failed_login::FailedLoginMap, db::AppEvent, - enterprise::license::{License, LicenseTier, set_cached_license}, - events::GrpcEvent, - grpc::{ - WorkerState, build_grpc_service_router, - gateway::{client_state::ClientMap, events::GatewayEvent, map::GatewayMap}, - }, + grpc::{AUTHORIZATION_HEADER, WorkerState, build_grpc_service_router}, }; -use defguard_mail::Mail; use hyper_util::rt::TokioIo; -use sqlx::PgPool; +use sqlx::{ + PgPool, + postgres::{PgConnectOptions, PgPoolOptions}, +}; use tokio::{ io::DuplexStream, - sync::{ - broadcast::{self, Sender}, - mpsc::{UnboundedReceiver, unbounded_channel}, - }, + sync::mpsc::{UnboundedReceiver, unbounded_channel}, task::JoinHandle, }; -use tonic::transport::{Channel, Endpoint, Server, server::Router}; +use tonic::{ + Request, + transport::{Channel, Endpoint, Server, Uri, server::Router}, +}; use tower::service_fn; -use crate::common::{init_config, initialize_users}; +use crate::common::initialize_users; -pub mod mock_gateway; +pub const TEST_SECRET_KEY: &str = + "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"; pub struct TestGrpcServer { grpc_server_task_handle: JoinHandle<()>, - pub grpc_event_rx: UnboundedReceiver, - wireguard_tx: Sender, - client_state: Arc>, + pub worker_state: Arc>, pub client_channel: Channel, - #[allow(dead_code)] - peer_stats_rx: UnboundedReceiver, + pub app_event_rx: UnboundedReceiver, } impl TestGrpcServer { @@ -47,13 +48,10 @@ impl TestGrpcServer { pub async fn new( server_stream: DuplexStream, grpc_router: Router, - grpc_event_rx: UnboundedReceiver, - wireguard_tx: Sender, - client_state: Arc>, + worker_state: Arc>, client_channel: Channel, - peer_stats_rx: UnboundedReceiver, + app_event_rx: UnboundedReceiver, ) -> Self { - // spawn test gRPC server let grpc_server_task_handle = tokio::spawn(async move { grpc_router .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server_stream))) @@ -64,103 +62,117 @@ impl TestGrpcServer { Self { grpc_server_task_handle, - grpc_event_rx, - wireguard_tx, - client_state, + worker_state, client_channel, - peer_stats_rx, + app_event_rx, } } - - pub fn get_client_map(&self) -> std::sync::MutexGuard<'_, ClientMap> { - self.client_state - .lock() - .expect("failed to acquire lock on client state") - } - - pub fn send_wireguard_event(&self, event: GatewayEvent) { - self.wireguard_tx - .send(event) - .expect("failed to send gateway event"); - } } impl Drop for TestGrpcServer { fn drop(&mut self) { - // explicitly stop spawned gRPC server task self.grpc_server_task_handle.abort(); } } +pub(crate) async fn setup_grpc_pool(_: PgPoolOptions, options: PgConnectOptions) -> PgPool { + setup_pool(options).await +} + pub(crate) async fn create_client_channel(client_stream: DuplexStream) -> Channel { - // Move client to an option so we can _move_ the inner value - // on the first attempt to connect. All other attempts will fail. - // reference: https://github.com/hyperium/tonic/blob/master/examples/src/mock/mock.rs#L31 let mut client = Some(client_stream); + let connector = service_fn(move |_: Uri| { + let client = client.take(); + + async move { + if let Some(client) = client { + Ok::<_, std::io::Error>(TokioIo::new(client)) + } else { + Err(std::io::Error::other("Client already taken")) + } + } + }); + Endpoint::try_from("http://[::]:50051") .expect("Failed to create channel") - .connect_with_connector(service_fn(move |_: Uri| { - let client = client.take(); - - async move { - if let Some(client) = client { - Ok(TokioIo::new(client)) - } else { - Err(std::io::Error::other("Client already taken")) - } - } - })) + .connect_with_connector(connector) .await .expect("Failed to create client channel") } pub(crate) async fn make_grpc_test_server(pool: &PgPool) -> TestGrpcServer { - // create communication channel for clients - let (client_stream, server_stream) = tokio::io::duplex(1024); - let client_channel = create_client_channel(client_stream).await; - - // setup helper structs - let (grpc_event_tx, grpc_event_rx) = unbounded_channel::(); - let (app_event_tx, _app_event_rx) = unbounded_channel::(); - let worker_state = Arc::new(Mutex::new(WorkerState::new(app_event_tx.clone()))); - let (wg_tx, _wg_rx) = broadcast::channel::(16); - let (peer_stats_tx, peer_stats_rx) = unbounded_channel::(); - let gateway_state = Arc::new(Mutex::new(GatewayMap::new())); - let client_state = Arc::new(Mutex::new(ClientMap::new())); - - let failed_logins = FailedLoginMap::new(); - let failed_logins = Arc::new(Mutex::new(failed_logins)); - initialize_users(pool).await; initialize_current_settings(pool) .await - .expect("Could not initialize settings"); - - let license = License::new( - "test_customer".to_string(), - false, - // Permanent license - None, - None, - None, - LicenseTier::Business, - ); + .expect("failed to initialize current settings for gRPC tests"); - set_cached_license(Some(license)); - let server = Server::builder(); - - let grpc_router = build_grpc_service_router(server, pool.clone(), worker_state, failed_logins) + // set test secret for generating JWT tokens + let mut settings = Settings::get_current_settings(); + settings.secret_key = Some(TEST_SECRET_KEY.to_string()); + update_current_settings(pool, settings) .await - .unwrap(); + .expect("Failed to update settings"); + + let (client_stream, server_stream) = tokio::io::duplex(1024); + let client_channel = create_client_channel(client_stream).await; + + let (app_event_tx, app_event_rx) = unbounded_channel::(); + let worker_state = Arc::new(Mutex::new(WorkerState::new(app_event_tx))); + let failed_logins = Arc::new(Mutex::new(FailedLoginMap::new())); + let grpc_router = build_grpc_service_router( + Server::builder(), + pool.clone(), + worker_state.clone(), + failed_logins, + ) + .await + .expect("failed to build gRPC router"); TestGrpcServer::new( server_stream, grpc_router, - grpc_event_rx, - wg_tx, - client_state, + worker_state, client_channel, - peer_stats_rx, + app_event_rx, ) .await } + +pub(crate) fn create_yubibridge_jwt(username: &str) -> String { + Claims::new( + ClaimsType::YubiBridge, + username.to_string(), + String::new(), + u32::MAX.into(), + ) + .to_jwt() + .expect("failed to generate YubiBridge token") +} + +pub(crate) fn create_gateway_jwt(username: &str, client_id: &str) -> String { + Claims::new( + ClaimsType::Gateway, + username.to_string(), + client_id.to_string(), + u32::MAX.into(), + ) + .to_jwt() + .expect("failed to generate gateway token") +} + +pub(crate) fn add_authorization_metadata(request: &mut Request, token: &str) { + request.metadata_mut().insert( + AUTHORIZATION_HEADER, + token.parse().expect("failed to encode authorization token"), + ); +} + +pub(crate) fn add_worker_auth_metadata(request: &mut Request, username: &str) { + add_authorization_metadata(request, &create_yubibridge_jwt(username)); +} + +pub(crate) fn worker_request(message: T, username: &str) -> Request { + let mut request = Request::new(message); + add_worker_auth_metadata(&mut request, username); + request +} diff --git a/crates/defguard_core/tests/integration/grpc/gateway.rs b/crates/defguard_core/tests/integration/grpc/gateway.rs deleted file mode 100644 index 14d66342df..0000000000 --- a/crates/defguard_core/tests/integration/grpc/gateway.rs +++ /dev/null @@ -1,671 +0,0 @@ -use std::{ - net::{IpAddr, Ipv4Addr, SocketAddr}, - time::Duration, -}; - -use chrono::{Days, Utc}; -use claims::{assert_err_eq, assert_matches}; -use defguard_common::db::{ - Id, NoId, - models::{ - Device, DeviceType, User, WireguardNetwork, - wireguard::{LocationMfaMode, ServiceLocationMode}, - wireguard_peer_stats::WireguardPeerStats, - }, - setup_pool, -}; -use defguard_core::{ - enterprise::{license::set_cached_license, limits::update_counts}, - events::GrpcEvent, - grpc::{MIN_GATEWAY_VERSION, gateway::events::GatewayEvent}, -}; -use defguard_proto::{ - enterprise::firewall::FirewallPolicy, - gateway::{Configuration, PeerStats, Update, stats_update::Payload, update}, -}; -use semver::Version; -use sqlx::{ - PgPool, - postgres::{PgConnectOptions, PgPoolOptions}, -}; -use tokio::{sync::mpsc::error::TryRecvError, time::sleep}; -use tonic::Code; - -use crate::grpc::common::{TestGrpcServer, make_grpc_test_server, mock_gateway::MockGateway}; - -fn generate_gateway_token(location: &Location) -> String { - Claims::new( - ClaimsType::Gateway, - format!("DEFGUARD-NETWORK-{location_id}"), - location.id.to_string(), - u32::MAX.into(), - ) - .to_jwt() - .expect("failed to generate gateway token") -} -async fn setup_test_server( - pool: PgPool, -) -> (TestGrpcServer, MockGateway, WireguardNetwork, User) { - let test_server = make_grpc_test_server(&pool).await; - - // create a test location - let mut location = WireguardNetwork::new( - "test location".to_string(), - 1000, - "endpoint1".to_string(), - None, - Vec::new(), - false, - false, - false, - LocationMfaMode::Disabled, - ServiceLocationMode::Disabled, - ); - location.keepalive_interval = 100; - location.peer_disconnect_threshold = 100; - let location = location.save(&pool).await.unwrap(); - - // set auth token for gateway - let token = generate_gateway_token(&location); - - // setup mock gateway - let gateway = MockGateway::new( - test_server.client_channel.clone(), - MIN_GATEWAY_VERSION, - Some(token), - Some("test gateway".into()), - ) - .await; - - // get test user - let test_user = User::find_by_username(&pool, "hpotter") - .await - .unwrap() - .unwrap(); - - (test_server, gateway, location, test_user) -} - -#[sqlx::test] -async fn test_gateway_authorization(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - let (test_server, _gateway, test_location, _test_user) = setup_test_server(pool).await; - - // setup another test gateway without a token - let mut test_gateway = MockGateway::new( - test_server.client_channel.clone(), - MIN_GATEWAY_VERSION, - None, - Some("test gateway".into()), - ) - .await; - - // make a request without auth token - let response = test_gateway.get_gateway_config().await; - - // check that response code is `Code::Unauthenticated` - assert!(response.is_err()); - let status = response.err().unwrap(); - assert_eq!(status.code(), Code::Unauthenticated); - - // setup another test gateway with an invalid token - let mut test_gateway = MockGateway::new( - test_server.client_channel.clone(), - MIN_GATEWAY_VERSION, - Some("invalid_token".into()), - Some("test gateway".into()), - ) - .await; - let response = test_gateway.get_gateway_config().await; - assert!(response.is_err()); - let status = response.err().unwrap(); - assert_eq!(status.code(), Code::Unauthenticated); - - // use valid token and retry - let token = generate_gateway_token(&test_location); - // setup another test gateway without a token - let mut test_gateway = MockGateway::new( - test_server.client_channel.clone(), - MIN_GATEWAY_VERSION, - Some(token), - Some("test gateway".into()), - ) - .await; - let response = test_gateway.get_gateway_config().await; - assert!(response.is_ok()); -} - -#[sqlx::test] -async fn test_gateway_hostname_is_required(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - let (test_server, _gateway, test_location, _test_user) = setup_test_server(pool).await; - - // setup gateway without hostname - let token = generate_gateway_token(&test_location); - let mut test_gateway = MockGateway::new( - test_server.client_channel.clone(), - MIN_GATEWAY_VERSION, - Some(token.clone()), - None, - ) - .await; - - // make a request without hostname - let response = test_gateway.get_gateway_config().await; - - // check that response code is `Code::Internal` - assert!(response.is_err()); - let status = response.err().unwrap(); - assert_eq!(status.code(), Code::Internal); - - // set hostname and retry - let mut test_gateway = MockGateway::new( - test_server.client_channel.clone(), - MIN_GATEWAY_VERSION, - Some(token), - Some("test gateway".into()), - ) - .await; - let response = test_gateway.get_gateway_config().await; - assert!(response.is_ok()); -} - -#[sqlx::test] -async fn test_gateway_status(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - let (test_server, mut gateway, test_location, _test_user) = setup_test_server(pool).await; - - // initial gateway map is empty - { - let gateway_map = test_server.get_gateway_map(); - assert!(gateway_map.is_empty()) - } - - // gateway request initial config - // it should be added to status map as disconnected - let response = gateway.get_gateway_config().await; - assert!(response.is_ok()); - { - let gateway_map = test_server.get_gateway_map(); - let location_gateways = gateway_map.get_network_gateway_status(test_location.id); - assert_eq!(location_gateways.len(), 1); - let gateway_state = location_gateways.first().unwrap(); - assert!(!gateway_state.connected); - assert!(gateway_state.connected_at.is_none()); - assert!(gateway_state.disconnected_at.is_none()); - assert_eq!(gateway_state.hostname, gateway.hostname()); - } - - // gateway connects to updates stream - // it should be marked as connected - gateway.connect_to_updates_stream().await; - { - let gateway_map = test_server.get_gateway_map(); - let location_gateways = gateway_map.get_network_gateway_status(test_location.id); - assert_eq!(location_gateways.len(), 1); - let gateway_state = location_gateways.first().unwrap(); - assert!(gateway_state.connected); - assert!(gateway_state.connected_at.is_some()); - assert!(gateway_state.disconnected_at.is_none()); - assert_eq!(gateway_state.hostname, gateway.hostname()); - } - - // gateway disconnect from updates stream - // it should be marked as disconnected - gateway.disconnect_from_updates_stream(); - // wait for the background thread to handle the disconnect - sleep(Duration::from_millis(100)).await; - - { - let gateway_map = test_server.get_gateway_map(); - let location_gateways = gateway_map.get_network_gateway_status(test_location.id); - assert_eq!(location_gateways.len(), 1); - let gateway_state = location_gateways.first().unwrap(); - assert!(!gateway_state.connected); - assert!(gateway_state.connected_at.is_some()); - assert!(gateway_state.disconnected_at.is_some()); - assert_eq!(gateway_state.hostname, gateway.hostname()); - } -} - -#[sqlx::test] -async fn test_vpn_client_connected(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - let (mut test_server, mut gateway, test_location, test_user) = - setup_test_server(pool.clone()).await; - - // initial client map is empty - { - let client_map = test_server.get_client_map(); - assert!(client_map.is_empty()) - } - - // connect stats stream - let stats_tx = gateway.setup_stats_update_stream().await; - let mut update_id = 1; - - // add user device - let device_pubkey = "wYOt6ImBaQ3BEMQ3Xf5P5fTnbqwOvjcqYkkSBt+1xOg="; - let test_device = Device::new( - "test device".into(), - device_pubkey.into(), - test_user.id, - DeviceType::User, - None, - true, - ) - .save(&pool) - .await - .unwrap(); - - // send stats update for existing device with old handshake - // and verify no gRPC event is emitted - stats_tx - .send(StatsUpdate { - id: update_id, - payload: Some(Payload::PeerStats(PeerStats { - public_key: device_pubkey.into(), - endpoint: "1.2.3.4:1234".into(), - latest_handshake: 0, - ..Default::default() - })), - }) - .expect("failed to send stats update"); - - assert_err_eq!(test_server.grpc_event_rx.try_recv(), TryRecvError::Empty); - - // send stats update with current handshake - update_id += 1; - stats_tx - .send(StatsUpdate { - id: update_id, - payload: Some(Payload::PeerStats(PeerStats { - public_key: device_pubkey.into(), - endpoint: "1.2.3.4:1234".into(), - latest_handshake: Utc::now().timestamp() as u64, - ..Default::default() - })), - }) - .expect("failed to send stats update"); - - // wait for event to be emitted - sleep(Duration::from_millis(100)).await; - let grpc_event = test_server - .grpc_event_rx - .try_recv() - .expect("failed to receive gRPC event"); - - assert_matches!( - grpc_event, - GrpcEvent::ClientConnected { - context: _, - location, - device - } if ((location.id == test_location.id) & (device.id == test_device.id)) - ); -} - -#[sqlx::test] -async fn test_vpn_client_disconnected(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - let (mut test_server, mut gateway, test_location, test_user) = - setup_test_server(pool.clone()).await; - - // add user device - let device_pubkey = "wYOt6ImBaQ3BEMQ3Xf5P5fTnbqwOvjcqYkkSBt+1xOg="; - let test_device = Device::new( - "test device".into(), - device_pubkey.into(), - test_user.id, - DeviceType::User, - None, - true, - ) - .save(&pool) - .await - .unwrap(); - - // insert device into client map with an old handshake - { - let mut client_map = test_server.get_client_map(); - let now = Utc::now().naive_utc(); - let stats = WireguardPeerStats { - id: NoId, - device_id: test_device.id, - collected_at: now, - network: test_location.id, - endpoint: None, - upload: 0, - download: 0, - latest_handshake: now.checked_sub_days(Days::new(1)).unwrap(), - allowed_ips: None, - }; - client_map - .connect_vpn_client( - test_location.id, - &gateway.hostname(), - device_pubkey, - &test_device, - &test_user, - SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080), - &stats, - ) - .expect("failed to insert connected client"); - } - - // connect stats stream - let stats_tx = gateway.setup_stats_update_stream().await; - let mut update_id = 1; - - // send stats update with old handshake - update_id += 1; - stats_tx - .send(StatsUpdate { - id: update_id, - payload: Some(Payload::PeerStats(PeerStats { - public_key: device_pubkey.into(), - endpoint: "1.2.3.4:1234".into(), - latest_handshake: 0, - ..Default::default() - })), - }) - .expect("failed to send stats update"); - - // wait for event to be emitted - sleep(Duration::from_millis(100)).await; - let grpc_event = test_server - .grpc_event_rx - .try_recv() - .expect("failed to receive gRPC event"); - - assert_matches!( - grpc_event, - GrpcEvent::ClientDisconnected { - context: _, - location, - device - } if ((location.id == test_location.id) & (device.id == test_device.id)) - ); -} - -#[sqlx::test] -async fn test_gateway_update_routing(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - let (test_server, mut gateway_1, test_location, _test_user) = - setup_test_server(pool.clone()).await; - - // setup another test location & gateway - let mut test_location_2 = WireguardNetwork::new( - "test location 2".to_string(), - 1000, - "endpoint2".to_string(), - None, - Vec::new(), - false, - false, - false, - LocationMfaMode::Disabled, - ServiceLocationMode::Disabled, - ); - test_location_2.keepalive_interval = 100; - test_location_2.peer_disconnect_threshold = 100; - let test_location_2 = test_location_2.save(&pool).await.unwrap(); - - // set auth token for gateway - let token = generate_gateway_token(&test_location_2); - let mut gateway_2 = MockGateway::new( - test_server.client_channel.clone(), - MIN_GATEWAY_VERSION, - Some(token), - Some("test_gateway_2".into()), - ) - .await; - - // register gateways with core - let _config_1 = gateway_1.get_gateway_config().await; - let _config_2 = gateway_2.get_gateway_config().await; - - // connect gateways to the updates stream - gateway_1.connect_to_updates_stream().await; - gateway_2.connect_to_updates_stream().await; - - // send update for location 1 - test_server.send_wireguard_event(GatewayEvent::NetworkDeleted( - test_location.id, - "network name".into(), - )); - - // only one gateway should receive this update - assert!(gateway_2.receive_next_update().await.is_none()); - let update = gateway_1.receive_next_update().await.unwrap(); - let expected_update = Update { - update_type: 2, - update: Some(update::Update::Network(Configuration { - name: "network name".into(), - prvkey: String::new(), - addresses: Vec::new(), - port: 0, - peers: Vec::new(), - firewall_config: None, - })), - }; - assert_eq!(update, expected_update); - - // send update for location 2 - test_server.send_wireguard_event(GatewayEvent::NetworkDeleted( - test_location_2.id, - "network name 2".into(), - )); - - // only one gateway should receive this update - assert!(gateway_1.receive_next_update().await.is_none()); - let update = gateway_2.receive_next_update().await.unwrap(); - let expected_update = Update { - update_type: 2, - update: Some(update::Update::Network(Configuration { - name: "network name 2".into(), - prvkey: String::new(), - addresses: Vec::new(), - port: 0, - peers: Vec::new(), - firewall_config: None, - })), - }; - assert_eq!(update, expected_update); - - // send update for location which does not exist - test_server.send_wireguard_event(GatewayEvent::NetworkDeleted(1234, "does not exist".into())); - - // no gateway should receive this update - assert!(gateway_1.receive_next_update().await.is_none()); - assert!(gateway_2.receive_next_update().await.is_none()); -} - -#[sqlx::test] -async fn test_gateway_config(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - let (_test_server, mut gateway, mut test_location, _test_user) = - setup_test_server(pool.clone()).await; - - // get gateway config - let config = gateway.get_gateway_config().await.unwrap().into_inner(); - - assert_eq!(config.name, test_location.name); - assert!(config.firewall_config.is_none()); - - // enable ACL for test location - test_location.acl_enabled = true; - test_location - .save(&pool) - .await - .expect("failed to update location"); - - // get gateway config - let config = gateway.get_gateway_config().await.unwrap().into_inner(); - assert!(config.firewall_config.is_some()); - assert_eq!( - config.firewall_config.unwrap().default_policy == i32::from(FirewallPolicy::Allow), - test_location.acl_default_allow - ); - - // unset the license and create another location to exceed limits and disable enterprise features - set_cached_license(None); - let mut test_location_2 = WireguardNetwork::new( - "test location 2".to_string(), - 1000, - "endpoint2".to_string(), - None, - Vec::new(), - false, - false, - false, - LocationMfaMode::Disabled, - ServiceLocationMode::Disabled, - ); - test_location_2.keepalive_interval = 100; - test_location_2.peer_disconnect_threshold = 100; - let _test_location_2 = test_location_2.save(&pool).await.unwrap(); - update_counts(&pool).await.unwrap(); - - let config = gateway.get_gateway_config().await.unwrap().into_inner(); - assert!(config.firewall_config.is_none()); -} - -#[sqlx::test] -async fn test_gateway_version_validation(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - let (test_server, _gateway, test_location, _test_user) = setup_test_server(pool.clone()).await; - - // setup gateway with unsupported version - let unsupported_version = - Version::new(MIN_GATEWAY_VERSION.major, MIN_GATEWAY_VERSION.minor - 1, 0); - let token = generate_gateway_token(&test_location); - // setup another test gateway without a token - let mut test_gateway = MockGateway::new( - test_server.client_channel.clone(), - unsupported_version, - Some(token), - Some("test gateway".into()), - ) - .await; - let response = test_gateway.get_gateway_config().await; - - // check that response code is `Code::FailedPrecondition` - assert!(response.is_err()); - let status = response.err().unwrap(); - assert_eq!(status.code(), Code::FailedPrecondition); -} - -// https://github.com/DefGuard/defguard/issues/1671 -#[sqlx::test] -async fn test_device_pubkey_change(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - let (mut test_server, mut gateway, test_location, test_user) = - setup_test_server(pool.clone()).await; - - // initial client map is empty - { - let client_map = test_server.get_client_map(); - assert!(client_map.is_empty()) - } - - // connect stats stream - let stats_tx = gateway.setup_stats_update_stream().await; - let mut update_id = 1; - - // add user device - let device_pubkey = "wYOt6ImBaQ3BEMQ3Xf5P5fTnbqwOvjcqYkkSBt+1xOg="; - let mut test_device = Device::new( - "test device".into(), - device_pubkey.into(), - test_user.id, - DeviceType::User, - None, - true, - ) - .save(&pool) - .await - .unwrap(); - - // send stats update for existing device - stats_tx - .send(StatsUpdate { - id: update_id, - payload: Some(Payload::PeerStats(PeerStats { - public_key: device_pubkey.into(), - endpoint: "1.2.3.4:1234".into(), - latest_handshake: Utc::now().timestamp() as u64, - ..Default::default() - })), - }) - .expect("failed to send stats update"); - - // wait for event to be emitted - sleep(Duration::from_millis(100)).await; - let grpc_event = test_server - .grpc_event_rx - .try_recv() - .expect("failed to receive gRPC event"); - assert_matches!( - grpc_event, - GrpcEvent::ClientConnected { - context: _, - location, - device - } if ((location.id == test_location.id) & (device.id == test_device.id)) - ); - - // change device pubkey - let new_device_pubkey = "TJG2T6rhndZtk06KnIIOlD6hhd7wpVkBss8sfyvMCAA="; - test_device.wireguard_pubkey = new_device_pubkey.to_owned(); - test_device.save(&pool).await.unwrap(); - - // send stats update with old pubkey - update_id += 1; - stats_tx - .send(StatsUpdate { - id: update_id, - payload: Some(Payload::PeerStats(PeerStats { - public_key: device_pubkey.into(), - endpoint: "1.2.3.4:1234".into(), - latest_handshake: Utc::now().timestamp() as u64, - ..Default::default() - })), - }) - .expect("failed to send stats update"); - - // no event should be emitted - sleep(Duration::from_millis(100)).await; - assert_err_eq!(test_server.grpc_event_rx.try_recv(), TryRecvError::Empty); - - // send stats update with new pubkey - update_id += 1; - stats_tx - .send(StatsUpdate { - id: update_id, - payload: Some(Payload::PeerStats(PeerStats { - public_key: new_device_pubkey.into(), - endpoint: "1.2.3.4:1234".into(), - latest_handshake: Utc::now().timestamp() as u64, - ..Default::default() - })), - }) - .expect("failed to send stats update"); - - // wait for event - // FIXME: ideally this should not be emitted; we'll fix it once we implement a more robust VPN session logic - sleep(Duration::from_millis(100)).await; - let grpc_event = test_server - .grpc_event_rx - .try_recv() - .expect("failed to receive gRPC event"); - - assert_matches!( - grpc_event, - GrpcEvent::ClientConnected { - context: _, - location, - device - } if ((location.id == test_location.id) & (device.id == test_device.id)) - ); -} diff --git a/crates/defguard_core/tests/integration/grpc/health.rs b/crates/defguard_core/tests/integration/grpc/health.rs new file mode 100644 index 0000000000..c2e1e5a348 --- /dev/null +++ b/crates/defguard_core/tests/integration/grpc/health.rs @@ -0,0 +1,23 @@ +use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; +use tonic_health::pb::{ + HealthCheckRequest, health_check_response::ServingStatus, health_client::HealthClient, +}; + +use super::common::{make_grpc_test_server, setup_grpc_pool}; + +#[sqlx::test] +async fn worker_service_health_is_serving(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_grpc_pool(PgPoolOptions::new(), options).await; + let server = make_grpc_test_server(&pool).await; + let mut client = HealthClient::new(server.client_channel.clone()); + + let response = client + .check(HealthCheckRequest { + service: "worker.WorkerService".into(), + }) + .await + .expect("health check should succeed") + .into_inner(); + + assert_eq!(response.status, ServingStatus::Serving as i32); +} diff --git a/crates/defguard_core/tests/integration/grpc/mod.rs b/crates/defguard_core/tests/integration/grpc/mod.rs index 5b53a1b0d6..d8d759d58f 100644 --- a/crates/defguard_core/tests/integration/grpc/mod.rs +++ b/crates/defguard_core/tests/integration/grpc/mod.rs @@ -1,2 +1,3 @@ mod common; -mod gateway; +mod health; +mod worker; diff --git a/crates/defguard_core/tests/integration/grpc/worker.rs b/crates/defguard_core/tests/integration/grpc/worker.rs new file mode 100644 index 0000000000..0a319c71b4 --- /dev/null +++ b/crates/defguard_core/tests/integration/grpc/worker.rs @@ -0,0 +1,428 @@ +use claims::assert_matches; +use defguard_common::db::models::{AuthenticationKey, AuthenticationKeyType, User, YubiKey}; +use defguard_core::db::AppEvent; +use defguard_proto::worker::{JobStatus, Worker, worker_service_client::WorkerServiceClient}; +use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; +use tokio::sync::mpsc::error::TryRecvError; +use tonic::Code; + +use super::common::{ + add_authorization_metadata, create_gateway_jwt, make_grpc_test_server, setup_grpc_pool, + worker_request, +}; + +#[sqlx::test] +async fn register_worker_success_and_duplicate(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_grpc_pool(PgPoolOptions::new(), options).await; + let server = make_grpc_test_server(&pool).await; + let mut client = WorkerServiceClient::new(server.client_channel.clone()); + + let response = client + .register_worker(worker_request( + Worker { + id: "worker-1".into(), + }, + "admin", + )) + .await; + + assert!(response.is_ok()); + + let status = client + .register_worker(worker_request( + Worker { + id: "worker-1".into(), + }, + "admin", + )) + .await + .expect_err("duplicate worker should fail"); + + assert_eq!(status.code(), Code::AlreadyExists); + assert_eq!(status.message(), "Worker already registered"); +} + +#[sqlx::test] +async fn get_job_returns_not_found_for_unknown_worker(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_grpc_pool(PgPoolOptions::new(), options).await; + let server = make_grpc_test_server(&pool).await; + let mut client = WorkerServiceClient::new(server.client_channel.clone()); + + let status = client + .get_job(worker_request( + Worker { + id: "missing-worker".into(), + }, + "admin", + )) + .await + .expect_err("missing worker should not have jobs"); + + assert_eq!(status.code(), Code::NotFound); + assert_eq!(status.message(), "No more jobs"); +} + +#[sqlx::test] +async fn get_job_returns_not_found_for_registered_worker_without_jobs( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_grpc_pool(PgPoolOptions::new(), options).await; + let server = make_grpc_test_server(&pool).await; + let mut client = WorkerServiceClient::new(server.client_channel.clone()); + + client + .register_worker(worker_request( + Worker { + id: "worker-1".into(), + }, + "admin", + )) + .await + .expect("worker registration should succeed"); + + let status = client + .get_job(worker_request( + Worker { + id: "worker-1".into(), + }, + "admin", + )) + .await + .expect_err("worker without jobs should get not found"); + + assert_eq!(status.code(), Code::NotFound); + assert_eq!(status.message(), "No more jobs"); +} + +#[sqlx::test] +async fn get_job_returns_seeded_payload(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_grpc_pool(PgPoolOptions::new(), options).await; + let server = make_grpc_test_server(&pool).await; + let mut client = WorkerServiceClient::new(server.client_channel.clone()); + + client + .register_worker(worker_request( + Worker { + id: "worker-1".into(), + }, + "admin", + )) + .await + .expect("worker registration should succeed"); + + let job_id = { + let mut state = server.worker_state.lock().unwrap(); + state.create_job( + "worker-1", + "Minerva".into(), + "McGonagall".into(), + "minerva@hogwart.edu.uk".into(), + "hpotter".into(), + ) + }; + + let response = client + .get_job(worker_request( + Worker { + id: "worker-1".into(), + }, + "admin", + )) + .await + .expect("seeded job should be returned") + .into_inner(); + + assert_eq!(response.job_id, job_id); + assert_eq!(response.first_name, "Minerva"); + assert_eq!(response.last_name, "McGonagall"); + assert_eq!(response.email, "minerva@hogwart.edu.uk"); +} + +#[sqlx::test] +async fn set_job_done_success_removes_job_and_stores_status( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_grpc_pool(PgPoolOptions::new(), options).await; + let mut server = make_grpc_test_server(&pool).await; + let mut client = WorkerServiceClient::new(server.client_channel.clone()); + + client + .register_worker(worker_request( + Worker { + id: "worker-1".into(), + }, + "admin", + )) + .await + .expect("worker registration should succeed"); + + let job_id = { + let mut state = server.worker_state.lock().unwrap(); + state.create_job( + "worker-1", + "Harry".into(), + "Potter".into(), + "h.potter@hogwart.edu.uk".into(), + "hpotter".into(), + ) + }; + + client + .set_job_done(worker_request( + JobStatus { + id: "worker-1".into(), + job_id, + success: true, + public_key: "gpg-public-key".into(), + ssh_key: "ssh-public-key".into(), + yubikey_serial: "yk-serial-1".into(), + error: String::new(), + }, + "admin", + )) + .await + .expect("job completion should succeed"); + + { + let mut state = server.worker_state.lock().unwrap(); + let status = state + .get_job_status(job_id) + .expect("job status should be recorded"); + assert!(status.success); + assert_eq!(status.serial, "yk-serial-1"); + assert_eq!(status.error, ""); + assert!( + state + .get_job("worker-1", std::net::IpAddr::from([127, 0, 0, 1])) + .is_none() + ); + } + + let user = User::find_by_username(&pool, "hpotter") + .await + .expect("user query should succeed") + .expect("user should exist"); + let yubikeys = YubiKey::find_by_user_id(&pool, user.id) + .await + .expect("yubikey query should succeed"); + let auth_keys = AuthenticationKey::find_by_user_id(&pool, user.id, None) + .await + .expect("auth key query should succeed"); + + assert_eq!(yubikeys.len(), 1); + assert_eq!(yubikeys[0].serial, "yk-serial-1"); + assert_eq!(auth_keys.len(), 2); + assert!(auth_keys.iter().any(|key| { + key.key_type == AuthenticationKeyType::Ssh + && key.key == "ssh-public-key" + && key.yubikey_id == Some(yubikeys[0].id) + })); + assert!(auth_keys.iter().any(|key| { + key.key_type == AuthenticationKeyType::Gpg + && key.key == "gpg-public-key" + && key.yubikey_id == Some(yubikeys[0].id) + })); + + let event = server + .app_event_rx + .try_recv() + .expect("success should emit an app event"); + assert_matches!( + event, + AppEvent::HWKeyProvision(data) + if data.username == "hpotter" + && data.email == "h.potter@hogwart.edu.uk" + && data.ssh_key == "ssh-public-key" + && data.pgp_key == "gpg-public-key" + && data.serial == "yk-serial-1" + ); + assert_matches!(server.app_event_rx.try_recv(), Err(TryRecvError::Empty)); +} + +#[sqlx::test] +async fn set_job_done_failure_stores_status_without_keys_or_event( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_grpc_pool(PgPoolOptions::new(), options).await; + let mut server = make_grpc_test_server(&pool).await; + let mut client = WorkerServiceClient::new(server.client_channel.clone()); + + client + .register_worker(worker_request( + Worker { + id: "worker-1".into(), + }, + "admin", + )) + .await + .expect("worker registration should succeed"); + + let job_id = { + let mut state = server.worker_state.lock().unwrap(); + state.create_job( + "worker-1", + "Harry".into(), + "Potter".into(), + "h.potter@hogwart.edu.uk".into(), + "hpotter".into(), + ) + }; + + client + .set_job_done(worker_request( + JobStatus { + id: "worker-1".into(), + job_id, + success: false, + public_key: "gpg-public-key".into(), + ssh_key: "ssh-public-key".into(), + yubikey_serial: "yk-serial-1".into(), + error: "worker failed".into(), + }, + "admin", + )) + .await + .expect("failed completion should still return ok"); + + { + let state = server.worker_state.lock().unwrap(); + let status = state + .get_job_status(job_id) + .expect("failure status should be recorded"); + assert!(!status.success); + assert_eq!(status.error, "worker failed"); + } + + let user = User::find_by_username(&pool, "hpotter") + .await + .expect("user query should succeed") + .expect("user should exist"); + assert!( + YubiKey::find_by_user_id(&pool, user.id) + .await + .expect("yubikey query should succeed") + .is_empty() + ); + assert!( + AuthenticationKey::find_by_user_id(&pool, user.id, None) + .await + .expect("auth key query should succeed") + .is_empty() + ); + assert_matches!(server.app_event_rx.try_recv(), Err(TryRecvError::Empty)); +} + +#[sqlx::test] +async fn set_job_done_unknown_job_is_ignored(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_grpc_pool(PgPoolOptions::new(), options).await; + let mut server = make_grpc_test_server(&pool).await; + let mut client = WorkerServiceClient::new(server.client_channel.clone()); + + client + .register_worker(worker_request( + Worker { + id: "worker-1".into(), + }, + "admin", + )) + .await + .expect("worker registration should succeed"); + + client + .set_job_done(worker_request( + JobStatus { + id: "worker-1".into(), + job_id: 999, + success: true, + public_key: "gpg-public-key".into(), + ssh_key: "ssh-public-key".into(), + yubikey_serial: "yk-serial-1".into(), + error: String::new(), + }, + "admin", + )) + .await + .expect("unknown jobs should be ignored"); + + { + let state = server.worker_state.lock().unwrap(); + assert!(state.get_job_status(999).is_none()); + } + + let user = User::find_by_username(&pool, "hpotter") + .await + .expect("user query should succeed") + .expect("user should exist"); + assert!( + YubiKey::find_by_user_id(&pool, user.id) + .await + .expect("yubikey query should succeed") + .is_empty() + ); + assert!( + AuthenticationKey::find_by_user_id(&pool, user.id, None) + .await + .expect("auth key query should succeed") + .is_empty() + ); + assert_matches!(server.app_event_rx.try_recv(), Err(TryRecvError::Empty)); +} + +#[sqlx::test] +async fn worker_interceptor_requires_valid_yubibridge_token( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_grpc_pool(PgPoolOptions::new(), options).await; + let server = make_grpc_test_server(&pool).await; + let mut client = WorkerServiceClient::new(server.client_channel.clone()); + + let missing_auth = client + .register_worker(tonic::Request::new(Worker { + id: "worker-missing-auth".into(), + })) + .await + .expect_err("missing auth should fail"); + assert_eq!(missing_auth.code(), Code::Unauthenticated); + assert_eq!(missing_auth.message(), "Missing authorization header"); + + let mut invalid_request = tonic::Request::new(Worker { + id: "worker-invalid-token".into(), + }); + add_authorization_metadata(&mut invalid_request, "not-a-jwt"); + let invalid_token = client + .register_worker(invalid_request) + .await + .expect_err("invalid token should fail"); + assert_eq!(invalid_token.code(), Code::Unauthenticated); + assert_eq!(invalid_token.message(), "Invalid token"); + + let mut wrong_claims_request = tonic::Request::new(Worker { + id: "worker-wrong-claims".into(), + }); + add_authorization_metadata( + &mut wrong_claims_request, + &create_gateway_jwt("admin", "gateway-network-1"), + ); + let wrong_claims = client + .register_worker(wrong_claims_request) + .await + .expect_err("wrong claims type should fail"); + assert_eq!(wrong_claims.code(), Code::Unauthenticated); + assert_eq!(wrong_claims.message(), "Invalid token"); + + let allowed = client + .get_job(worker_request( + Worker { + id: "worker-valid-token".into(), + }, + "admin", + )) + .await + .expect_err("valid token should reach service logic"); + assert_eq!(allowed.code(), Code::NotFound); + assert_eq!(allowed.message(), "No more jobs"); +} diff --git a/crates/defguard_core/tests/integration/main.rs b/crates/defguard_core/tests/integration/main.rs index 43ffcc1303..f0ca4e17ed 100644 --- a/crates/defguard_core/tests/integration/main.rs +++ b/crates/defguard_core/tests/integration/main.rs @@ -1,4 +1,4 @@ mod api; mod common; -// mod grpc; +mod grpc; mod ldap; diff --git a/crates/defguard_gateway_manager/src/certs.rs b/crates/defguard_gateway_manager/src/certs.rs index a1daf0b53a..5e2367e179 100644 --- a/crates/defguard_gateway_manager/src/certs.rs +++ b/crates/defguard_gateway_manager/src/certs.rs @@ -31,3 +31,152 @@ pub(super) async fn refresh_certs(pool: &PgPool, tx: &watch::Sender, + ) -> Gateway { + let mut gateway = Gateway::new( + location_id, + name.to_string(), + "127.0.0.1".to_string(), + 51820, + "test-admin".to_string(), + ); + gateway.certificate = certificate.map(str::to_owned); + + gateway + .save(pool) + .await + .expect("failed to create gateway for cert refresh tests") + } +} diff --git a/crates/defguard_gateway_manager/src/handler.rs b/crates/defguard_gateway_manager/src/handler.rs index 45dd490ada..8be72a7020 100644 --- a/crates/defguard_gateway_manager/src/handler.rs +++ b/crates/defguard_gateway_manager/src/handler.rs @@ -8,6 +8,9 @@ use std::{ }, }; +#[cfg(test)] +use std::{path::PathBuf, time::Duration}; + use chrono::{DateTime, TimeDelta}; use defguard_common::{ VERSION, @@ -26,7 +29,6 @@ use defguard_core::{ handlers::mail::{send_gateway_disconnected_email, send_gateway_reconnected_email}, location_management::allowed_peers::get_location_allowed_peers, }; -#[cfg(not(test))] use defguard_grpc_tls::{certs as tls_certs, connector::HttpsSchemeConnector}; use defguard_proto::{ enterprise::firewall::FirewallConfig, @@ -36,7 +38,6 @@ use defguard_proto::{ }, }; use defguard_version::client::ClientVersionInterceptor; -#[cfg(not(test))] use hyper_rustls::HttpsConnectorBuilder; use reqwest::Url; use semver::Version; @@ -54,8 +55,30 @@ use tonic::{Code, Status, transport::Endpoint}; use crate::{Client, TEN_SECS, error::GatewayError}; +#[cfg(test)] +use crate::GatewayManagerTestSupport; + +#[cfg(test)] +#[derive(Debug, Default)] +struct GatewayTestTransport { + socket_path: Option, +} + +#[cfg(test)] +impl GatewayTestTransport { + fn with_socket_path(socket_path: PathBuf) -> Self { + Self { + socket_path: Some(socket_path), + } + } + + fn socket_path(&self) -> Option<&PathBuf> { + self.socket_path.as_ref() + } +} + /// One instance per connected Gateway. -pub(super) struct GatewayHandler { +pub(crate) struct GatewayHandler { // Gateway server endpoint URL. url: Url, gateway: Gateway, @@ -64,6 +87,10 @@ pub(super) struct GatewayHandler { events_tx: Sender, peer_stats_tx: UnboundedSender, certs_rx: watch::Receiver>>, + #[cfg(test)] + test_transport: GatewayTestTransport, + #[cfg(test)] + test_support: Option, } impl GatewayHandler { @@ -89,9 +116,49 @@ impl GatewayHandler { events_tx, peer_stats_tx, certs_rx, + #[cfg(test)] + test_transport: GatewayTestTransport::default(), + #[cfg(test)] + test_support: None, }) } + #[cfg(not(test))] + fn handler_retry_delay(&self) -> std::time::Duration { + TEN_SECS + } + + #[cfg(not(test))] + fn connect_channel( + &self, + endpoint: Endpoint, + ) -> Result { + self.connect_tls_channel(endpoint) + } + + fn connect_tls_channel( + &self, + endpoint: Endpoint, + ) -> Result { + let settings = Settings::get_current_settings(); + let Some(ca_cert_der) = settings.ca_cert_der else { + return Err(GatewayError::EndpointError( + "Core CA is not setup, can't create a Gateway endpoint.".to_string(), + )); + }; + let tls_config = + tls_certs::client_config(&ca_cert_der, self.certs_rx.clone(), self.gateway.id) + .map_err(|err| GatewayError::EndpointError(err.to_string()))?; + let connector = HttpsConnectorBuilder::new() + .with_tls_config(tls_config) + .https_only() + .enable_http2() + .build(); + let connector = HttpsSchemeConnector::new(connector); + + Ok(endpoint.connect_with_connector_lazy(connector)) + } + fn endpoint(&self) -> Result { let mut url = self.url.clone(); @@ -247,179 +314,275 @@ impl GatewayHandler { }); } - /// Connect to Gateway and handle its messages through gRPC. - pub(super) async fn handle_connection( + async fn mark_disconnected(&mut self) { + if let Err(err) = self.gateway.touch_disconnected(&self.pool).await { + error!( + "Failed to update disconnection time for {} in the database: {err}", + self.gateway + ); + } + } + + async fn handle_disconnection_error(&mut self) { + if self.gateway.is_connected() { + self.send_disconnect_notification().await; + } + + self.mark_disconnected().await; + } + + async fn mark_connected_and_maybe_notify(&mut self, network_name: &str) { + if let Err(err) = self.gateway.touch_connected(&self.pool).await { + error!( + "Failed to update connection time for {} in the database: {err}", + self.gateway + ); + return; + } + + self.send_reconnect_notification(network_name.to_owned()); + } + + fn remove_client(&self, clients: &Arc>>) { + clients + .lock() + .expect("GatewayHandler failed to lock clients") + .remove(&self.gateway.id); + } + + async fn handle_stream_disconnection( + &mut self, + clients: &Arc>>, + retry_on_connect_failure: bool, + retry_delay: std::time::Duration, + ) { + self.remove_client(clients); + self.handle_disconnection_error().await; + + if !retry_on_connect_failure { + return; + } + + debug!("Waiting {retry_delay:?} to re-establish the connection"); + sleep(retry_delay).await; + } + + async fn handle_connection_iteration( &mut self, clients: Arc>>, + retry_on_connect_failure: bool, ) -> Result<(), GatewayError> { - #[cfg(test)] - let _ = &self.certs_rx; let endpoint = self.endpoint()?; let uri = endpoint.uri().to_string(); - loop { - #[cfg(not(test))] - let channel = { - let settings = Settings::get_current_settings(); - let Some(ca_cert_der) = settings.ca_cert_der else { - return Err(GatewayError::EndpointError( - "Core CA is not setup, can't create a Gateway endpoint.".to_string(), - )); - }; - let tls_config = - tls_certs::client_config(&ca_cert_der, self.certs_rx.clone(), self.gateway.id) - .map_err(|err| GatewayError::EndpointError(err.to_string()))?; - let connector = HttpsConnectorBuilder::new() - .with_tls_config(tls_config) - .https_only() - .enable_http2() - .build(); - let connector = HttpsSchemeConnector::new(connector); - endpoint.connect_with_connector_lazy(connector) - }; - #[cfg(test)] - let channel = endpoint.connect_with_connector_lazy(tower::service_fn( - |_: tonic::transport::Uri| async { - Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new( - tokio::net::UnixStream::connect(super::TONIC_SOCKET).await?, - )) - }, - )); - debug!("Connecting to Gateway {uri}"); - let interceptor = ClientVersionInterceptor::new( - Version::parse(VERSION).expect("failed to parse self version"), - ); - let mut client = gateway_client::GatewayClient::with_interceptor(channel, interceptor); - clients - .lock() - .expect("GatewayHandler failed to lock clients") - .insert(self.gateway.id, client.clone()); - let (tx, rx) = mpsc::unbounded_channel(); - let response = match client.bidi(UnboundedReceiverStream::new(rx)).await { - Ok(response) => response, - Err(err) => { - error!("Failed to connect to Gateway {uri}, retrying: {err}"); - sleep(TEN_SECS).await; - continue; - } - }; - info!("Connected to Defguard Gateway {uri}"); + let channel = self.connect_channel(endpoint)?; - let maybe_info = defguard_version::ComponentInfo::from_metadata(response.metadata()); - let (version, _info) = defguard_version::get_tracing_variables(&maybe_info); + debug!("Connecting to Gateway {uri}"); + let interceptor = ClientVersionInterceptor::new( + Version::parse(VERSION).expect("failed to parse self version"), + ); + let mut client = gateway_client::GatewayClient::with_interceptor(channel, interceptor); + + #[cfg(test)] + self.note_handler_connection_attempt_for_tests(); + + let (tx, rx) = mpsc::unbounded_channel(); + let retry_delay = self.handler_retry_delay(); + let response = match client.bidi(UnboundedReceiverStream::new(rx)).await { + Ok(response) => response, + Err(err) => { + error!("Failed to connect to Gateway {uri}, retrying: {err}"); + if retry_on_connect_failure { + sleep(retry_delay).await; + return Ok(()); + } - if let Some(mut gateway) = Gateway::find_by_id(&self.pool, self.gateway.id).await? { - gateway.version = Some(version.to_string()); - gateway.save(&self.pool).await?; + return Err(err.into()); } + }; + let maybe_info = defguard_version::ComponentInfo::from_metadata(response.metadata()); + let (version, _info) = defguard_version::get_tracing_variables(&maybe_info); - let mut resp_stream = response.into_inner(); - let mut config_sent = false; + if let Some(mut gateway) = Gateway::find_by_id(&self.pool, self.gateway.id).await? { + gateway.version = Some(version.to_string()); + gateway.save(&self.pool).await?; + } - 'message: loop { - match resp_stream.message().await { - Ok(None) => { - info!("Stream was closed by the sender."); - break 'message; - } - Ok(Some(received)) => { - info!("Received message from Gateway."); - debug!("Message from Gateway {uri}"); + clients + .lock() + .expect("GatewayHandler failed to lock clients") + .insert(self.gateway.id, client.clone()); + info!("Connected to Defguard Gateway {uri}"); - match received.payload { - Some(core_request::Payload::ConfigRequest(_config_request)) => { - if config_sent { - warn!( - "Ignoring repeated configuration request from {}", - self.gateway + let mut resp_stream = response.into_inner(); + let mut config_sent = false; + + loop { + match resp_stream.message().await { + Ok(None) => { + info!("Stream was closed by the sender."); + self.handle_stream_disconnection( + &clients, + retry_on_connect_failure, + retry_delay, + ) + .await; + return Ok(()); + } + Ok(Some(received)) => { + info!("Received message from Gateway."); + debug!("Message from Gateway {uri}"); + + match received.payload { + Some(core_request::Payload::ConfigRequest(_config_request)) => { + if config_sent { + warn!( + "Ignoring repeated configuration request from {}", + self.gateway + ); + continue; + } + + match self.send_configuration(&tx).await { + Ok(network) => { + info!("Sent configuration to {}", self.gateway); + config_sent = true; + self.mark_connected_and_maybe_notify(&network.name).await; + let mut updates_handler = GatewayUpdatesHandler::new( + self.gateway.location_id, + network, + self.gateway.name.clone(), + self.events_tx.subscribe(), + tx.clone(), ); - continue; + tokio::spawn(async move { + updates_handler.run().await; + }); } - - // Send network configuration to Gateway. - match self.send_configuration(&tx).await { - Ok(network) => { - info!("Sent configuration to {}", self.gateway); - config_sent = true; - match self.gateway.touch_connected(&self.pool).await { - Ok(()) => { - self.send_reconnect_notification( - network.name.clone(), - ); - } - Err(err) => { - error!( - "Failed to update connection time for {} in the database: {err}", - self.gateway - ); - } - } - let mut updates_handler = GatewayUpdatesHandler::new( - self.gateway.location_id, - network, - self.gateway.name.clone(), - self.events_tx.subscribe(), - tx.clone(), - ); - tokio::spawn(async move { - updates_handler.run().await; - }); - } - Err(err) => { - error!( - "Failed to send configuration to {}: {err}", - self.gateway - ); - } + Err(err) => { + error!( + "Failed to send configuration to {}: {err}", + self.gateway + ); } } - Some(core_request::Payload::PeerStats(peer_stats)) => { - if !config_sent { + } + Some(core_request::Payload::PeerStats(peer_stats)) => { + if !config_sent { + warn!( + "Ignoring peer statistics from {} because it hasn't \ + authorized itself", + self.gateway + ); + continue; + } + + match try_protos_into_stats_message( + peer_stats.clone(), + self.gateway.location_id, + self.gateway.id, + ) { + None => { warn!( - "Ignoring peer statistics from {} because it hasn't \ - authorized itself", - self.gateway + "Failed to parse peer stats update. Skipping sending \ + message to session manager." ); - continue; } - - // convert stats to DB storage format - match try_protos_into_stats_message( - peer_stats.clone(), - self.gateway.location_id, - self.gateway.id, - ) { - None => { - warn!( - "Failed to parse peer stats update. Skipping sending \ - message to session manager." + Some(message) => { + if let Err(err) = self.peer_stats_tx.send(message) { + error!( + "Failed to send peers stats update to session manager: {err}" ); } - Some(message) => { - if let Err(err) = self.peer_stats_tx.send(message) { - error!( - "Failed to send peers stats update to session manager: {err}" - ); - } - } } } - None => (), } + None => (), } - Err(err) => { - error!("Disconnected from Gateway at {uri}, error: {err}"); - // Important: call this funtion before setting disconnection time. - self.send_disconnect_notification().await; - let _ = self.gateway.touch_disconnected(&self.pool).await; - debug!("Waiting 10s to re-establish the connection"); - sleep(TEN_SECS).await; - break 'message; - } + } + Err(err) => { + error!("Disconnected from Gateway at {uri}, error: {err}"); + self.handle_stream_disconnection( + &clients, + retry_on_connect_failure, + retry_delay, + ) + .await; + return Ok(()); } } } } + + /// Connect to Gateway and handle its messages through gRPC. + pub(super) async fn handle_connection( + &mut self, + clients: Arc>>, + ) -> Result<(), GatewayError> { + loop { + self.handle_connection_iteration(Arc::clone(&clients), true) + .await?; + } + } +} + +#[cfg(test)] +impl GatewayHandler { + pub(crate) fn new_with_test_socket( + gateway: Gateway, + pool: PgPool, + events_tx: Sender, + peer_stats_tx: UnboundedSender, + certs_rx: watch::Receiver>>, + socket_path: PathBuf, + ) -> Result { + let mut handler = Self::new(gateway, pool, events_tx, peer_stats_tx, certs_rx)?; + handler.test_transport = GatewayTestTransport::with_socket_path(socket_path); + Ok(handler) + } + + pub(crate) fn attach_test_support(&mut self, test_support: GatewayManagerTestSupport) { + self.test_support = Some(test_support); + } + + fn note_handler_connection_attempt_for_tests(&self) { + if let Some(test_support) = &self.test_support { + test_support.note_handler_connection_attempt(self.gateway.id); + } + } + + fn handler_retry_delay(&self) -> Duration { + self.test_support + .as_ref() + .map_or(TEN_SECS, GatewayManagerTestSupport::handler_reconnect_delay) + } + + fn connect_channel( + &self, + endpoint: Endpoint, + ) -> Result { + if let Some(socket_path) = self.test_transport.socket_path().cloned() { + return Ok(endpoint.connect_with_connector_lazy(tower::service_fn( + move |_: tonic::transport::Uri| { + let socket_path = socket_path.clone(); + async move { + Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new( + tokio::net::UnixStream::connect(socket_path).await?, + )) + } + }, + ))); + } + + self.connect_tls_channel(endpoint) + } + + pub(crate) async fn handle_connection_once(&mut self) -> anyhow::Result<()> { + let clients = Arc::>>::default(); + self.handle_connection_iteration(clients, false) + .await + .map_err(anyhow::Error::from) + } } /// Helper struct for handling gateway events. @@ -847,8 +1010,8 @@ fn try_protos_into_stats_message( )) } -fn gen_config( - network: &WireguardNetwork, +fn gen_config( + network: &WireguardNetwork, peers: Vec, maybe_firewall_config: Option, ) -> Configuration { @@ -868,7 +1031,7 @@ fn gen_config( mod tests { use std::{collections::HashMap, net::IpAddr, str::FromStr, sync::Arc}; - use chrono::Utc; + use chrono::{DateTime, Utc}; use defguard_common::db::{ Id, models::{ @@ -881,11 +1044,14 @@ mod tests { setup_pool, }; use defguard_core::grpc::GatewayEvent; - use defguard_proto::gateway::core_response; + use defguard_proto::gateway::{Peer, PeerStats, core_response}; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use tokio::sync::{broadcast, mpsc::unbounded_channel, watch}; - use super::{GatewayHandler, GatewayUpdatesHandler}; + use super::{ + FirewallConfig, GatewayHandler, GatewayUpdatesHandler, gen_config, + try_protos_into_stats_message, + }; fn test_network(location_mfa_mode: LocationMfaMode) -> WireguardNetwork { WireguardNetwork::new( @@ -903,6 +1069,139 @@ mod tests { .with_id(1) } + fn build_peer_stats(endpoint: &str) -> PeerStats { + PeerStats { + public_key: "peer-public-key".to_string(), + endpoint: endpoint.to_string(), + upload: 123, + download: 456, + keepalive_interval: 25, + latest_handshake: 1_700_000_000, + allowed_ips: "10.10.0.2/32".to_string(), + } + } + + fn build_network() -> WireguardNetwork { + let mut network = WireguardNetwork::new( + "test-network".to_string(), + 51820, + "198.51.100.10".to_string(), + Some("1.1.1.1".to_string()), + ["0.0.0.0/0".parse().expect("valid allowed IP network")], + false, + false, + false, + LocationMfaMode::default(), + ServiceLocationMode::default(), + ) + .set_address([ + "10.10.0.1/24".parse().expect("valid IPv4 network"), + "fd00::1/64".parse().expect("valid IPv6 network"), + ]) + .expect("valid network addresses") + .with_id(1); + network.pubkey = "network-public-key".to_string(); + network.prvkey = "network-private-key".to_string(); + network.mtu = 1420; + network.fwmark = 4321; + network.keepalive_interval = 25; + network.peer_disconnect_threshold = 180; + network + } + + #[test] + fn try_protos_into_stats_message_maps_valid_peer_stats() { + let stats = try_protos_into_stats_message(build_peer_stats("203.0.113.10:51820"), 11, 22) + .expect("valid peer stats should be converted"); + + assert_eq!(stats.location_id, 11); + assert_eq!(stats.gateway_id, 22); + assert_eq!(stats.device_pubkey, "peer-public-key"); + assert_eq!(stats.endpoint.to_string(), "203.0.113.10:51820"); + assert_eq!(stats.upload, 123); + assert_eq!(stats.download, 456); + assert_eq!( + stats.latest_handshake, + DateTime::from_timestamp(1_700_000_000, 0) + .expect("valid handshake timestamp") + .naive_utc() + ); + } + + #[test] + fn try_protos_into_stats_message_rejects_invalid_endpoint() { + let stats = try_protos_into_stats_message(build_peer_stats("not-a-socket-address"), 11, 22); + + assert!(stats.is_none()); + } + + #[test] + fn try_protos_into_stats_message_falls_back_to_default_timestamp() { + let stats = try_protos_into_stats_message( + PeerStats { + latest_handshake: i64::MAX as u64, + ..build_peer_stats("203.0.113.10:51820") + }, + 11, + 22, + ) + .expect("valid endpoint should still produce stats"); + + assert_eq!( + stats.latest_handshake, + DateTime::::default().naive_utc() + ); + } + + #[test] + fn gen_config_maps_network_fields() { + let config = gen_config( + &build_network(), + vec![Peer { + pubkey: "peer-public-key".to_string(), + allowed_ips: vec!["10.10.0.2/32".to_string()], + preshared_key: Some("peer-preshared-key".to_string()), + keepalive_interval: Some(25), + }], + Some(FirewallConfig { + default_policy: 0, + rules: Vec::new(), + snat_bindings: Vec::new(), + }), + ); + + assert_eq!(config.name, "test-network"); + assert_eq!(config.port, 51820); + assert_eq!(config.prvkey, "network-private-key"); + assert_eq!(config.addresses, vec!["10.10.0.1/24", "fd00::1/64"]); + assert_eq!(config.mtu, 1420); + assert_eq!(config.fwmark, 4321); + + let peer = config + .peers + .first() + .expect("generated config should include peer"); + assert_eq!(peer.pubkey, "peer-public-key"); + assert_eq!(peer.allowed_ips, vec!["10.10.0.2/32"]); + assert_eq!(peer.preshared_key.as_deref(), Some("peer-preshared-key")); + assert_eq!(peer.keepalive_interval, Some(25)); + + let firewall_config = config + .firewall_config + .expect("generated config should include firewall config"); + assert_eq!(firewall_config.default_policy, 0); + assert!(firewall_config.rules.is_empty()); + assert!(firewall_config.snat_bindings.is_empty()); + } + + #[test] + fn gen_config_preserves_absent_firewall_config_and_empty_peers() { + let config = gen_config(&build_network(), Vec::new(), None); + + assert!(config.peers.is_empty()); + assert!(config.firewall_config.is_none()); + } + fn test_handler(location_mfa_mode: LocationMfaMode) -> GatewayUpdatesHandler { let network = test_network(location_mfa_mode); let (events_tx, events_rx) = broadcast::channel(1); diff --git a/crates/defguard_gateway_manager/src/lib.rs b/crates/defguard_gateway_manager/src/lib.rs index 495138560d..ddb9cc53cf 100644 --- a/crates/defguard_gateway_manager/src/lib.rs +++ b/crates/defguard_gateway_manager/src/lib.rs @@ -4,6 +4,12 @@ use std::{ time::Duration, }; +#[cfg(test)] +use std::{ + path::PathBuf, + sync::atomic::{AtomicBool, Ordering}, +}; + use defguard_common::{ db::{ChangeNotification, Id, TriggerOperation, models::gateway::Gateway}, messages::peer_stats_update::PeerStatsUpdate, @@ -18,6 +24,9 @@ use tokio::{ }; use tonic::{Request, service::interceptor::InterceptedService, transport::Channel}; +#[cfg(test)] +use tokio::sync::Notify; + use crate::{error::GatewayError, handler::GatewayHandler}; #[macro_use] @@ -26,21 +35,223 @@ extern crate tracing; mod certs; mod error; mod handler; -// #[cfg(test)] -// mod tests; #[cfg(test)] -static TONIC_SOCKET: &str = "tonic.sock"; +mod tests; + const GATEWAY_TABLE_TRIGGER: &str = "gateway_change"; const GATEWAY_RECONNECT_DELAY: Duration = Duration::from_secs(5); const TEN_SECS: Duration = Duration::from_secs(10); type Client = GatewayClient>; +struct AbortTaskOnDrop { + handle: Option>, +} + +impl AbortTaskOnDrop { + fn new(handle: tokio::task::JoinHandle) -> Self { + Self { + handle: Some(handle), + } + } +} + +impl Drop for AbortTaskOnDrop { + fn drop(&mut self) { + if let Some(handle) = self.handle.take() { + handle.abort(); + } + } +} + +#[cfg(test)] +#[derive(Clone, Default)] +struct GatewayManagerTestSupport { + socket_paths_by_url: Arc>>, + handler_spawn_attempts_by_gateway: Arc>>, + handler_spawn_attempt_notify: Arc, + handler_connection_attempts_by_gateway: Arc>>, + handler_connection_attempt_notify: Arc, + gateway_notifications_by_gateway: Arc>>, + gateway_notification_notify: Arc, + listener_ready: Arc, + listener_ready_notify: Arc, + retry_delay_override: Arc>>, +} + +#[cfg(test)] +impl GatewayManagerTestSupport { + fn register_gateway_url(&self, gateway_url: String, socket_path: PathBuf) { + self.socket_paths_by_url + .lock() + .expect("Failed to lock GatewayManager test socket registry") + .insert(gateway_url, socket_path); + } + + fn socket_path_for(&self, gateway: &Gateway) -> Option { + self.socket_paths_by_url + .lock() + .expect("Failed to lock GatewayManager test socket registry") + .get(&gateway.url()) + .cloned() + } + + fn note_handler_spawn_attempt(&self, gateway_id: Id) { + let mut handler_spawn_attempts = self + .handler_spawn_attempts_by_gateway + .lock() + .expect("Failed to lock GatewayManager handler spawn attempts registry"); + *handler_spawn_attempts.entry(gateway_id).or_default() += 1; + self.handler_spawn_attempt_notify.notify_waiters(); + } + + #[cfg(test)] + fn handler_spawn_attempt_count(&self, gateway_id: Id) -> u64 { + self.handler_spawn_attempts_by_gateway + .lock() + .expect("Failed to lock GatewayManager handler spawn attempts registry") + .get(&gateway_id) + .copied() + .unwrap_or_default() + } + + #[cfg(test)] + async fn wait_for_handler_spawn_attempt_count(&self, gateway_id: Id, expected_count: u64) { + loop { + if self.handler_spawn_attempt_count(gateway_id) >= expected_count { + return; + } + + let notified = self.handler_spawn_attempt_notify.notified(); + if self.handler_spawn_attempt_count(gateway_id) >= expected_count { + return; + } + + notified.await; + } + } + + fn note_handler_connection_attempt(&self, gateway_id: Id) { + let mut handler_connection_attempts = self + .handler_connection_attempts_by_gateway + .lock() + .expect("Failed to lock GatewayManager handler connection attempts registry"); + *handler_connection_attempts.entry(gateway_id).or_default() += 1; + self.handler_connection_attempt_notify.notify_waiters(); + } + + #[cfg(test)] + fn handler_connection_attempt_count(&self, gateway_id: Id) -> u64 { + self.handler_connection_attempts_by_gateway + .lock() + .expect("Failed to lock GatewayManager handler connection attempts registry") + .get(&gateway_id) + .copied() + .unwrap_or_default() + } + + #[cfg(test)] + async fn wait_for_handler_connection_attempt_count(&self, gateway_id: Id, expected_count: u64) { + loop { + if self.handler_connection_attempt_count(gateway_id) >= expected_count { + return; + } + + let notified = self.handler_connection_attempt_notify.notified(); + if self.handler_connection_attempt_count(gateway_id) >= expected_count { + return; + } + + notified.await; + } + } + + fn note_gateway_notification(&self, gateway_id: Id) { + let mut gateway_notifications = self + .gateway_notifications_by_gateway + .lock() + .expect("Failed to lock GatewayManager gateway notification registry"); + *gateway_notifications.entry(gateway_id).or_default() += 1; + self.gateway_notification_notify.notify_waiters(); + } + + #[cfg(test)] + fn gateway_notification_count(&self, gateway_id: Id) -> u64 { + self.gateway_notifications_by_gateway + .lock() + .expect("Failed to lock GatewayManager gateway notification registry") + .get(&gateway_id) + .copied() + .unwrap_or_default() + } + + #[cfg(test)] + async fn wait_for_gateway_notification_count(&self, gateway_id: Id, expected_count: u64) { + loop { + if self.gateway_notification_count(gateway_id) >= expected_count { + return; + } + + let notified = self.gateway_notification_notify.notified(); + if self.gateway_notification_count(gateway_id) >= expected_count { + return; + } + + notified.await; + } + } + + fn mark_listener_ready(&self) { + self.listener_ready.store(true, Ordering::Release); + self.listener_ready_notify.notify_waiters(); + } + + #[cfg(test)] + async fn wait_until_listener_ready(&self) { + loop { + if self.listener_ready.load(Ordering::Acquire) { + return; + } + + let notified = self.listener_ready_notify.notified(); + if self.listener_ready.load(Ordering::Acquire) { + return; + } + + notified.await; + } + } + + #[cfg(test)] + fn set_retry_delay(&self, retry_delay: Duration) { + *self + .retry_delay_override + .lock() + .expect("Failed to lock GatewayManager retry delay override") = Some(retry_delay); + } + + fn manager_reconnect_delay(&self) -> Duration { + self.retry_delay_override + .lock() + .expect("Failed to lock GatewayManager retry delay override") + .unwrap_or(GATEWAY_RECONNECT_DELAY) + } + + fn handler_reconnect_delay(&self) -> Duration { + self.retry_delay_override + .lock() + .expect("Failed to lock GatewayManager retry delay override") + .unwrap_or(TEN_SECS) + } +} + pub struct GatewayManager { clients: Arc>>, pool: PgPool, handlers: JoinSet>, + #[cfg(test)] + test_support: Option, tx: GatewayTxSet, } @@ -51,23 +262,122 @@ impl GatewayManager { clients: Arc::default(), handlers: JoinSet::new(), pool, + #[cfg(test)] + test_support: None, tx, } } + #[cfg(test)] + #[must_use] + fn new_for_test( + pool: PgPool, + tx: GatewayTxSet, + test_support: GatewayManagerTestSupport, + ) -> Self { + Self { + clients: Arc::default(), + handlers: JoinSet::new(), + pool, + test_support: Some(test_support), + tx, + } + } + + fn mark_listener_ready_for_tests(&self) { + #[cfg(test)] + if let Some(test_support) = &self.test_support { + test_support.mark_listener_ready(); + } + } + + fn note_gateway_notification_for_tests(&self, maybe_gateway_id: Option) { + #[cfg(test)] + if let (Some(gateway_id), Some(test_support)) = + (maybe_gateway_id, self.test_support.as_ref()) + { + test_support.note_gateway_notification(gateway_id); + } + + #[cfg(not(test))] + let _ = maybe_gateway_id; + } + + fn manager_reconnect_delay(&self) -> Duration { + #[cfg(test)] + { + self.test_support.as_ref().map_or( + GATEWAY_RECONNECT_DELAY, + GatewayManagerTestSupport::manager_reconnect_delay, + ) + } + + #[cfg(not(test))] + { + GATEWAY_RECONNECT_DELAY + } + } + + fn build_handler( + &self, + gateway: Gateway, + certs_rx: Receiver>>, + ) -> Result { + #[cfg(test)] + if let Some(test_support) = self.test_support.clone() { + test_support.note_handler_spawn_attempt(gateway.id); + + let mut gateway_handler = + if let Some(socket_path) = test_support.socket_path_for(&gateway) { + GatewayHandler::new_with_test_socket( + gateway, + self.pool.clone(), + self.tx.events.clone(), + self.tx.peer_stats.clone(), + certs_rx, + socket_path, + )? + } else { + GatewayHandler::new( + gateway, + self.pool.clone(), + self.tx.events.clone(), + self.tx.peer_stats.clone(), + certs_rx, + )? + }; + gateway_handler.attach_test_support(test_support); + + return Ok(gateway_handler); + } + + GatewayHandler::new( + gateway, + self.pool.clone(), + self.tx.events.clone(), + self.tx.peer_stats.clone(), + certs_rx, + ) + } + /// Bi-directional gRPC stream for communication with Defguard Gateway. pub async fn run(&mut self) -> Result<(), anyhow::Error> { let (certs_tx, certs_rx) = tokio::sync::watch::channel(Arc::new(HashMap::new())); certs::refresh_certs(&self.pool, &certs_tx).await; let refresh_pool = self.pool.clone(); - tokio::spawn(async move { + let _refresh_certs_task = AbortTaskOnDrop::new(tokio::spawn(async move { loop { certs::refresh_certs(&refresh_pool, &certs_tx).await; tokio::time::sleep(TEN_SECS).await; } - }); + })); let mut abort_handles = HashMap::new(); for gateway in Gateway::all(&self.pool).await? { + if !gateway.enabled { + debug!("Existing Gateway is disabled, so it won't be handled"); + continue; + } + let id = gateway.id; let abort_handle = self.run_handler(gateway, Arc::clone(&self.clients), certs_rx.clone())?; @@ -77,12 +387,17 @@ impl GatewayManager { // Observe gateway URL changes. let mut listener = PgListener::connect_with(&self.pool).await?; listener.listen(GATEWAY_TABLE_TRIGGER).await?; + self.mark_listener_ready_for_tests(); while let Ok(notification) = listener.recv().await { let payload = notification.payload(); match serde_json::from_str::>>(payload) { - Ok(gateway_notification) => match gateway_notification.operation { - TriggerOperation::Insert => { - if let Some(new) = gateway_notification.new { + Ok(gateway_notification) => { + let maybe_gateway_id = match gateway_notification.operation { + TriggerOperation::Insert => { + let Some(new) = gateway_notification.new else { + continue; + }; + let id = new.id; if new.enabled { let abort_handle = self.run_handler( @@ -94,77 +409,94 @@ impl GatewayManager { } else { debug!("New Gateway is disabled, so it won't be handled"); } + + Some(id) } - } - TriggerOperation::Update => { - let (Some(old), Some(new)) = - (gateway_notification.old, gateway_notification.new) - else { - continue; - }; - if old.address == new.address - && old.port == new.port - && old.enabled == new.enabled - { - debug!("Gateway address/port/state didn't change"); - continue; - } - if let Some(abort_handle) = abort_handles.remove(&old.id) { - info!( - "Aborting connection to Gateway {old}, it has changed in the \ - database" - ); - abort_handle.abort(); - } else if old.enabled { - warn!("Cannot find Gateway {old} on the list of connected gateways"); - } - if new.enabled { + TriggerOperation::Update => { + let (Some(mut old), Some(new)) = + (gateway_notification.old, gateway_notification.new) + else { + continue; + }; + let id = new.id; - let abort_handle = - self.run_handler(new, Arc::clone(&self.clients), certs_rx.clone())?; - abort_handles.insert(id, abort_handle); - } else { - debug!("Updated Gateway is disabled, so it won't be handled"); - } - } - TriggerOperation::Delete => { - let Some(old) = gateway_notification.old else { - continue; - }; - - // Send purge request to Gateway. - let maybe_client = { - self.clients - .lock() - .expect("Failed to lock GatewayManager::clients") - .remove(&old.id) - }; - - if let Some(mut client) = maybe_client { - debug!("Sending purge request to Gateway {old}"); - if let Err(err) = client.purge(Request::new(())).await { - error!("Error sending purge request to Gateway {old}: {err}"); + if old.address == new.address + && old.port == new.port + && old.enabled == new.enabled + { + debug!("Gateway address/port/state didn't change"); } else { - info!("Sent purge request to Gateway {old}"); + self.remove_client(old.id); + if let Some(abort_handle) = abort_handles.remove(&old.id) { + if let Err(err) = old.touch_disconnected(&self.pool).await { + error!( + "Failed to update disconnection time for Gateway {old} after database change: {err}" + ); + } + info!( + "Aborting connection to Gateway {old}, it has changed in the \ + database" + ); + abort_handle.abort(); + } else if old.enabled { + warn!( + "Cannot find Gateway {old} on the list of connected gateways" + ); + } + if new.enabled { + let abort_handle = self.run_handler( + new, + Arc::clone(&self.clients), + certs_rx.clone(), + )?; + abort_handles.insert(id, abort_handle); + } else { + debug!("Updated Gateway is disabled, so it won't be handled"); + } } - } else { - warn!( - "Cannot find gRPC client for Gateway {old}; skipping purge request" - ); + + Some(id) } + TriggerOperation::Delete => { + let Some(old) = gateway_notification.old else { + continue; + }; + + // Send purge request to Gateway. + let maybe_client = self.remove_client(old.id); + + if let Some(mut client) = maybe_client { + debug!("Sending purge request to Gateway {old}"); + if let Err(err) = client.purge(Request::new(())).await { + error!("Error sending purge request to Gateway {old}: {err}"); + } else { + info!("Sent purge request to Gateway {old}"); + } + } else { + warn!( + "Cannot find gRPC client for Gateway {old}; skipping purge request" + ); + } + + // Kill the `GatewayHandler` and the connection. + if let Some(abort_handle) = abort_handles.remove(&old.id) { + info!( + "Aborting connection to Gateway {old}, it has disappeared from the \ + database" + ); + abort_handle.abort(); + } else if old.enabled { + warn!( + "Cannot find Gateway {old} on the list of connected gateways" + ); + } - // Kill the `GatewayHandler` and the connection. - if let Some(abort_handle) = abort_handles.remove(&old.id) { - info!( - "Aborting connection to Gateway {old}, it has disappeard from the \ - database" - ); - abort_handle.abort(); - } else if old.enabled { - warn!("Cannot find Gateway {old} on the list of connected gateways"); + Some(old.id) } - } - }, + }; + + self.note_gateway_notification_for_tests(maybe_gateway_id); + } Err(err) => error!("Failed to de-serialize database notification object: {err}"), } } @@ -182,26 +514,30 @@ impl GatewayManager { clients: Arc>>, certs_rx: Receiver>>, ) -> Result { - let mut gateway_handler = GatewayHandler::new( - gateway, - self.pool.clone(), - self.tx.events.clone(), - self.tx.peer_stats.clone(), - certs_rx.clone(), - )?; + let mut gateway_handler = self.build_handler(gateway, certs_rx)?; + let manager_reconnect_delay = self.manager_reconnect_delay(); let abort_handle = self.handlers.spawn(async move { loop { if let Err(err) = gateway_handler .handle_connection(Arc::clone(&clients)) .await { - error!("Gateway connection error: {err}, retrying in 5 seconds..."); - tokio::time::sleep(GATEWAY_RECONNECT_DELAY).await; + error!( + "Gateway connection error: {err}, retrying in {manager_reconnect_delay:?}..." + ); + tokio::time::sleep(manager_reconnect_delay).await; } } }); Ok(abort_handle) } + + fn remove_client(&self, gateway_id: Id) -> Option { + self.clients + .lock() + .expect("Failed to lock GatewayManager::clients") + .remove(&gateway_id) + } } /// Shared set of outbound channels that gateway instances use to forward diff --git a/crates/defguard_gateway_manager/src/tests.rs b/crates/defguard_gateway_manager/src/tests.rs deleted file mode 100644 index e6fbe3cfc6..0000000000 --- a/crates/defguard_gateway_manager/src/tests.rs +++ /dev/null @@ -1,105 +0,0 @@ -use std::{ - io, - net::{IpAddr, Ipv4Addr}, - sync::{Arc, Mutex}, -}; - -use ipnetwork::IpNetwork; -use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; -use tokio::{ - net::UnixListener, - sync::{broadcast, mpsc::unbounded_channel}, -}; -use tokio_stream::wrappers::{UnboundedReceiverStream, UnixListenerStream}; -use tonic::{Request, Response, Status, Streaming, transport::Server}; - -use defguard_common::db::{ - models::{ - gateway::Gateway, - wireguard::{LocationMfaMode, ServiceLocationMode, WireguardNetwork}, - }, - setup_pool, -}; -use defguard_mail::Mail; -use defguard_proto::gateway::{CoreRequest, CoreResponse, gateway_server}; - -use super::{TONIC_SOCKET, handler::GatewayHandler}; -use crate::grpc::{ClientMap, GrpcEvent, gateway::events::GatewayEvent}; - -// TODO: move to "gateway" repo. -struct FakeGateway; - -#[tonic::async_trait] -impl gateway_server::Gateway for FakeGateway { - type BidiStream = UnboundedReceiverStream>; - - async fn bidi( - &self, - request: Request>, - ) -> Result, Status> { - let (_tx, rx) = unbounded_channel(); - let mut stream = request.into_inner(); - tokio::spawn(async move { - loop { - match stream.message().await { - Ok(Some(_response)) => (), - Ok(None) => (), - Err(_err) => (), - } - } - }); - - Ok(Response::new(UnboundedReceiverStream::new(rx))) - } -} - -async fn fake_gateway() -> Result<(), io::Error> { - let gateway = FakeGateway {}; - - let uds = UnixListener::bind(TONIC_SOCKET)?; - let uds_stream = UnixListenerStream::new(uds); - - Server::builder() - .add_service(gateway_server::GatewayServer::new(gateway)) - .serve_with_incoming(uds_stream) - .await - .unwrap(); - - Ok(()) -} - -#[sqlx::test] -async fn test_gateway(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - let network = WireguardNetwork::new( - "TestNet".to_string(), - 50051, - "0.0.0.0".to_string(), - None, - [IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 1, 1, 0)), 24).unwrap()], - false, - false, - false, - LocationMfaMode::default(), - ServiceLocationMode::default(), - ) - .set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 1, 1, 1)), 24).unwrap()]) - .expect("test network address should be valid") - .save(&pool) - .await - .unwrap(); - let gateway = Gateway::new(network.id, "http://[::]:50051") - .save(&pool) - .await - .unwrap(); - let client_state = Arc::new(Mutex::new(ClientMap::new())); - let (events_tx, _events_rx) = broadcast::channel::(16); - let (grpc_event_tx, _grpc_event_rx) = unbounded_channel::(); - - let mut gateway_handler = - GatewayHandler::new(gateway, None, pool, client_state, events_tx, grpc_event_tx).unwrap(); - let handle = tokio::spawn(async move { - gateway_handler.handle_connection().await; - }); - handle.abort(); -} diff --git a/crates/defguard_gateway_manager/src/tests/common/mod.rs b/crates/defguard_gateway_manager/src/tests/common/mod.rs new file mode 100644 index 0000000000..16ba51f528 --- /dev/null +++ b/crates/defguard_gateway_manager/src/tests/common/mod.rs @@ -0,0 +1,721 @@ +use std::{ + collections::HashMap, + io, + path::PathBuf, + sync::{ + Arc, Mutex, + atomic::{AtomicU64, Ordering}, + }, + time::Duration, +}; + +use defguard_common::{ + db::{ + Id, NoId, + models::{ + gateway::Gateway, settings::initialize_current_settings, wireguard::WireguardNetwork, + }, + setup_pool, + }, + messages::peer_stats_update::PeerStatsUpdate, +}; +use defguard_core::grpc::GatewayEvent; +use defguard_proto::gateway::{ + ConfigurationRequest, CoreRequest, CoreResponse, PeerStats, core_request, gateway_server, +}; +use sqlx::{PgPool, postgres::PgConnectOptions}; +use tokio::{ + net::UnixListener, + sync::{ + Notify, broadcast, + mpsc::{self, UnboundedReceiver, UnboundedSender}, + oneshot, watch, + }, + task::JoinHandle, + time::timeout, +}; +use tokio_stream::{once, wrappers::UnboundedReceiverStream}; +use tonic::{Request, Response, Status, Streaming, transport::Server}; + +use crate::{GatewayManager, GatewayManagerTestSupport, GatewayTxSet, handler::GatewayHandler}; + +const TEST_TIMEOUT: Duration = Duration::from_secs(2); + +macro_rules! assert_some { + ($expr:expr, $message:literal) => { + match $expr { + Some(value) => value, + None => panic!($message), + } + }; +} + +static TEST_ID: AtomicU64 = AtomicU64::new(0); + +fn next_test_id() -> u64 { + TEST_ID.fetch_add(1, Ordering::Relaxed) +} + +fn unique_name(prefix: &str) -> String { + format!("{prefix}-{}", next_test_id()) +} + +fn unique_socket_path() -> PathBuf { + PathBuf::from(format!( + "/tmp/defguard-gateway-manager-{}-{}.sock", + std::process::id(), + next_test_id() + )) +} + +pub(crate) fn unique_mock_gateway_socket_path() -> PathBuf { + unique_socket_path() +} + +#[derive(Clone)] +struct MockGatewayService { + state: Arc, +} + +struct MockGatewayState { + outbound_tx: UnboundedSender, + inbound_rx: Mutex>>>, + connected_tx: Mutex>>, + connection_count: AtomicU64, + connection_notify: Notify, + purge_count: AtomicU64, + purge_notify: Notify, +} + +impl MockGatewayState { + fn notify_connected(&self) { + self.connection_count.fetch_add(1, Ordering::Relaxed); + self.connection_notify.notify_waiters(); + + if let Some(tx) = self + .connected_tx + .lock() + .expect("failed to lock connected notifier") + .take() + { + let _ = tx.send(()); + } + } + + fn take_inbound_rx(&self) -> Result>, Status> { + self.inbound_rx + .lock() + .expect("failed to lock inbound receiver") + .take() + .ok_or_else(|| Status::failed_precondition("mock gateway already connected")) + } + + fn note_purge(&self) { + self.purge_count.fetch_add(1, Ordering::Relaxed); + self.purge_notify.notify_waiters(); + } +} + +#[tonic::async_trait] +impl gateway_server::Gateway for MockGatewayService { + type BidiStream = UnboundedReceiverStream>; + + async fn bidi( + &self, + request: Request>, + ) -> Result, Status> { + let inbound_rx = self.state.take_inbound_rx()?; + self.state.notify_connected(); + + let mut outbound_stream = request.into_inner(); + let outbound_tx = self.state.outbound_tx.clone(); + tokio::spawn(async move { + while let Ok(Some(response)) = outbound_stream.message().await { + if outbound_tx.send(response).is_err() { + break; + } + } + }); + + Ok(Response::new(UnboundedReceiverStream::new(inbound_rx))) + } + + async fn purge(&self, _request: Request<()>) -> Result, Status> { + self.state.note_purge(); + Ok(Response::new(())) + } +} + +pub(crate) struct MockGatewayHarness { + state: Arc, + socket_path: PathBuf, + inbound_tx: Option>>, + outbound_rx: UnboundedReceiver, + connected_rx: oneshot::Receiver<()>, + server_task: Option>>, + next_message_id: AtomicU64, +} + +impl MockGatewayHarness { + pub(crate) async fn start() -> Self { + Self::start_at(unique_socket_path()).await + } + + pub(crate) async fn start_at(socket_path: PathBuf) -> Self { + let _ = std::fs::remove_file(&socket_path); + + let listener = + UnixListener::bind(&socket_path).expect("failed to bind mock gateway unix socket"); + let (outbound_tx, outbound_rx) = mpsc::unbounded_channel(); + let (inbound_tx, inbound_rx) = mpsc::unbounded_channel(); + let (connected_tx, connected_rx) = oneshot::channel(); + let state = Arc::new(MockGatewayState { + outbound_tx, + inbound_rx: Mutex::new(Some(inbound_rx)), + connected_tx: Mutex::new(Some(connected_tx)), + connection_count: AtomicU64::new(0), + connection_notify: Notify::new(), + purge_count: AtomicU64::new(0), + purge_notify: Notify::new(), + }); + let service = MockGatewayService { + state: Arc::clone(&state), + }; + + let server_task = tokio::spawn(async move { + let (stream, _) = listener.accept().await?; + Server::builder() + .add_service(gateway_server::GatewayServer::new(service)) + .serve_with_incoming(once(Ok::<_, io::Error>(stream))) + .await + .map_err(io::Error::other) + }); + + Self { + state, + socket_path, + inbound_tx: Some(inbound_tx), + outbound_rx, + connected_rx, + server_task: Some(server_task), + next_message_id: AtomicU64::new(1), + } + } + + pub(crate) fn socket_path(&self) -> PathBuf { + self.socket_path.clone() + } + + pub(crate) async fn wait_connected(&mut self) { + timeout(TEST_TIMEOUT, &mut self.connected_rx) + .await + .expect("timed out waiting for mock gateway connection") + .expect("mock gateway connection notifier dropped"); + } + + pub(crate) fn connection_count(&self) -> u64 { + self.state.connection_count.load(Ordering::Relaxed) + } + + pub(crate) async fn wait_for_connection_count(&self, expected_count: u64) { + timeout(TEST_TIMEOUT, async { + loop { + if self.connection_count() >= expected_count { + return; + } + + let notified = self.state.connection_notify.notified(); + if self.connection_count() >= expected_count { + return; + } + + notified.await; + } + }) + .await + .expect("timed out waiting for mock gateway connection count"); + } + + pub(crate) async fn wait_purged(&self) { + timeout(TEST_TIMEOUT, async { + loop { + if self.state.purge_count.load(Ordering::Relaxed) > 0 { + return; + } + + let notified = self.state.purge_notify.notified(); + if self.state.purge_count.load(Ordering::Relaxed) > 0 { + return; + } + + notified.await; + } + }) + .await + .expect("timed out waiting for purge request"); + } + + pub(crate) fn send_config_request(&self) { + let request = ConfigurationRequest { + hostname: "mock-gateway".to_string(), + ..Default::default() + }; + self.send_request(CoreRequest { + id: self.next_message_id.fetch_add(1, Ordering::Relaxed), + payload: Some(core_request::Payload::ConfigRequest(request)), + }); + } + + pub(crate) fn send_peer_stats(&self, peer_stats: PeerStats) { + self.send_request(CoreRequest { + id: self.next_message_id.fetch_add(1, Ordering::Relaxed), + payload: Some(core_request::Payload::PeerStats(peer_stats)), + }); + } + + pub(crate) fn send_stream_error(&self, status: Status) { + self.inbound_tx + .as_ref() + .expect("mock gateway inbound channel already closed") + .send(Err(status)) + .expect("failed to inject inbound stream error"); + } + + fn send_request(&self, request: CoreRequest) { + self.inbound_tx + .as_ref() + .expect("mock gateway inbound channel already closed") + .send(Ok(request)) + .expect("failed to inject mock gateway request"); + } + + pub(crate) fn close_stream(&mut self) { + self.inbound_tx.take(); + } + + pub(crate) async fn recv_outbound(&mut self) -> CoreResponse { + timeout(TEST_TIMEOUT, self.outbound_rx.recv()) + .await + .expect("timed out waiting for outbound response") + .expect("mock gateway outbound response channel closed unexpectedly") + } + + pub(crate) async fn expect_no_outbound(&mut self) { + if let Ok(Some(_message)) = + timeout(Duration::from_millis(200), self.outbound_rx.recv()).await + { + panic!("unexpected outbound response"); + } + } + + pub(crate) async fn expect_server_finished(mut self) { + let server_task = assert_some!( + self.server_task.take(), + "mock gateway server task already taken" + ); + let server_result = timeout(TEST_TIMEOUT, server_task) + .await + .expect("timed out waiting for mock gateway server to finish") + .expect("mock gateway server task panicked"); + server_result.expect("mock gateway server exited with error"); + } +} + +impl Drop for MockGatewayHarness { + fn drop(&mut self) { + if let Some(server_task) = self.server_task.take() { + server_task.abort(); + } + let _ = std::fs::remove_file(&self.socket_path); + } +} + +pub(crate) struct ManagerTestContext { + pub(crate) pool: PgPool, + control: GatewayManagerTestSupport, + manager_task: Option>>, +} + +impl ManagerTestContext { + pub(crate) async fn new(options: PgConnectOptions) -> Self { + let pool = setup_pool(options).await; + initialize_current_settings(&pool) + .await + .expect("failed to initialize global settings for gateway manager tests"); + + Self { + pool, + control: GatewayManagerTestSupport::default(), + manager_task: None, + } + } + + pub(crate) fn register_gateway_mock( + &self, + gateway: &Gateway, + mock_gateway: &MockGatewayHarness, + ) { + self.register_gateway_url(gateway.url(), mock_gateway); + } + + pub(crate) fn register_gateway_url( + &self, + gateway_url: String, + mock_gateway: &MockGatewayHarness, + ) { + self.register_gateway_socket_path(gateway_url, mock_gateway.socket_path()); + } + + pub(crate) fn register_gateway_socket_path(&self, gateway_url: String, socket_path: PathBuf) { + self.control.register_gateway_url(gateway_url, socket_path); + } + + pub(crate) fn handler_spawn_attempt_count(&self, gateway_id: Id) -> u64 { + self.control.handler_spawn_attempt_count(gateway_id) + } + + pub(crate) fn handler_connection_attempt_count(&self, gateway_id: Id) -> u64 { + self.control.handler_connection_attempt_count(gateway_id) + } + + pub(crate) fn gateway_notification_count(&self, gateway_id: Id) -> u64 { + self.control.gateway_notification_count(gateway_id) + } + + pub(crate) async fn wait_for_handler_spawn_attempt_count( + &self, + gateway_id: Id, + expected_count: u64, + ) { + timeout( + TEST_TIMEOUT, + self.control + .wait_for_handler_spawn_attempt_count(gateway_id, expected_count), + ) + .await + .expect("timed out waiting for gateway manager handler spawn attempt"); + } + + pub(crate) async fn wait_for_handler_connection_attempt_count( + &self, + gateway_id: Id, + expected_count: u64, + ) { + timeout( + TEST_TIMEOUT, + self.control + .wait_for_handler_connection_attempt_count(gateway_id, expected_count), + ) + .await + .expect("timed out waiting for gateway manager handler connection attempt"); + } + + pub(crate) async fn wait_for_gateway_notification_count( + &self, + gateway_id: Id, + expected_count: u64, + ) { + timeout( + TEST_TIMEOUT, + self.control + .wait_for_gateway_notification_count(gateway_id, expected_count), + ) + .await + .expect("timed out waiting for gateway manager database notification"); + } + + pub(crate) async fn start(&mut self) { + assert!( + self.manager_task.is_none(), + "gateway manager already started" + ); + + let (events_tx, _) = broadcast::channel(16); + let (peer_stats_tx, _peer_stats_rx) = mpsc::unbounded_channel(); + let tx = GatewayTxSet::new(events_tx, peer_stats_tx); + let mut manager = GatewayManager::new_for_test(self.pool.clone(), tx, self.control.clone()); + let manager_task = tokio::spawn(async move { manager.run().await }); + + timeout(TEST_TIMEOUT, self.control.wait_until_listener_ready()) + .await + .expect("timed out waiting for gateway manager listener to become ready"); + self.manager_task = Some(manager_task); + } + + pub(crate) fn set_retry_delay(&self, retry_delay: Duration) { + self.control.set_retry_delay(retry_delay); + } + + pub(crate) async fn finish(mut self) { + if let Some(manager_task) = self.manager_task.take() { + manager_task.abort(); + + match manager_task.await { + Err(err) if err.is_cancelled() => {} + Err(err) => panic!("gateway manager task panicked: {err}"), + Ok(Ok(())) => {} + Ok(Err(err)) => panic!("gateway manager exited with error: {err}"), + } + } + + self.pool.close().await; + } +} + +impl Drop for ManagerTestContext { + fn drop(&mut self) { + if let Some(manager_task) = self.manager_task.take() { + manager_task.abort(); + } + } +} + +pub(crate) struct HandlerTestContext { + pub(crate) pool: PgPool, + pub(crate) network: WireguardNetwork, + pub(crate) gateway: Gateway, + pub(crate) peer_stats_rx: UnboundedReceiver, + events_tx: Option>, + pub(crate) mock_gateway: Option, + handler_task: Option>>, +} + +impl HandlerTestContext { + pub(crate) async fn new(options: PgConnectOptions) -> Self { + let (events_tx, _) = broadcast::channel(16); + Self::new_with_events_tx(options, events_tx).await + } + + pub(crate) async fn new_with_events_tx( + options: PgConnectOptions, + events_tx: broadcast::Sender, + ) -> Self { + let pool = setup_pool(options).await; + initialize_current_settings(&pool) + .await + .expect("failed to initialize global settings for gateway handler tests"); + let network = create_network(&pool).await; + let gateway = create_gateway(&pool, network.id).await; + let (peer_stats_tx, peer_stats_rx) = mpsc::unbounded_channel(); + let (_, certs_rx) = watch::channel(Arc::new(HashMap::new())); + let mut mock_gateway = MockGatewayHarness::start().await; + let mut handler = GatewayHandler::new_with_test_socket( + gateway.clone(), + pool.clone(), + events_tx.clone(), + peer_stats_tx, + certs_rx, + mock_gateway.socket_path(), + ) + .expect("failed to create gateway handler"); + let handler_task = tokio::spawn(async move { handler.handle_connection_once().await }); + + mock_gateway.wait_connected().await; + + Self { + pool, + network, + gateway, + peer_stats_rx, + events_tx: Some(events_tx), + mock_gateway: Some(mock_gateway), + handler_task: Some(handler_task), + } + } + + pub(crate) fn events_tx(&self) -> &broadcast::Sender { + self.events_tx + .as_ref() + .expect("events sender already taken from context") + } + + pub(crate) fn mock_gateway(&self) -> &MockGatewayHarness { + self.mock_gateway + .as_ref() + .expect("mock gateway already taken from context") + } + + pub(crate) fn mock_gateway_mut(&mut self) -> &mut MockGatewayHarness { + self.mock_gateway + .as_mut() + .expect("mock gateway already taken from context") + } + + pub(crate) async fn reload_gateway(&self) -> Gateway { + Gateway::find_by_id(&self.pool, self.gateway.id) + .await + .expect("failed to query gateway from database") + .expect("expected gateway in database") + } + + pub(crate) async fn create_other_network(&self) -> WireguardNetwork { + create_network(&self.pool).await + } + + pub(crate) async fn expect_no_peer_stats(&mut self) { + if let Ok(Some(message)) = + timeout(Duration::from_millis(200), self.peer_stats_rx.recv()).await + { + panic!("unexpected peer stats update: {message:?}"); + } + } + + pub(crate) async fn recv_peer_stats(&mut self) -> PeerStatsUpdate { + timeout(TEST_TIMEOUT, self.peer_stats_rx.recv()) + .await + .expect("timed out waiting for peer stats update") + .expect("peer stats channel unexpectedly closed") + } + + pub(crate) async fn complete_config_handshake(&mut self) -> Gateway { + let initial_event_receivers = self.events_tx().receiver_count(); + self.mock_gateway().send_config_request(); + let _ = self.mock_gateway_mut().recv_outbound().await; + let connected_gateway = + wait_for_gateway_connection_state(&self.pool, self.gateway.id, true).await; + timeout(TEST_TIMEOUT, async { + while self.events_tx().receiver_count() <= initial_event_receivers { + tokio::time::sleep(Duration::from_millis(20)).await; + } + }) + .await + .expect("timed out waiting for gateway updates handler subscription"); + tokio::task::yield_now().await; + connected_gateway + } + + pub(crate) async fn finish(mut self) -> MockGatewayHarness { + let mut mock_gateway = assert_some!( + self.mock_gateway.take(), + "mock gateway already taken from context" + ); + mock_gateway.close_stream(); + let handler_task = self + .handler_task + .as_mut() + .expect("handler task already taken from context"); + let result = timeout(TEST_TIMEOUT, handler_task) + .await + .expect("timed out waiting for handler task to finish") + .expect("gateway handler task panicked"); + result.expect("gateway handler returned an unexpected error"); + self.handler_task.take(); + self.events_tx.take(); + mock_gateway + } + + pub(crate) async fn finish_after_error(mut self) -> MockGatewayHarness { + let mock_gateway = assert_some!( + self.mock_gateway.take(), + "mock gateway already taken from context" + ); + let handler_task = self + .handler_task + .as_mut() + .expect("handler task already taken from context"); + let result = timeout(TEST_TIMEOUT, handler_task) + .await + .expect("timed out waiting for handler task to finish after stream error") + .expect("gateway handler task panicked after stream error"); + result.expect("gateway handler returned an unexpected error after stream error"); + self.handler_task.take(); + self.events_tx.take(); + mock_gateway + } +} + +impl Drop for HandlerTestContext { + fn drop(&mut self) { + if let Some(handler_task) = self.handler_task.take() { + handler_task.abort(); + } + } +} + +pub(crate) async fn reload_gateway(pool: &PgPool, gateway_id: Id) -> Gateway { + Gateway::find_by_id(pool, gateway_id) + .await + .expect("failed to query gateway from database") + .expect("expected gateway in database") +} + +pub(crate) async fn wait_for_gateway_connection_state( + pool: &PgPool, + gateway_id: Id, + expected_connected: bool, +) -> Gateway { + timeout(TEST_TIMEOUT, async { + loop { + let gateway = reload_gateway(pool, gateway_id).await; + if gateway.is_connected() == expected_connected { + return gateway; + } + + tokio::time::sleep(Duration::from_millis(20)).await; + } + }) + .await + .expect("timed out waiting for gateway connection state change") +} + +pub(crate) fn build_peer_stats(endpoint: &str) -> PeerStats { + PeerStats { + public_key: "peer-public-key".to_string(), + endpoint: endpoint.to_string(), + upload: 123, + download: 456, + keepalive_interval: 25, + latest_handshake: 1_700_000_000, + allowed_ips: "10.10.0.2/32".to_string(), + } +} + +pub(crate) async fn create_network(pool: &PgPool) -> WireguardNetwork { + let network = WireguardNetwork::new( + unique_name("network"), + 51820, + "198.51.100.10".to_string(), + None, + Vec::new(), + false, + false, + false, + Default::default(), + Default::default(), + ) + .try_set_address("10.10.0.1/24") + .expect("failed to set network address"); + network + .save(pool) + .await + .expect("failed to create test network") +} + +pub(crate) async fn create_gateway(pool: &PgPool, location_id: Id) -> Gateway { + create_gateway_with_enabled(pool, location_id, true).await +} + +pub(crate) async fn create_gateway_with_enabled( + pool: &PgPool, + location_id: Id, + enabled: bool, +) -> Gateway { + let gateway = build_gateway_with_enabled(location_id, enabled); + gateway + .save(pool) + .await + .expect("failed to create test gateway") +} + +pub(crate) fn build_gateway_with_enabled(location_id: Id, enabled: bool) -> Gateway { + let port = 20_000 + i32::try_from(next_test_id() % 40_000).expect("port offset fits in i32"); + let mut gateway = Gateway::new( + location_id, + unique_name("gateway"), + "127.0.0.1".to_string(), + port, + "test-admin".to_string(), + ); + gateway.enabled = enabled; + gateway +} diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler.rs new file mode 100644 index 0000000000..35896035f6 --- /dev/null +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler.rs @@ -0,0 +1,28 @@ +#[path = "handler/support.rs"] +mod support; + +use defguard_common::db::models::device::{DeviceInfo, WireguardNetworkDevice}; +use defguard_core::grpc::GatewayEvent; +use defguard_proto::gateway::{UpdateType, core_response}; +use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; +use tonic::Status; + +use self::support::{ + assert_device_event_for_different_network_is_ignored, + assert_device_event_is_ignored_before_config_handshake, assert_firewall_disable_update, + assert_firewall_event_for_different_network_is_ignored, assert_firewall_modify_update, + assert_network_create_update, assert_network_delete_update, assert_network_modify_update, + assert_peer_update, assert_send_ok, build_test_firewall_config, + create_authorized_mfa_device_for_current_network, create_authorized_mfa_device_for_network, + create_device_for_network, create_device_info_for_current_network, + enable_internal_mfa_for_network, expected_keepalive_interval, panic_unexpected, parse_test_ip, +}; +use crate::tests::common::{HandlerTestContext, build_peer_stats, reload_gateway}; + +include!("handler/handshake.rs"); +include!("handler/lifecycle.rs"); +include!("handler/stats.rs"); +include!("handler/network_events.rs"); +include!("handler/firewall_events.rs"); +include!("handler/device_events.rs"); +include!("handler/mfa.rs"); diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/device_events.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/device_events.rs new file mode 100644 index 0000000000..a15b8cbd14 --- /dev/null +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/device_events.rs @@ -0,0 +1,217 @@ +#[sqlx::test] +async fn test_device_created_for_network_produces_peer_create_update( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + let expected_keepalive_interval = expected_keepalive_interval(&context); + + let _ = context.complete_config_handshake().await; + let device_info = create_device_info_for_current_network( + &context, + "created-peer-device", + "LQKsT6/3HWKuJmMulH63R8iK+5sI8FyYEL6WDIi6lQU=", + "10.10.0.10", + ) + .await; + + assert_send_ok!( + context + .events_tx() + .send(GatewayEvent::DeviceCreated(device_info)), + "failed to broadcast created device event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_peer_update( + outbound, + UpdateType::Create, + "LQKsT6/3HWKuJmMulH63R8iK+5sI8FyYEL6WDIi6lQU=", + &["10.10.0.10"], + None, + Some(expected_keepalive_interval), + ); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_device_created_before_config_handshake_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + assert_device_event_is_ignored_before_config_handshake( + options, + "created-before-config-device", + "tND8hJQhYnI8naBTo59He43zYldagfjlwmSxWEc01Cc=", + "10.10.0.11", + GatewayEvent::DeviceCreated, + ) + .await; +} + +#[sqlx::test] +async fn test_device_modified_for_network_produces_peer_modify_update( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + let expected_keepalive_interval = expected_keepalive_interval(&context); + + let _ = context.complete_config_handshake().await; + let device = create_device_for_network( + &context, + context.network.id, + "modified-peer-device", + "TJgN9JzUF5zdZAPYD96G/Wys2M3TvaT5TIrErUl20nI=", + "10.10.0.20", + ) + .await; + + let mut network_device = WireguardNetworkDevice::find(&context.pool, device.id, context.network.id) + .await + .expect("failed to load device network info") + .expect("expected device network info for modified device"); + network_device.wireguard_ips = vec![parse_test_ip("10.10.0.21")]; + network_device + .update(&context.pool) + .await + .expect("failed to update device network info"); + + let device_info = DeviceInfo::from_device(&context.pool, device) + .await + .expect("failed to load modified device info"); + + assert_send_ok!( + context + .events_tx() + .send(GatewayEvent::DeviceModified(device_info)), + "failed to broadcast modified device event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_peer_update( + outbound, + UpdateType::Modify, + "TJgN9JzUF5zdZAPYD96G/Wys2M3TvaT5TIrErUl20nI=", + &["10.10.0.21"], + None, + Some(expected_keepalive_interval), + ); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_device_modified_before_config_handshake_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + assert_device_event_is_ignored_before_config_handshake( + options, + "modified-before-config-device", + "wyFOHCec/Fi9s+cARikVO71JhyYtYMk0FrQx3fK2PTM=", + "10.10.0.22", + GatewayEvent::DeviceModified, + ) + .await; +} + +#[sqlx::test] +async fn test_device_deleted_for_network_produces_peer_delete_update( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + let _ = context.complete_config_handshake().await; + let device_info = create_device_info_for_current_network( + &context, + "deleted-peer-device", + "PKY3zg5/ecNyMjqLi6yJ3jwb4PvC/SGzjhJ3jrn2vVQ=", + "10.10.0.30", + ) + .await; + + assert_send_ok!( + context + .events_tx() + .send(GatewayEvent::DeviceDeleted(device_info)), + "failed to broadcast deleted device event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_peer_update( + outbound, + UpdateType::Delete, + "PKY3zg5/ecNyMjqLi6yJ3jwb4PvC/SGzjhJ3jrn2vVQ=", + &[], + None, + None, + ); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_device_deleted_before_config_handshake_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + assert_device_event_is_ignored_before_config_handshake( + options, + "deleted-before-config-device", + "m84QJmDMkqdCj8AB2NTE8F55W7M/i3CaaD3eQbQdInY=", + "10.10.0.31", + GatewayEvent::DeviceDeleted, + ) + .await; +} + +#[sqlx::test] +async fn test_device_created_for_different_network_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + assert_device_event_for_different_network_is_ignored( + options, + "created-other-network-device", + "W6wBmd8wgTwvCyGqDRXk6Hf4OMqDUbUn2XWKnG5wVVQ=", + "10.11.0.10", + GatewayEvent::DeviceCreated, + ) + .await; +} + +#[sqlx::test] +async fn test_device_modified_for_different_network_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + assert_device_event_for_different_network_is_ignored( + options, + "modified-other-network-device", + "yjuzq0cLk3Ww5oQcqK6YkSKwXnqQ1V9OlSMFAEkr0lU=", + "10.11.0.20", + GatewayEvent::DeviceModified, + ) + .await; +} + +#[sqlx::test] +async fn test_device_deleted_for_different_network_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + assert_device_event_for_different_network_is_ignored( + options, + "deleted-other-network-device", + "Jtp+K8xnFXuF4cae+tVGZNwoSM2fXjJbRl3sI6rdcAQ=", + "10.11.0.30", + GatewayEvent::DeviceDeleted, + ) + .await; +} diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/firewall_events.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/firewall_events.rs new file mode 100644 index 0000000000..1545aab35a --- /dev/null +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/firewall_events.rs @@ -0,0 +1,73 @@ +#[sqlx::test] +async fn test_matching_location_firewall_config_changed_event_produces_update( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + let expected_firewall_config = build_test_firewall_config(); + + let _ = context.complete_config_handshake().await; + + assert_send_ok!( + context + .events_tx() + .send(GatewayEvent::FirewallConfigChanged( + context.network.id, + expected_firewall_config.clone(), + )), + "failed to broadcast firewall config changed event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_firewall_modify_update(outbound, &expected_firewall_config); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_matching_location_firewall_disabled_event_produces_disable_update( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + let _ = context.complete_config_handshake().await; + + assert_send_ok!( + context + .events_tx() + .send(GatewayEvent::FirewallDisabled(context.network.id)), + "failed to broadcast firewall disabled event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_firewall_disable_update(outbound); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_different_location_firewall_config_changed_event_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let expected_firewall_config = build_test_firewall_config(); + + assert_firewall_event_for_different_network_is_ignored(options, move |other_network_id| { + GatewayEvent::FirewallConfigChanged(other_network_id, expected_firewall_config) + }) + .await; +} + +#[sqlx::test] +async fn test_different_location_firewall_disabled_event_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + assert_firewall_event_for_different_network_is_ignored(options, |other_network_id| { + GatewayEvent::FirewallDisabled(other_network_id) + }) + .await; +} diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/handshake.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/handshake.rs new file mode 100644 index 0000000000..bd83c1e5d6 --- /dev/null +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/handshake.rs @@ -0,0 +1,56 @@ +#[sqlx::test] +async fn test_sends_configuration_on_first_config_request( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + context.mock_gateway().send_config_request(); + let outbound = context.mock_gateway_mut().recv_outbound().await; + + match outbound.payload { + Some(core_response::Payload::Config(config)) => { + assert_eq!(config.name, context.network.name); + assert_eq!(config.port, context.network.port as u32); + assert_eq!(config.peers, Vec::new()); + } + _ => panic_unexpected("expected configuration response"), + } + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_does_not_send_configuration_before_gateway_requests_it( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + let gateway_before = context.reload_gateway().await; + assert!(!gateway_before.is_connected()); + + context.mock_gateway_mut().expect_no_outbound().await; + + let gateway_after = context.reload_gateway().await; + assert!(!gateway_after.is_connected()); + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_ignores_repeated_config_request(_: PgPoolOptions, options: PgConnectOptions) { + let mut context = HandlerTestContext::new(options).await; + + context.mock_gateway().send_config_request(); + let first_outbound = context.mock_gateway_mut().recv_outbound().await; + assert!(matches!( + first_outbound.payload, + Some(core_response::Payload::Config(_)) + )); + + context.mock_gateway().send_config_request(); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/lifecycle.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/lifecycle.rs new file mode 100644 index 0000000000..8c86a03789 --- /dev/null +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/lifecycle.rs @@ -0,0 +1,59 @@ +#[sqlx::test] +async fn test_gateway_is_marked_connected_after_successful_config_handshake( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + let gateway_before = context.reload_gateway().await; + assert!(!gateway_before.is_connected()); + + let gateway_after = context.complete_config_handshake().await; + assert!(gateway_after.is_connected()); + assert!(gateway_after.connected_at.is_some()); + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_gateway_is_marked_disconnected_when_stream_closes( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + let connected_gateway = context.complete_config_handshake().await; + assert!(connected_gateway.is_connected()); + + let pool = context.pool.clone(); + let gateway_id = context.gateway.id; + let mock_gateway = context.finish().await; + let disconnected_gateway = reload_gateway(&pool, gateway_id).await; + assert!(!disconnected_gateway.is_connected()); + assert!(disconnected_gateway.disconnected_at.is_some()); + + mock_gateway.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_gateway_is_marked_disconnected_when_stream_errors( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + let _ = context.complete_config_handshake().await; + + context + .mock_gateway() + .send_stream_error(Status::internal("mock gateway stream failure")); + + let pool = context.pool.clone(); + let gateway_id = context.gateway.id; + let mock_gateway = context.finish_after_error().await; + let disconnected_gateway = reload_gateway(&pool, gateway_id).await; + assert!(!disconnected_gateway.is_connected()); + assert!(disconnected_gateway.disconnected_at.is_some()); + + mock_gateway.expect_server_finished().await; +} diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/mfa.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/mfa.rs new file mode 100644 index 0000000000..30b6da8895 --- /dev/null +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/mfa.rs @@ -0,0 +1,120 @@ +#[sqlx::test] +async fn test_matching_location_mfa_session_authorized_produces_peer_create( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + let expected_keepalive_interval = expected_keepalive_interval(&context); + enable_internal_mfa_for_network(&context.pool, &mut context.network).await; + + let _ = context.complete_config_handshake().await; + let (device, network_device) = create_authorized_mfa_device_for_current_network( + &context, + "mfa-authorized-device", + "4v9K9Q4HEdmlX0Mb4uxDLPq3nKjvU8fNnJ9fKjzh4ko=", + "10.10.0.40", + Some("mfa-authorized-preshared-key"), + ) + .await; + + assert_send_ok!( + context.events_tx().send(GatewayEvent::MfaSessionAuthorized( + context.network.id, + device, + network_device, + )), + "failed to broadcast MFA session authorized event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_peer_update( + outbound, + UpdateType::Create, + "4v9K9Q4HEdmlX0Mb4uxDLPq3nKjvU8fNnJ9fKjzh4ko=", + &["10.10.0.40"], + Some("mfa-authorized-preshared-key"), + Some(expected_keepalive_interval), + ); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_mfa_session_authorized_with_mismatched_network_id_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + enable_internal_mfa_for_network(&context.pool, &mut context.network).await; + + let mut other_network = context.create_other_network().await; + enable_internal_mfa_for_network(&context.pool, &mut other_network).await; + assert_ne!(other_network.id, context.network.id); + + let _ = context.complete_config_handshake().await; + let (device, network_device) = create_authorized_mfa_device_for_network( + &context, + other_network.id, + "mfa-mismatched-network-device", + "Z2UuIvYJvU5fTOp8i3tHfLm4xZ0R8ExY6E3S3l+rqT8=", + "10.11.0.40", + Some("mfa-mismatched-network-preshared-key"), + ) + .await; + + assert_send_ok!( + context.events_tx().send(GatewayEvent::MfaSessionAuthorized( + context.network.id, + device, + network_device, + )), + "failed to broadcast mismatched MFA session authorized event" + ); + + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_matching_location_mfa_session_disconnected_produces_peer_delete( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + enable_internal_mfa_for_network(&context.pool, &mut context.network).await; + + let _ = context.complete_config_handshake().await; + let (device, _) = create_authorized_mfa_device_for_current_network( + &context, + "mfa-disconnected-device", + "2+n8hQ1yA2sPp1z2i6m8lP4VtY7M8W6hYqS3n4uL7qg=", + "10.10.0.41", + Some("mfa-disconnected-preshared-key"), + ) + .await; + + assert_send_ok!( + context + .events_tx() + .send(GatewayEvent::MfaSessionDisconnected( + context.network.id, + device, + )), + "failed to broadcast MFA session disconnected event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_peer_update( + outbound, + UpdateType::Delete, + "2+n8hQ1yA2sPp1z2i6m8lP4VtY7M8W6hYqS3n4uL7qg=", + &[], + None, + None, + ); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/network_events.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/network_events.rs new file mode 100644 index 0000000000..e7aeddcdbe --- /dev/null +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/network_events.rs @@ -0,0 +1,236 @@ +#[sqlx::test] +async fn test_matching_location_network_deleted_event_produces_delete_update( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + let _ = context.complete_config_handshake().await; + + assert_send_ok!( + context.events_tx().send(GatewayEvent::NetworkDeleted( + context.network.id, + context.network.name.clone(), + )), + "failed to broadcast gateway event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_network_delete_update(outbound, &context.network.name); + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_matching_location_network_modified_event_produces_modify_update( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + let _ = context.complete_config_handshake().await; + + let mut modified_network = context + .network + .clone() + .set_address([ + "10.20.0.1/24" + .parse() + .expect("failed to parse modified network address"), + ]) + .expect("failed to set modified network address"); + modified_network.name = format!("{}-modified", context.network.name); + modified_network.port = 51821; + modified_network.mtu = 1380; + modified_network.fwmark = 42; + + assert_send_ok!( + context.events_tx().send(GatewayEvent::NetworkModified( + context.network.id, + modified_network, + Vec::new(), + None, + )), + "failed to broadcast modified gateway event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_network_modify_update( + outbound, + &format!("{}-modified", context.network.name), + "10.20.0.1/24", + 51821, + 1380, + 42, + ); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_matching_location_network_created_event_produces_create_update( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + let _ = context.complete_config_handshake().await; + + let mut created_network = context + .network + .clone() + .set_address([ + "10.40.0.1/24" + .parse() + .expect("failed to parse created network address"), + ]) + .expect("failed to set created network address"); + created_network.name = format!("{}-created", context.network.name); + created_network.port = 51841; + created_network.mtu = 1410; + created_network.fwmark = 17; + + assert_send_ok!( + context.events_tx().send(GatewayEvent::NetworkCreated( + context.network.id, + created_network, + )), + "failed to broadcast created gateway event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_network_create_update( + outbound, + &format!("{}-created", context.network.name), + "10.40.0.1/24", + 51841, + 1410, + 17, + ); + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_only_matching_handler_receives_network_modified_update( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let (events_tx, _) = tokio::sync::broadcast::channel(16); + let mut matching_context = + HandlerTestContext::new_with_events_tx(options.clone(), events_tx.clone()).await; + let mut unrelated_context = HandlerTestContext::new_with_events_tx(options, events_tx).await; + + assert_ne!(matching_context.network.id, unrelated_context.network.id); + + let _ = matching_context.complete_config_handshake().await; + let _ = unrelated_context.complete_config_handshake().await; + + let mut modified_network = matching_context + .network + .clone() + .set_address([ + "10.30.0.1/24" + .parse() + .expect("failed to parse modified network address"), + ]) + .expect("failed to set modified network address"); + modified_network.name = format!("{}-modified", matching_context.network.name); + modified_network.port = 51831; + modified_network.mtu = 1400; + modified_network.fwmark = 7; + + assert_send_ok!( + matching_context + .events_tx() + .send(GatewayEvent::NetworkModified( + matching_context.network.id, + modified_network, + Vec::new(), + None, + )), + "failed to broadcast modified gateway event" + ); + + let outbound = matching_context.mock_gateway_mut().recv_outbound().await; + assert_network_modify_update( + outbound, + &format!("{}-modified", matching_context.network.name), + "10.30.0.1/24", + 51831, + 1400, + 7, + ); + matching_context.mock_gateway_mut().expect_no_outbound().await; + unrelated_context.mock_gateway_mut().expect_no_outbound().await; + + matching_context + .finish() + .await + .expect_server_finished() + .await; + unrelated_context + .finish() + .await + .expect_server_finished() + .await; +} + +#[sqlx::test] +async fn test_different_location_network_created_event_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + let other_network = context.create_other_network().await; + assert_ne!(other_network.id, context.network.id); + + let _ = context.complete_config_handshake().await; + assert_send_ok!( + context.events_tx().send(GatewayEvent::NetworkCreated( + other_network.id, + other_network, + )), + "failed to broadcast unrelated created gateway event" + ); + + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_different_location_network_deleted_event_is_ignored( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + let other_network = context.create_other_network().await; + assert_ne!(other_network.id, context.network.id); + + let _ = context.complete_config_handshake().await; + assert_send_ok!( + context.events_tx().send(GatewayEvent::NetworkDeleted( + other_network.id, + other_network.name.clone(), + )), + "failed to broadcast unrelated gateway event" + ); + + context.mock_gateway_mut().expect_no_outbound().await; + + assert_send_ok!( + context.events_tx().send(GatewayEvent::NetworkDeleted( + context.network.id, + context.network.name.clone(), + )), + "failed to broadcast owned gateway event" + ); + + let outbound = context.mock_gateway_mut().recv_outbound().await; + assert_network_delete_update(outbound, &context.network.name); + + context.finish().await.expect_server_finished().await; +} diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/stats.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/stats.rs new file mode 100644 index 0000000000..285fb24fcb --- /dev/null +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/stats.rs @@ -0,0 +1,58 @@ +#[sqlx::test] +async fn test_ignores_peer_stats_before_config_handshake( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + context + .mock_gateway() + .send_peer_stats(build_peer_stats("203.0.113.10:51820")); + + context.expect_no_peer_stats().await; + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_forwards_valid_peer_stats_after_config(_: PgPoolOptions, options: PgConnectOptions) { + let mut context = HandlerTestContext::new(options).await; + + context.mock_gateway().send_config_request(); + let _ = context.mock_gateway_mut().recv_outbound().await; + context + .mock_gateway() + .send_peer_stats(build_peer_stats("203.0.113.10:51820")); + + let forwarded = context.recv_peer_stats().await; + assert_eq!(forwarded.location_id, context.network.id); + assert_eq!(forwarded.gateway_id, context.gateway.id); + assert_eq!(forwarded.device_pubkey, "peer-public-key"); + assert_eq!(forwarded.endpoint.to_string(), "203.0.113.10:51820"); + assert_eq!(forwarded.upload, 123); + assert_eq!(forwarded.download, 456); + + context.finish().await.expect_server_finished().await; +} + +#[sqlx::test] +async fn test_drops_malformed_or_missing_endpoint_peer_stats( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = HandlerTestContext::new(options).await; + + context.mock_gateway().send_config_request(); + let _ = context.mock_gateway_mut().recv_outbound().await; + + context.mock_gateway().send_peer_stats(build_peer_stats("")); + context.expect_no_peer_stats().await; + + context + .mock_gateway() + .send_peer_stats(build_peer_stats("not-a-socket-address")); + context.expect_no_peer_stats().await; + + context.finish().await.expect_server_finished().await; +} diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs new file mode 100644 index 0000000000..dab5e580d6 --- /dev/null +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/handler/support.rs @@ -0,0 +1,474 @@ +use std::net::IpAddr; + +use defguard_common::db::{ + Id, + models::{ + device::{Device, DeviceInfo, DeviceNetworkInfo, DeviceType, WireguardNetworkDevice}, + user::User, + vpn_client_session::VpnClientSession, + wireguard::{LocationMfaMode, WireguardNetwork}, + }, +}; +use defguard_core::grpc::GatewayEvent; +use defguard_proto::enterprise::firewall::{ + FirewallConfig, FirewallPolicy, FirewallRule, IpAddress, IpVersion, Port, Protocol, + SnatBinding, ip_address::Address, port::Port as PortInner, +}; +use defguard_proto::gateway::{ + CoreResponse, Update, UpdateType, core_response, + update::{self}, +}; +use sqlx::postgres::PgConnectOptions; + +use crate::tests::common::HandlerTestContext; + +macro_rules! assert_send_ok { + ($result:expr, $message:literal) => { + match $result { + Ok(value) => value, + Err(_) => panic!($message), + } + }; +} + +pub(crate) use assert_send_ok; + +pub(crate) fn panic_unexpected(message: &str) -> ! { + panic!("{message}") +} + +pub(crate) async fn create_device_info_for_current_network( + context: &HandlerTestContext, + device_name: &str, + device_pubkey: &str, + device_ip: &str, +) -> DeviceInfo { + create_device_info_for_network( + context, + context.network.id, + device_name, + device_pubkey, + device_ip, + ) + .await +} + +pub(crate) async fn create_authorized_mfa_device_for_current_network( + context: &HandlerTestContext, + device_name: &str, + device_pubkey: &str, + device_ip: &str, + preshared_key: Option<&str>, +) -> (Device, DeviceNetworkInfo) { + create_authorized_mfa_device_for_network( + context, + context.network.id, + device_name, + device_pubkey, + device_ip, + preshared_key, + ) + .await +} + +pub(crate) async fn create_authorized_mfa_device_for_network( + context: &HandlerTestContext, + network_id: Id, + device_name: &str, + device_pubkey: &str, + device_ip: &str, + preshared_key: Option<&str>, +) -> (Device, DeviceNetworkInfo) { + let Some(preshared_key) = preshared_key else { + panic!("authorized MFA test device requires a preshared key") + }; + + let device = + create_device_for_network(context, network_id, device_name, device_pubkey, device_ip).await; + let network_device = WireguardNetworkDevice::find(&context.pool, device.id, network_id) + .await + .expect("failed to load MFA device network info") + .expect("expected MFA device network info"); + + let network = WireguardNetwork::find_by_id(&context.pool, network_id) + .await + .expect("failed to load MFA test network") + .expect("expected MFA test network"); + + let mut session = VpnClientSession::new(network_id, device.user_id, device.id, None, None); + session.preshared_key = Some(preshared_key.to_owned()); + session + .save(&context.pool) + .await + .expect("failed to persist MFA device session"); + + let device_network_info = network_device + .to_device_network_info_runtime(&context.pool, &network) + .await + .expect("failed to build MFA device network info"); + + assert!(device_network_info.is_authorized); + assert_eq!( + device_network_info.preshared_key.as_deref(), + Some(preshared_key) + ); + + (device, device_network_info) +} + +pub(crate) async fn create_device_info_for_network( + context: &HandlerTestContext, + network_id: Id, + device_name: &str, + device_pubkey: &str, + device_ip: &str, +) -> DeviceInfo { + let device = + create_device_for_network(context, network_id, device_name, device_pubkey, device_ip).await; + + DeviceInfo::from_device(&context.pool, device) + .await + .expect("failed to load device info") +} + +pub(crate) async fn create_device_for_network( + context: &HandlerTestContext, + network_id: Id, + device_name: &str, + device_pubkey: &str, + device_ip: &str, +) -> Device { + let username = format!("{device_name}-user"); + let email = format!("{device_name}@example.com"); + let user = User::new( + username, + Some("pass123"), + "Peer".to_string(), + "Test".to_string(), + email, + None, + ) + .save(&context.pool) + .await + .expect("failed to create test user"); + let device = Device::new( + device_name.to_string(), + device_pubkey.to_string(), + user.id, + DeviceType::User, + None, + true, + ) + .save(&context.pool) + .await + .expect("failed to create test device"); + + WireguardNetworkDevice::new(network_id, device.id, vec![parse_test_ip(device_ip)]) + .insert(&context.pool) + .await + .expect("failed to attach device to network"); + + device +} + +pub(crate) async fn enable_internal_mfa_for_network( + pool: &sqlx::PgPool, + network: &mut WireguardNetwork, +) { + network.location_mfa_mode = LocationMfaMode::Internal; + network + .save(pool) + .await + .expect("failed to enable MFA for test network"); + assert!(network.mfa_enabled()); +} + +pub(crate) async fn assert_device_event_is_ignored_before_config_handshake( + options: PgConnectOptions, + device_name: &str, + device_pubkey: &str, + device_ip: &str, + build_event: fn(DeviceInfo) -> GatewayEvent, +) { + let mut context = HandlerTestContext::new(options).await; + assert_eq!(context.events_tx().receiver_count(), 0); + + let _broadcast_guard = context.events_tx().subscribe(); + let device_info = + create_device_info_for_current_network(&context, device_name, device_pubkey, device_ip) + .await; + + assert_send_ok!( + context.events_tx().send(build_event(device_info)), + "failed to broadcast ignored device event" + ); + + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +pub(crate) async fn assert_device_event_for_different_network_is_ignored( + options: PgConnectOptions, + device_name: &str, + device_pubkey: &str, + device_ip: &str, + build_event: fn(DeviceInfo) -> GatewayEvent, +) { + let mut context = HandlerTestContext::new(options).await; + let other_network = context.create_other_network().await; + assert_ne!(other_network.id, context.network.id); + + let _ = context.complete_config_handshake().await; + let device_info = create_device_info_for_network( + &context, + other_network.id, + device_name, + device_pubkey, + device_ip, + ) + .await; + + assert_send_ok!( + context.events_tx().send(build_event(device_info)), + "failed to broadcast ignored device event" + ); + + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +pub(crate) async fn assert_firewall_event_for_different_network_is_ignored( + options: PgConnectOptions, + build_event: impl FnOnce(Id) -> GatewayEvent, +) { + let mut context = HandlerTestContext::new(options).await; + let other_network = context.create_other_network().await; + assert_ne!(other_network.id, context.network.id); + + let _ = context.complete_config_handshake().await; + + assert_send_ok!( + context.events_tx().send(build_event(other_network.id)), + "failed to broadcast ignored firewall event" + ); + + context.mock_gateway_mut().expect_no_outbound().await; + + context.finish().await.expect_server_finished().await; +} + +pub(crate) fn expected_keepalive_interval(context: &HandlerTestContext) -> u32 { + u32::try_from(context.network.keepalive_interval) + .expect("expected non-negative network keepalive interval") +} + +pub(crate) fn parse_test_ip(ip: &str) -> IpAddr { + ip.parse().expect("failed to parse test peer IP address") +} + +pub(crate) fn assert_peer_update( + outbound: CoreResponse, + expected_update_type: UpdateType, + expected_pubkey: &str, + expected_allowed_ips: &[&str], + expected_preshared_key: Option<&str>, + expected_keepalive_interval: Option, +) { + match outbound.payload { + Some(core_response::Payload::Update(Update { + update_type, + update: Some(update::Update::Peer(peer)), + })) => { + assert_eq!(update_type, expected_update_type as i32); + assert_eq!(peer.pubkey, expected_pubkey); + assert_eq!( + peer.allowed_ips, + expected_allowed_ips + .iter() + .map(|allowed_ip| allowed_ip.to_string()) + .collect::>() + ); + assert_eq!(peer.preshared_key.as_deref(), expected_preshared_key); + assert_eq!(peer.keepalive_interval, expected_keepalive_interval); + } + _ => panic_unexpected("expected peer update"), + } +} + +pub(crate) fn assert_network_delete_update(outbound: CoreResponse, expected_network_name: &str) { + match outbound.payload { + Some(core_response::Payload::Update(Update { + update_type, + update: Some(update::Update::Network(network)), + })) => { + assert_eq!(update_type, UpdateType::Delete as i32); + assert_eq!(network.name, expected_network_name); + } + _ => panic_unexpected("expected network delete update"), + } +} + +pub(crate) fn assert_network_create_update( + outbound: CoreResponse, + expected_network_name: &str, + expected_address: &str, + expected_port: u32, + expected_mtu: u32, + expected_fwmark: u32, +) { + match outbound.payload { + Some(core_response::Payload::Update(Update { + update_type, + update: Some(update::Update::Network(network)), + })) => { + assert_eq!(update_type, UpdateType::Create as i32); + assert_eq!(network.name, expected_network_name); + assert_eq!(network.addresses, vec![expected_address.to_string()]); + assert_eq!(network.port, expected_port); + assert_eq!(network.peers, Vec::new()); + assert_eq!(network.firewall_config, None); + assert_eq!(network.mtu, expected_mtu); + assert_eq!(network.fwmark, expected_fwmark); + } + _ => panic_unexpected("expected network create update"), + } +} + +pub(crate) fn assert_network_modify_update( + outbound: CoreResponse, + expected_network_name: &str, + expected_address: &str, + expected_port: u32, + expected_mtu: u32, + expected_fwmark: u32, +) { + match outbound.payload { + Some(core_response::Payload::Update(Update { + update_type, + update: Some(update::Update::Network(network)), + })) => { + assert_eq!(update_type, UpdateType::Modify as i32); + assert_eq!(network.name, expected_network_name); + assert_eq!(network.addresses, vec![expected_address.to_string()]); + assert_eq!(network.port, expected_port); + assert_eq!(network.peers, Vec::new()); + assert_eq!(network.firewall_config, None); + assert_eq!(network.mtu, expected_mtu); + assert_eq!(network.fwmark, expected_fwmark); + } + _ => panic_unexpected("expected network modify update"), + } +} + +pub(crate) fn build_test_firewall_config() -> FirewallConfig { + FirewallConfig { + default_policy: i32::from(FirewallPolicy::Allow), + rules: vec![FirewallRule { + id: 101, + source_addrs: vec![IpAddress { + address: Some(Address::IpSubnet("10.10.0.0/24".to_string())), + }], + destination_addrs: vec![IpAddress { + address: Some(Address::Ip("198.51.100.20".to_string())), + }], + destination_ports: vec![Port { + port: Some(PortInner::SinglePort(443)), + }], + protocols: vec![i32::from(Protocol::Tcp)], + verdict: i32::from(FirewallPolicy::Deny), + comment: Some("block test https destination".to_string()), + ip_version: i32::from(IpVersion::Ipv4), + }], + snat_bindings: vec![SnatBinding { + id: 202, + source_addrs: vec![IpAddress { + address: Some(Address::IpSubnet("10.10.0.0/24".to_string())), + }], + public_ip: "203.0.113.44".to_string(), + comment: Some("test snat binding".to_string()), + }], + } +} + +pub(crate) fn assert_firewall_modify_update( + outbound: CoreResponse, + expected_firewall_config: &FirewallConfig, +) { + match outbound.payload { + Some(core_response::Payload::Update(Update { + update_type, + update: Some(update::Update::FirewallConfig(firewall_config)), + })) => { + assert_eq!(update_type, UpdateType::Modify as i32); + assert_eq!( + firewall_config.default_policy, + expected_firewall_config.default_policy + ); + assert_eq!( + firewall_config.rules.len(), + expected_firewall_config.rules.len() + ); + assert_eq!( + firewall_config.snat_bindings.len(), + expected_firewall_config.snat_bindings.len() + ); + + let firewall_rule = firewall_config + .rules + .first() + .expect("expected firewall rule in update payload"); + let expected_firewall_rule = expected_firewall_config + .rules + .first() + .expect("expected firewall rule in test config"); + assert_eq!(firewall_rule.id, expected_firewall_rule.id); + assert_eq!( + firewall_rule.source_addrs, + expected_firewall_rule.source_addrs + ); + assert_eq!( + firewall_rule.destination_addrs, + expected_firewall_rule.destination_addrs + ); + assert_eq!( + firewall_rule.destination_ports, + expected_firewall_rule.destination_ports + ); + assert_eq!(firewall_rule.protocols, expected_firewall_rule.protocols); + assert_eq!(firewall_rule.verdict, expected_firewall_rule.verdict); + assert_eq!(firewall_rule.comment, expected_firewall_rule.comment); + assert_eq!(firewall_rule.ip_version, expected_firewall_rule.ip_version); + + let snat_binding = firewall_config + .snat_bindings + .first() + .expect("expected SNAT binding in update payload"); + let expected_snat_binding = expected_firewall_config + .snat_bindings + .first() + .expect("expected SNAT binding in test config"); + assert_eq!(snat_binding.id, expected_snat_binding.id); + assert_eq!( + snat_binding.source_addrs, + expected_snat_binding.source_addrs + ); + assert_eq!(snat_binding.public_ip, expected_snat_binding.public_ip); + assert_eq!(snat_binding.comment, expected_snat_binding.comment); + } + _ => panic_unexpected("expected firewall config update"), + } +} + +pub(crate) fn assert_firewall_disable_update(outbound: CoreResponse) { + match outbound.payload { + Some(core_response::Payload::Update(Update { + update_type, + update: Some(update::Update::DisableFirewall(())), + })) => { + assert_eq!(update_type, UpdateType::Delete as i32); + } + _ => panic_unexpected("expected firewall disable update"), + } +} diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/manager.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/manager.rs new file mode 100644 index 0000000000..89481edfa0 --- /dev/null +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/manager.rs @@ -0,0 +1,425 @@ +use std::time::Duration; + +use defguard_common::db::{Id, models::gateway::Gateway}; +use defguard_proto::gateway::core_response; +use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; +use tonic::Status; + +use crate::tests::common::{ + ManagerTestContext, MockGatewayHarness, build_gateway_with_enabled, create_gateway, + create_gateway_with_enabled, create_network, reload_gateway, unique_mock_gateway_socket_path, + wait_for_gateway_connection_state, +}; + +const FAST_RETRY_DELAY: Duration = Duration::from_millis(20); + +async fn complete_manager_handshake( + context: &ManagerTestContext, + gateway: &Gateway, + mock_gateway: &mut MockGatewayHarness, +) { + mock_gateway.wait_connected().await; + mock_gateway.send_config_request(); + let outbound = mock_gateway.recv_outbound().await; + assert!(matches!( + outbound.payload, + Some(core_response::Payload::Config(_)) + )); + + let gateway_after = wait_for_gateway_connection_state(&context.pool, gateway.id, true).await; + assert!(gateway_after.is_connected()); +} + +#[sqlx::test] +async fn test_starts_existing_enabled_gateway_on_startup( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = ManagerTestContext::new(options).await; + let network = create_network(&context.pool).await; + let gateway = create_gateway(&context.pool, network.id).await; + let mut mock_gateway = MockGatewayHarness::start().await; + context.register_gateway_mock(&gateway, &mock_gateway); + + context.start().await; + complete_manager_handshake(&context, &gateway, &mut mock_gateway).await; + + context.finish().await; +} + +#[sqlx::test] +async fn test_starts_gateway_after_enabled_update_notification( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = ManagerTestContext::new(options).await; + let network = create_network(&context.pool).await; + let mut gateway = create_gateway_with_enabled(&context.pool, network.id, false).await; + let mut mock_gateway = MockGatewayHarness::start().await; + context.register_gateway_mock(&gateway, &mock_gateway); + + context.start().await; + assert_eq!( + context.handler_spawn_attempt_count(gateway.id), + 0, + "disabled gateway handler should not start during manager startup" + ); + + let gateway_before = reload_gateway(&context.pool, gateway.id).await; + assert!(!gateway_before.is_connected()); + + gateway.enabled = true; + gateway + .save(&context.pool) + .await + .expect("failed to enable test gateway"); + + context + .wait_for_gateway_notification_count(gateway.id, 1) + .await; + context + .wait_for_handler_spawn_attempt_count(gateway.id, 1) + .await; + complete_manager_handshake(&context, &gateway, &mut mock_gateway).await; + + context.finish().await; +} + +#[sqlx::test] +async fn test_noop_gateway_update_does_not_restart_handler( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = ManagerTestContext::new(options).await; + let network = create_network(&context.pool).await; + let mut gateway = create_gateway(&context.pool, network.id).await; + let mut mock_gateway = MockGatewayHarness::start().await; + context.register_gateway_mock(&gateway, &mock_gateway); + + context.start().await; + complete_manager_handshake(&context, &gateway, &mut mock_gateway).await; + + gateway = reload_gateway(&context.pool, gateway.id).await; + let initial_spawn_attempts = context.handler_spawn_attempt_count(gateway.id); + let initial_notification_count = context.gateway_notification_count(gateway.id); + let initial_connection_count = mock_gateway.connection_count(); + + gateway.modified_by = "manager-noop-update".to_string(); + gateway + .save(&context.pool) + .await + .expect("failed to save no-op gateway update"); + + context + .wait_for_gateway_notification_count(gateway.id, initial_notification_count + 1) + .await; + assert_eq!( + context.handler_spawn_attempt_count(gateway.id), + initial_spawn_attempts, + "no-op gateway update should not restart the handler" + ); + assert_eq!( + mock_gateway.connection_count(), + initial_connection_count, + "no-op gateway update should not reconnect the handler" + ); + + let gateway_after = reload_gateway(&context.pool, gateway.id).await; + assert!(gateway_after.is_connected()); + + context.finish().await; +} + +#[sqlx::test] +async fn test_gateway_address_change_restarts_handler(_: PgPoolOptions, options: PgConnectOptions) { + let mut context = ManagerTestContext::new(options).await; + let network = create_network(&context.pool).await; + let mut gateway = create_gateway(&context.pool, network.id).await; + let mut original_mock_gateway = MockGatewayHarness::start().await; + context.register_gateway_mock(&gateway, &original_mock_gateway); + + context.start().await; + complete_manager_handshake(&context, &gateway, &mut original_mock_gateway).await; + + gateway = reload_gateway(&context.pool, gateway.id).await; + + let replacement_mock_url = { + gateway.address = "127.0.0.2".to_string(); + gateway.modified_by = "manager-address-update".to_string(); + gateway.url() + }; + let mut replacement_mock_gateway = MockGatewayHarness::start().await; + context.register_gateway_url(replacement_mock_url, &replacement_mock_gateway); + + let initial_spawn_attempts = context.handler_spawn_attempt_count(gateway.id); + let initial_notification_count = context.gateway_notification_count(gateway.id); + + gateway + .save(&context.pool) + .await + .expect("failed to save gateway address update"); + + context + .wait_for_gateway_notification_count(gateway.id, initial_notification_count + 1) + .await; + context + .wait_for_handler_spawn_attempt_count(gateway.id, initial_spawn_attempts + 1) + .await; + replacement_mock_gateway.wait_for_connection_count(1).await; + complete_manager_handshake(&context, &gateway, &mut replacement_mock_gateway).await; + + let gateway_after = reload_gateway(&context.pool, gateway.id).await; + assert_eq!(gateway_after.address, "127.0.0.2"); + assert!(gateway_after.is_connected()); + original_mock_gateway.expect_server_finished().await; + + context.finish().await; +} + +#[sqlx::test] +async fn test_enabled_gateway_update_to_disabled_stops_handler( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = ManagerTestContext::new(options).await; + let network = create_network(&context.pool).await; + let mut gateway = create_gateway(&context.pool, network.id).await; + let mut mock_gateway = MockGatewayHarness::start().await; + context.register_gateway_mock(&gateway, &mock_gateway); + + context.start().await; + complete_manager_handshake(&context, &gateway, &mut mock_gateway).await; + + gateway = reload_gateway(&context.pool, gateway.id).await; + let initial_spawn_attempts = context.handler_spawn_attempt_count(gateway.id); + let initial_notification_count = context.gateway_notification_count(gateway.id); + let initial_connection_count = mock_gateway.connection_count(); + + gateway.enabled = false; + gateway.modified_by = "manager-disable-update".to_string(); + gateway + .save(&context.pool) + .await + .expect("failed to save gateway disable update"); + + context + .wait_for_gateway_notification_count(gateway.id, initial_notification_count + 1) + .await; + let gateway_after = wait_for_gateway_connection_state(&context.pool, gateway.id, false).await; + assert!(!gateway_after.is_connected()); + assert!(gateway_after.disconnected_at.is_some()); + assert_eq!( + context.handler_spawn_attempt_count(gateway.id), + initial_spawn_attempts, + "disabling the gateway should stop the existing handler without spawning a replacement" + ); + assert_eq!( + mock_gateway.connection_count(), + initial_connection_count, + "disabling the gateway should not create a new gateway connection" + ); + mock_gateway.expect_server_finished().await; + + context.finish().await; +} + +#[sqlx::test] +async fn test_insert_notification_starts_handler_for_enabled_gateway( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = ManagerTestContext::new(options).await; + let network = create_network(&context.pool).await; + let gateway = build_gateway_with_enabled(network.id, true); + let gateway_url = gateway.url(); + let mut mock_gateway = MockGatewayHarness::start().await; + context.register_gateway_url(gateway_url, &mock_gateway); + + context.start().await; + + let gateway = gateway + .save(&context.pool) + .await + .expect("failed to insert enabled test gateway"); + + context + .wait_for_gateway_notification_count(gateway.id, 1) + .await; + context + .wait_for_handler_spawn_attempt_count(gateway.id, 1) + .await; + complete_manager_handshake(&context, &gateway, &mut mock_gateway).await; + + context.finish().await; +} + +#[sqlx::test] +async fn test_delete_notification_purges_and_aborts_gateway_connection( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = ManagerTestContext::new(options).await; + let network = create_network(&context.pool).await; + let gateway = create_gateway(&context.pool, network.id).await; + let mut mock_gateway = MockGatewayHarness::start().await; + context.register_gateway_mock(&gateway, &mock_gateway); + + context.start().await; + complete_manager_handshake(&context, &gateway, &mut mock_gateway).await; + + Gateway::delete_by_id(&context.pool, gateway.id) + .await + .expect("failed to delete test gateway"); + + mock_gateway.wait_purged().await; + mock_gateway.expect_server_finished().await; + + context.finish().await; +} + +#[sqlx::test] +async fn test_retries_failed_connection_without_notification_or_duplicate_handler( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = ManagerTestContext::new(options).await; + context.set_retry_delay(FAST_RETRY_DELAY); + + let network = create_network(&context.pool).await; + let gateway = create_gateway(&context.pool, network.id).await; + let socket_path = unique_mock_gateway_socket_path(); + context.register_gateway_socket_path(gateway.url(), socket_path.clone()); + + context.start().await; + context + .wait_for_handler_spawn_attempt_count(gateway.id, 1) + .await; + context + .wait_for_handler_connection_attempt_count(gateway.id, 2) + .await; + + assert_eq!( + context.gateway_notification_count(gateway.id), + 0, + "manager reconnect retries should not depend on gateway table notifications" + ); + assert_eq!( + context.handler_spawn_attempt_count(gateway.id), + 1, + "manager reconnect retries should reuse the existing handler task" + ); + + let mut mock_gateway = MockGatewayHarness::start_at(socket_path).await; + mock_gateway.wait_for_connection_count(1).await; + complete_manager_handshake(&context, &gateway, &mut mock_gateway).await; + + assert_eq!( + context.handler_spawn_attempt_count(gateway.id), + 1, + "reconnect success should not create a second concurrent handler" + ); + + context.finish().await; +} + +#[sqlx::test] +async fn test_retries_after_stream_close_with_single_handler_supervisor( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = ManagerTestContext::new(options).await; + context.set_retry_delay(FAST_RETRY_DELAY); + + let network = create_network(&context.pool).await; + let gateway = create_gateway(&context.pool, network.id).await; + let mut mock_gateway = MockGatewayHarness::start().await; + context.register_gateway_mock(&gateway, &mock_gateway); + + context.start().await; + complete_manager_handshake(&context, &gateway, &mut mock_gateway).await; + + let initial_spawn_attempts = context.handler_spawn_attempt_count(gateway.id); + let initial_connection_attempts = context.handler_connection_attempt_count(gateway.id); + let reconnect_socket_path = mock_gateway.socket_path(); + + mock_gateway.close_stream(); + + let gateway_after_disconnect = + wait_for_gateway_connection_state(&context.pool, gateway.id, false).await; + assert!(!gateway_after_disconnect.is_connected()); + assert!(gateway_after_disconnect.disconnected_at.is_some()); + + mock_gateway.expect_server_finished().await; + + context + .wait_for_handler_connection_attempt_count(gateway.id, initial_connection_attempts + 1) + .await; + assert_eq!( + context.handler_spawn_attempt_count(gateway.id), + initial_spawn_attempts, + "stream closure retries should keep a single handler supervisor" + ); + + let mut replacement_mock_gateway = MockGatewayHarness::start_at(reconnect_socket_path).await; + replacement_mock_gateway.wait_for_connection_count(1).await; + complete_manager_handshake(&context, &gateway, &mut replacement_mock_gateway).await; + + assert_eq!( + context.handler_spawn_attempt_count(gateway.id), + initial_spawn_attempts, + "successful reconnect after stream closure should not create a duplicate handler" + ); + + context.finish().await; +} + +#[sqlx::test] +async fn test_retries_after_stream_error_with_single_handler_supervisor( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let mut context = ManagerTestContext::new(options).await; + context.set_retry_delay(FAST_RETRY_DELAY); + + let network = create_network(&context.pool).await; + let gateway = create_gateway(&context.pool, network.id).await; + let mut mock_gateway = MockGatewayHarness::start().await; + context.register_gateway_mock(&gateway, &mock_gateway); + + context.start().await; + complete_manager_handshake(&context, &gateway, &mut mock_gateway).await; + + let initial_spawn_attempts = context.handler_spawn_attempt_count(gateway.id); + let initial_connection_attempts = context.handler_connection_attempt_count(gateway.id); + let reconnect_socket_path = mock_gateway.socket_path(); + + mock_gateway.send_stream_error(Status::internal("mock gateway stream failure")); + + let gateway_after_disconnect = + wait_for_gateway_connection_state(&context.pool, gateway.id, false).await; + assert!(!gateway_after_disconnect.is_connected()); + assert!(gateway_after_disconnect.disconnected_at.is_some()); + + mock_gateway.expect_server_finished().await; + + context + .wait_for_handler_connection_attempt_count(gateway.id, initial_connection_attempts + 1) + .await; + assert_eq!( + context.handler_spawn_attempt_count(gateway.id), + initial_spawn_attempts, + "stream failure retries should keep a single handler supervisor" + ); + + let mut replacement_mock_gateway = MockGatewayHarness::start_at(reconnect_socket_path).await; + replacement_mock_gateway.wait_for_connection_count(1).await; + complete_manager_handshake(&context, &gateway, &mut replacement_mock_gateway).await; + + assert_eq!( + context.handler_spawn_attempt_count(gateway.id), + initial_spawn_attempts, + "successful reconnect after stream failure should not create a duplicate handler" + ); + + context.finish().await; +} diff --git a/crates/defguard_gateway_manager/src/tests/gateway_manager/mod.rs b/crates/defguard_gateway_manager/src/tests/gateway_manager/mod.rs new file mode 100644 index 0000000000..caf495ef2e --- /dev/null +++ b/crates/defguard_gateway_manager/src/tests/gateway_manager/mod.rs @@ -0,0 +1,2 @@ +mod handler; +mod manager; diff --git a/crates/defguard_gateway_manager/src/tests/mod.rs b/crates/defguard_gateway_manager/src/tests/mod.rs new file mode 100644 index 0000000000..ef5f24f2c4 --- /dev/null +++ b/crates/defguard_gateway_manager/src/tests/mod.rs @@ -0,0 +1,2 @@ +mod common; +mod gateway_manager; diff --git a/flake.lock b/flake.lock index 1b6bbdcda0..4db212dbcd 100644 --- a/flake.lock +++ b/flake.lock @@ -32,11 +32,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1773821835, - "narHash": "sha256-TJ3lSQtW0E2JrznGVm8hOQGVpXjJyXY2guAxku2O9A4=", + "lastModified": 1774106199, + "narHash": "sha256-US5Tda2sKmjrg2lNHQL3jRQ6p96cgfWh3J1QBliQ8Ws=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "b40629efe5d6ec48dd1efba650c797ddbd39ace0", + "rev": "6c9a78c09ff4d6c21d0319114873508a6ec01655", "type": "github" }, "original": { @@ -74,11 +74,11 @@ ] }, "locked": { - "lastModified": 1773889863, - "narHash": "sha256-tSsmZOHBgq4qfu5MNCAEsKZL1cI4avNLw2oUTXWeb74=", + "lastModified": 1774235565, + "narHash": "sha256-D8OOwvq3zDDCtIhMcNueb9tGSZaZUanKpWDleRgQ80U=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "dbfd51be2692cb7022e301d14c139accb4ee63f0", + "rev": "dc00324a2438762582b49954373112b8eab29cab", "type": "github" }, "original": {