From b80daff3fe9c97553e1741a13f058085f6141eed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Tue, 17 Mar 2026 11:12:27 +0100 Subject: [PATCH 01/10] WireguardNetwork address getter/setter --- ...3e4a83ac3623590674e47bb1a1dbf9c25d77f.json | 152 ------------------ .../defguard_common/src/db/models/device.rs | 106 ++++-------- .../src/db/models/wireguard.rs | 126 +++++++++++---- .../src/enterprise/db/models/acl.rs | 25 +-- .../src/enterprise/db/models/acl/tests.rs | 5 +- .../src/enterprise/firewall/mod.rs | 4 +- .../firewall/tests/all_locations.rs | 60 +++---- .../enterprise/firewall/tests/destination.rs | 48 ++---- .../firewall/tests/disabled_rules.rs | 30 ++-- .../firewall/tests/expired_rules.rs | 28 ++-- .../src/enterprise/firewall/tests/gh1868.rs | 72 ++++----- .../src/enterprise/firewall/tests/mod.rs | 104 +++++------- .../firewall/tests/unapplied_rules.rs | 30 ++-- .../src/handlers/network_devices.rs | 33 ++-- .../defguard_core/src/handlers/wireguard.rs | 3 +- crates/defguard_core/src/lib.rs | 4 +- .../src/location_management/allowed_peers.rs | 30 ++-- .../src/location_management/mod.rs | 2 +- .../src/location_management/tests.rs | 16 +- crates/defguard_core/src/wg_config.rs | 6 +- .../api/wireguard_network_devices.rs | 16 +- .../api/wireguard_network_import.rs | 4 +- .../defguard_gateway_manager/src/handler.rs | 4 +- .../src/servers/enrollment.rs | 20 +-- .../tests/common/mod.rs | 2 +- crates/defguard_setup/src/auto_adoption.rs | 2 +- .../src/handlers/auto_wizard.rs | 2 +- .../tests/auto_adoption_wizard.rs | 2 +- crates/defguard_setup/tests/wizard_state.rs | 2 +- crates/defguard_static_ip/src/lib.rs | 40 ++--- 30 files changed, 368 insertions(+), 610 deletions(-) delete mode 100644 .sqlx/query-c58c7b4dc7463a93895b17d591e3e4a83ac3623590674e47bb1a1dbf9c25d77f.json diff --git a/.sqlx/query-c58c7b4dc7463a93895b17d591e3e4a83ac3623590674e47bb1a1dbf9c25d77f.json b/.sqlx/query-c58c7b4dc7463a93895b17d591e3e4a83ac3623590674e47bb1a1dbf9c25d77f.json deleted file mode 100644 index a38c619ecf..0000000000 --- a/.sqlx/query-c58c7b4dc7463a93895b17d591e3e4a83ac3623590674e47bb1a1dbf9c25d77f.json +++ /dev/null @@ -1,152 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "SELECT id, name, address, port, pubkey, prvkey, endpoint, dns, mtu, fwmark, allowed_ips, allow_all_groups, connected_at, keepalive_interval, peer_disconnect_threshold, acl_enabled, acl_default_allow, location_mfa_mode \"location_mfa_mode: LocationMfaMode\", service_location_mode \"service_location_mode: ServiceLocationMode\" FROM wireguard_network WHERE id = $1", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "id", - "type_info": "Int8" - }, - { - "ordinal": 1, - "name": "name", - "type_info": "Text" - }, - { - "ordinal": 2, - "name": "address", - "type_info": "InetArray" - }, - { - "ordinal": 3, - "name": "port", - "type_info": "Int4" - }, - { - "ordinal": 4, - "name": "pubkey", - "type_info": "Text" - }, - { - "ordinal": 5, - "name": "prvkey", - "type_info": "Text" - }, - { - "ordinal": 6, - "name": "endpoint", - "type_info": "Text" - }, - { - "ordinal": 7, - "name": "dns", - "type_info": "Text" - }, - { - "ordinal": 8, - "name": "mtu", - "type_info": "Int4" - }, - { - "ordinal": 9, - "name": "fwmark", - "type_info": "Int8" - }, - { - "ordinal": 10, - "name": "allowed_ips", - "type_info": "InetArray" - }, - { - "ordinal": 11, - "name": "allow_all_groups", - "type_info": "Bool" - }, - { - "ordinal": 12, - "name": "connected_at", - "type_info": "Timestamp" - }, - { - "ordinal": 13, - "name": "keepalive_interval", - "type_info": "Int4" - }, - { - "ordinal": 14, - "name": "peer_disconnect_threshold", - "type_info": "Int4" - }, - { - "ordinal": 15, - "name": "acl_enabled", - "type_info": "Bool" - }, - { - "ordinal": 16, - "name": "acl_default_allow", - "type_info": "Bool" - }, - { - "ordinal": 17, - "name": "location_mfa_mode: LocationMfaMode", - "type_info": { - "Custom": { - "name": "location_mfa_mode", - "kind": { - "Enum": [ - "disabled", - "internal", - "external" - ] - } - } - } - }, - { - "ordinal": 18, - "name": "service_location_mode: ServiceLocationMode", - "type_info": { - "Custom": { - "name": "service_location_mode", - "kind": { - "Enum": [ - "disabled", - "prelogon", - "alwayson" - ] - } - } - } - } - ], - "parameters": { - "Left": [ - "Int8" - ] - }, - "nullable": [ - false, - false, - false, - false, - false, - false, - false, - true, - false, - false, - false, - false, - true, - false, - false, - false, - false, - false, - false - ] - }, - "hash": "c58c7b4dc7463a93895b17d591e3e4a83ac3623590674e47bb1a1dbf9c25d77f" -} diff --git a/crates/defguard_common/src/db/models/device.rs b/crates/defguard_common/src/db/models/device.rs index 42956fd515..b92382a369 100644 --- a/crates/defguard_common/src/db/models/device.rs +++ b/crates/defguard_common/src/db/models/device.rs @@ -201,7 +201,7 @@ pub struct UserDeviceNetworkInfo { } impl UserDevice { - pub async fn from_device(pool: &PgPool, device: Device) -> Result, SqlxError> { + pub async fn from_device(pool: &PgPool, device: Device) -> sqlx::Result> { // fetch device config and connection info for all allowed networks let result = query!( "SELECT n.id network_id, n.name network_name, n.endpoint gateway_endpoint, \ @@ -319,7 +319,7 @@ impl WireguardNetworkDevice { .collect() } - pub async fn insert<'e, E>(&self, executor: E) -> Result<(), SqlxError> + pub async fn insert<'e, E>(&self, executor: E) -> sqlx::Result<()> where E: PgExecutor<'e>, { @@ -343,7 +343,7 @@ impl WireguardNetworkDevice { Ok(()) } - pub async fn update<'e, E>(&self, executor: E) -> Result<(), SqlxError> + pub async fn update<'e, E>(&self, executor: E) -> sqlx::Result<()> where E: PgExecutor<'e>, { @@ -364,7 +364,7 @@ impl WireguardNetworkDevice { Ok(()) } - pub async fn delete<'e, E>(&self, executor: E) -> Result<(), SqlxError> + pub async fn delete<'e, E>(&self, executor: E) -> sqlx::Result<()> where E: PgExecutor<'e>, { @@ -384,7 +384,7 @@ impl WireguardNetworkDevice { executor: E, device_id: Id, network_id: Id, - ) -> Result, SqlxError> + ) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -406,7 +406,7 @@ impl WireguardNetworkDevice { /// Get a first network the device was added to. Useful for network devices to /// make sure they always pull only one network's config. - pub async fn find_first<'e, E>(executor: E, device_id: Id) -> Result, SqlxError> + pub async fn find_first<'e, E>(executor: E, device_id: Id) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -450,7 +450,7 @@ impl WireguardNetworkDevice { }) } - pub async fn all_for_network<'e, E>(executor: E, network_id: Id) -> Result, SqlxError> + pub async fn all_for_network<'e, E>(executor: E, network_id: Id) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -476,7 +476,7 @@ impl WireguardNetworkDevice { executor: E, network_id: Id, user_id: Id, - ) -> Result, SqlxError> + ) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -497,29 +497,17 @@ impl WireguardNetworkDevice { Ok(res) } - pub async fn network<'e, E>(&self, executor: E) -> Result, SqlxError> + pub async fn network<'e, E>(&self, executor: E) -> sqlx::Result> where E: PgExecutor<'e>, { - query_as!( - WireguardNetwork, - "SELECT id, name, address, port, pubkey, prvkey, endpoint, dns, mtu, fwmark, \ - allowed_ips, allow_all_groups, connected_at, keepalive_interval, \ - peer_disconnect_threshold, acl_enabled, acl_default_allow, \ - location_mfa_mode \"location_mfa_mode: LocationMfaMode\", \ - service_location_mode \"service_location_mode: ServiceLocationMode\" \ - FROM wireguard_network WHERE id = $1", - self.wireguard_network_id - ) - .fetch_one(executor) - .await + WireguardNetwork::find_by_id(executor, self.wireguard_network_id) + .await? + .ok_or(sqlx::Error::RowNotFound) } /// Check if any device is assigned to a given network. - pub async fn has_devices_in_network<'e, E>( - executor: E, - network_id: Id, - ) -> Result + pub async fn has_devices_in_network<'e, E>(executor: E, network_id: Id) -> sqlx::Result where E: PgExecutor<'e>, { @@ -881,7 +869,7 @@ impl Device { let reserved = reserved_ips.unwrap_or_default(); // Iterate over all network addresses and assign new IP for the device in each of them - for address in &network.address { + for address in network.address() { debug!( "Assigning address to device {} in network {} {address}", self.name, network.name, @@ -892,7 +880,9 @@ impl Device { { debug!( "Skipping reassignment of already assigned valid IP {ip} for device {} in network {} with addresses {:?}", - self.name, network.name, network.address + self.name, + network.name, + network.address() ); ips.push(*ip); continue; @@ -972,31 +962,6 @@ impl Device { Ok(wireguard_network_device) } - /// Gets the first network of the network device - /// FIXME: Return only one network, not a Vec - pub async fn find_network_device_networks<'e, E>( - &self, - executor: E, - ) -> Result>, SqlxError> - where - E: PgExecutor<'e>, - { - query_as!( - WireguardNetwork, - "SELECT id, name, address, port, pubkey, prvkey, endpoint, dns, mtu, fwmark, \ - allowed_ips, allow_all_groups, connected_at, keepalive_interval, \ - peer_disconnect_threshold, acl_enabled, acl_default_allow, \ - location_mfa_mode \"location_mfa_mode: LocationMfaMode\", \ - service_location_mode \"service_location_mode: ServiceLocationMode\" \ - FROM wireguard_network WHERE id IN \ - (SELECT wireguard_network_id FROM wireguard_network_device \ - WHERE device_id = $1 ORDER BY id LIMIT 1)", - self.id - ) - .fetch_all(executor) - .await - } - pub fn validate_pubkey(pubkey: &str) -> Result<(), String> { if let Ok(key) = BASE64_STANDARD.decode(pubkey) { if key.len() == KEY_LENGTH { @@ -1093,7 +1058,7 @@ mod test { pubkey: String, network: &WireguardNetwork, ) -> Result<(Self, WireguardNetworkDevice), ModelError> { - if let Some(address) = network.address.first() { + if let Some(address) = network.address().first() { let net_ip = address.ip(); let net_network = address.network(); let net_broadcast = address.broadcast(); @@ -1217,11 +1182,11 @@ mod test { .unwrap(); let mut updated_network = network.clone(); - updated_network.address = vec![ + updated_network.set_address([ "10.0.0.0/16".parse::().unwrap(), "123.12.0.0/16".parse::().unwrap(), "123.123.0.0/16".parse::().unwrap(), - ]; + ]); updated_network.save(&mut *conn).await.unwrap(); let used_ips = updated_network @@ -1310,7 +1275,7 @@ mod test { .unwrap(); let mut updated_network = network.clone(); - updated_network.address = vec!["10.0.0.0/16".parse::().unwrap()]; + updated_network.set_address(["10.0.0.0/16".parse::().unwrap()]); updated_network.save(&mut *conn).await.unwrap(); let used_ips = updated_network @@ -1390,7 +1355,7 @@ mod test { .unwrap(); let mut updated_network = network.clone(); - updated_network.address = vec!["123.123.0.0/16".parse::().unwrap()]; + updated_network.set_address(["123.123.0.0/16".parse::().unwrap()]); updated_network.save(&mut *conn).await.unwrap(); let used_ips = updated_network @@ -1455,17 +1420,13 @@ mod test { .await .unwrap(); - let mut network = WireguardNetwork:: { - allow_all_groups: true, - ..Default::default() - }; + let mut network = WireguardNetwork::default(); + network.allow_all_groups = true; network.try_set_address("10.1.1.1/24").unwrap(); let network = network.save(&pool).await.unwrap(); - let mut network_2 = WireguardNetwork:: { - name: "testnetwork2".into(), - allow_all_groups: true, - ..Default::default() - }; + let mut network_2 = WireguardNetwork::default(); + network_2.name = "testnetwork2".into(); + network_2.allow_all_groups = true; network_2.try_set_address("10.1.2.1/24").unwrap(); let network2 = network_2.save(&pool).await.unwrap(); @@ -1564,14 +1525,11 @@ mod test { .await .unwrap(); - let network = WireguardNetwork:: { - address: vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(192, 168, 42, 4)), 29).unwrap()], - allow_all_groups: true, - ..Default::default() - } - .save(&pool) - .await - .unwrap(); + let mut network = WireguardNetwork::default(); + network + .set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::new(192, 168, 42, 4)), 29).unwrap()]); + network.allow_all_groups = true; + let network = network.save(&pool).await.unwrap(); let mut conn = pool.begin().await.unwrap(); diff --git a/crates/defguard_common/src/db/models/wireguard.rs b/crates/defguard_common/src/db/models/wireguard.rs index b78bd38aaf..28d3ea6112 100644 --- a/crates/defguard_common/src/db/models/wireguard.rs +++ b/crates/defguard_common/src/db/models/wireguard.rs @@ -11,10 +11,7 @@ use ipnetwork::{IpNetwork, IpNetworkError, NetworkSize}; use model_derive::Model; use rand::rngs::OsRng; use serde::{Deserialize, Serialize}; -use sqlx::{ - Error as SqlxError, FromRow, PgConnection, PgExecutor, PgPool, Type, query, query_as, - query_scalar, -}; +use sqlx::{FromRow, PgConnection, PgExecutor, PgPool, Type, query, query_as, query_scalar}; use thiserror::Error; use tracing::{debug, info}; use utoipa::ToSchema; @@ -109,7 +106,7 @@ pub struct WireguardNetwork { pub name: String, #[model(ref)] #[schema(value_type = Vec)] - pub address: Vec, + address: Vec, pub port: i32, // Should be u16 pub pubkey: String, #[serde(default, skip_serializing)] @@ -215,9 +212,9 @@ pub enum NetworkAddressError { impl WireguardNetwork { #[allow(clippy::too_many_arguments)] #[must_use] - pub fn new( + pub fn new( name: String, - address: Vec, + address: V, port: i32, endpoint: String, dns: Option, @@ -231,13 +228,16 @@ impl WireguardNetwork { acl_default_allow: bool, location_mfa_mode: LocationMfaMode, service_location_mode: ServiceLocationMode, - ) -> Self { + ) -> Self + where + V: Into>, + { let prvkey = StaticSecret::random_from_rng(OsRng); let pubkey = PublicKey::from(&prvkey); Self { id: NoId, name, - address, + address: address.into(), port, pubkey: BASE64_STANDARD.encode(pubkey.to_bytes()), prvkey: BASE64_STANDARD.encode(prvkey.to_bytes()), @@ -261,24 +261,53 @@ impl WireguardNetwork { pub fn try_set_address(&mut self, address: &str) -> Result<(), IpNetworkError> { let address = parse_address_list(address); if address.is_empty() { - return Err(IpNetworkError::InvalidAddr("invalid address".into())); + Err(IpNetworkError::InvalidAddr("invalid address".into())) + } else { + self.address = address; + Ok(()) } - self.address = address; + } +} - Ok(()) +impl WireguardNetwork { + /// Address list getter. + pub fn address(&self) -> &[IpNetwork] { + self.address.as_slice() + } + + /// Address list setter. + pub fn set_address(&mut self, address: V) + where + V: Into>, + { + self.address = address.into(); + } + + /// Validate addresses. + pub fn address_is_valid(&self) -> bool { + for addr in &self.address { + let ip = addr.ip(); + if ip == addr.network() || ip == addr.broadcast() { + return false; + } + } + + true } } impl WireguardNetwork { + /// Try to find `WireguardNetwork` with the given name. pub async fn find_by_name<'e, E>(executor: E, name: &str) -> sqlx::Result>> where E: PgExecutor<'e>, { let networks = query_as!( - WireguardNetwork, + Self, "SELECT id, name, address, port, pubkey, prvkey, endpoint, dns, mtu, fwmark, \ - allowed_ips, allow_all_groups, connected_at, keepalive_interval, peer_disconnect_threshold, \ - acl_enabled, acl_default_allow, location_mfa_mode \"location_mfa_mode: LocationMfaMode\", \ + allowed_ips, allow_all_groups, connected_at, keepalive_interval, \ + peer_disconnect_threshold, acl_enabled, acl_default_allow, \ + location_mfa_mode \"location_mfa_mode: LocationMfaMode\", \ service_location_mode \"service_location_mode: ServiceLocationMode\" \ FROM wireguard_network WHERE name = $1", name @@ -293,6 +322,52 @@ impl WireguardNetwork { Ok(Some(networks)) } + /// Gets the first network of the network device. + /// FIXME: Return only one network, not a Vec. + pub async fn find_network_device_networks<'e, E>( + executor: E, + device_id: Id, + ) -> sqlx::Result> + where + E: PgExecutor<'e>, + { + query_as!( + Self, + "SELECT id, name, address, port, pubkey, prvkey, endpoint, dns, mtu, fwmark, \ + allowed_ips, allow_all_groups, connected_at, keepalive_interval, \ + peer_disconnect_threshold, acl_enabled, acl_default_allow, \ + location_mfa_mode \"location_mfa_mode: LocationMfaMode\", \ + service_location_mode \"service_location_mode: ServiceLocationMode\" \ + FROM wireguard_network WHERE id IN \ + (SELECT wireguard_network_id FROM wireguard_network_device \ + WHERE device_id = $1 ORDER BY id LIMIT 1)", + device_id + ) + .fetch_all(executor) + .await + } + + /// Find all for a given rule `Id`. + pub async fn all_for_rule<'e, E>(executor: E, rule_id: Id) -> sqlx::Result> + where + E: PgExecutor<'e>, + { + query_as!( + Self, + "SELECT n.id, name, address, port, pubkey, prvkey, endpoint, dns, mtu, fwmark, \ + allowed_ips, allow_all_groups, connected_at, keepalive_interval, \ + peer_disconnect_threshold, acl_enabled, acl_default_allow, \ + location_mfa_mode \"location_mfa_mode: LocationMfaMode\", \ + service_location_mode \"service_location_mode: ServiceLocationMode\" \ + FROM aclrulenetwork r \ + JOIN wireguard_network n ON n.id = r.network_id \ + WHERE r.rule_id = $1", + rule_id, + ) + .fetch_all(executor) + .await + } + /// Check if given number of devices can fit in networks used by this location. /// Note: `device_count` should include network and broadcast addresses. pub fn validate_network_size(&self, device_count: usize) -> Result<(), WireguardNetworkError> { @@ -576,7 +651,7 @@ impl WireguardNetwork { from: &NaiveDateTime, aggregation: &DateTimeAggregation, device_type: DeviceType, - ) -> Result, SqlxError> { + ) -> sqlx::Result> { // Retrieve currently connected devices from database let devices = query_as!( Device, @@ -961,7 +1036,7 @@ impl WireguardNetwork { &self, pool: &PgPool, from: &NaiveDateTime, - ) -> Result { + ) -> sqlx::Result { let total_activity = query_as!( WireguardNetworkActivityStats, "SELECT \ @@ -981,10 +1056,7 @@ impl WireguardNetwork { } /// Retrieves currently connected sessions stats - async fn current_activity( - &self, - pool: &PgPool, - ) -> Result { + async fn current_activity(&self, pool: &PgPool) -> sqlx::Result { let current_activity = query_as!( WireguardNetworkActivityStats, "SELECT \ @@ -1038,7 +1110,7 @@ impl WireguardNetwork { pool: &PgPool, from: &NaiveDateTime, aggregation: &DateTimeAggregation, - ) -> Result { + ) -> sqlx::Result { let total_activity = self.total_activity(pool, from).await?; let current_activity = self.current_activity(pool).await?; let transfer_series = self.transfer_series(pool, from, aggregation).await?; @@ -1318,7 +1390,7 @@ impl WireguardNetwork { pub async fn get_active_vpn_sessions<'e, E: sqlx::PgExecutor<'e>>( &self, executor: E, - ) -> Result>, SqlxError> { + ) -> sqlx::Result>> { query_as!( VpnClientSession, "SELECT id, location_id, user_id, device_id, \ @@ -1336,7 +1408,7 @@ impl WireguardNetwork { pub async fn all_used_ips_for_network( &self, transaction: &mut PgConnection, - ) -> Result, SqlxError> { + ) -> sqlx::Result> { let all_devices = WireguardNetworkDevice::all_for_network(&mut *transaction, self.id).await?; let used_ips: HashSet = all_devices @@ -1479,7 +1551,7 @@ pub async fn networks_stats( pool: &PgPool, from: &NaiveDateTime, aggregation: &DateTimeAggregation, -) -> Result { +) -> sqlx::Result { // get all active users/devices within specified time window let total_activity = query_as!( WireguardNetworkActivityStats, @@ -1861,7 +1933,7 @@ mod test { let network = WireguardNetwork::new( "network".to_string(), - vec![IpNetwork::from_str("10.1.1.1/24").unwrap()], + [IpNetwork::from_str("10.1.1.1/24").unwrap()], 50051, String::new(), None, @@ -1993,7 +2065,7 @@ mod test { let network = WireguardNetwork::new( "network".to_string(), - vec![ + [ IpNetwork::from_str("10.1.1.1/24").unwrap(), IpNetwork::from_str("fc00::1/112").unwrap(), ], diff --git a/crates/defguard_core/src/enterprise/db/models/acl.rs b/crates/defguard_core/src/enterprise/db/models/acl.rs index f1268aa0f3..472029f5ac 100644 --- a/crates/defguard_core/src/enterprise/db/models/acl.rs +++ b/crates/defguard_core/src/enterprise/db/models/acl.rs @@ -8,12 +8,7 @@ use std::{ use chrono::NaiveDateTime; use defguard_common::db::{ Id, NoId, - models::{ - Device, DeviceType, WireguardNetwork, - group::Group, - user::User, - wireguard::{LocationMfaMode, ServiceLocationMode}, - }, + models::{Device, DeviceType, WireguardNetwork, group::Group, user::User}, }; use ipnetwork::IpNetwork; use model_derive::Model; @@ -965,28 +960,14 @@ impl AclRule { pub(crate) async fn get_networks<'e, E>( &self, executor: E, - ) -> Result>, SqlxError> + ) -> sqlx::Result>> where E: PgExecutor<'e>, { if self.all_locations { WireguardNetwork::all(executor).await } else { - query_as!( - WireguardNetwork, - "SELECT n.id, name, address, port, pubkey, prvkey, endpoint, dns, mtu, fwmark, \ - allowed_ips, allow_all_groups, connected_at, keepalive_interval, \ - peer_disconnect_threshold, acl_enabled, acl_default_allow, \ - location_mfa_mode \"location_mfa_mode: LocationMfaMode\", \ - service_location_mode \"service_location_mode: ServiceLocationMode\" \ - FROM aclrulenetwork r \ - JOIN wireguard_network n \ - ON n.id = r.network_id \ - WHERE r.rule_id = $1", - self.id, - ) - .fetch_all(executor) - .await + WireguardNetwork::all_for_rule(executor, self.id).await } } diff --git a/crates/defguard_core/src/enterprise/db/models/acl/tests.rs b/crates/defguard_core/src/enterprise/db/models/acl/tests.rs index 68aaef108c..82843f3f90 100644 --- a/crates/defguard_core/src/enterprise/db/models/acl/tests.rs +++ b/crates/defguard_core/src/enterprise/db/models/acl/tests.rs @@ -1,7 +1,10 @@ use std::ops::Bound; use defguard_common::{ - db::{models::wireguard::DEFAULT_WIREGUARD_MTU, setup_pool}, + db::{ + models::wireguard::{DEFAULT_WIREGUARD_MTU, LocationMfaMode, ServiceLocationMode}, + setup_pool, + }, utils::parse_address_list, }; use rand::{Rng, thread_rng}; diff --git a/crates/defguard_core/src/enterprise/firewall/mod.rs b/crates/defguard_core/src/enterprise/firewall/mod.rs index c56e012ccf..7a8b61cb3b 100644 --- a/crates/defguard_core/src/enterprise/firewall/mod.rs +++ b/crates/defguard_core/src/enterprise/firewall/mod.rs @@ -55,8 +55,8 @@ pub async fn generate_firewall_rules_from_acls( let location = WireguardNetwork::find_by_id(&mut *conn, location_id) .await? .ok_or(ModelError::NotFound)?; - let has_ipv4_addresses = location.address.iter().any(IpNetwork::is_ipv4); - let has_ipv6_addresses = location.address.iter().any(IpNetwork::is_ipv6); + let has_ipv4_addresses = location.address().iter().any(IpNetwork::is_ipv4); + let has_ipv6_addresses = location.address().iter().any(IpNetwork::is_ipv6); // convert each ACL into a corresponding `FirewallRule`s for acl in acl_rules { diff --git a/crates/defguard_core/src/enterprise/firewall/tests/all_locations.rs b/crates/defguard_core/src/enterprise/firewall/tests/all_locations.rs index fd6c780191..9ecbe07180 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/all_locations.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/all_locations.rs @@ -20,19 +20,15 @@ async fn test_acl_rules_all_locations_ipv4(_: PgPoolOptions, options: PgConnectO set_test_license_business(); // Create test location - let location_1 = WireguardNetwork { - acl_enabled: true, - address: vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap()], - ..Default::default() - }; + let mut location_1 = WireguardNetwork::default(); + location_1.acl_enabled = true; + location_1.set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap()]); let location_1 = location_1.save(&pool).await.unwrap(); // Create another test location - let location_2 = WireguardNetwork { - acl_enabled: true, - address: vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap()], - ..Default::default() - }; + let mut location_2 = WireguardNetwork::default(); + location_2.acl_enabled = true; + location_2.set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap()]); let location_2 = location_2.save(&pool).await.unwrap(); // Setup some test users and their devices @@ -121,19 +117,15 @@ async fn test_acl_rules_all_locations_ipv6(_: PgPoolOptions, options: PgConnectO let mut rng = thread_rng(); // Create test location - let location_1 = WireguardNetwork { - acl_enabled: true, - address: vec![IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap()], - ..Default::default() - }; + let mut location_1 = WireguardNetwork::default(); + location_1.acl_enabled = true; + location_1.set_address([IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap()]); let location_1 = location_1.save(&pool).await.unwrap(); // Create another test location - let location_2 = WireguardNetwork { - acl_enabled: true, - address: vec![IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap()], - ..Default::default() - }; + let mut location_2 = WireguardNetwork::default(); + location_2.acl_enabled = true; + location_2.set_address([IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap()]); let location_2 = location_2.save(&pool).await.unwrap(); // Setup some test users and their devices @@ -222,25 +214,21 @@ async fn test_acl_rules_all_locations_ipv4_and_ipv6(_: PgPoolOptions, options: P let mut rng = thread_rng(); // Create test location - let location_1 = WireguardNetwork { - acl_enabled: true, - address: vec![ - IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap(), - IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap(), - ], - ..Default::default() - }; + let mut location_1 = WireguardNetwork::default(); + location_1.acl_enabled = true; + location_1.set_address([ + IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap(), + IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap(), + ]); let location_1 = location_1.save(&pool).await.unwrap(); // Create another test location - let location_2 = WireguardNetwork { - acl_enabled: true, - address: vec![ - IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap(), - IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap(), - ], - ..Default::default() - }; + let mut location_2 = WireguardNetwork::default(); + location_2.acl_enabled = true; + location_2.set_address([ + IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap(), + IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap(), + ]); let location_2 = location_2.save(&pool).await.unwrap(); // Setup some test users and their devices diff --git a/crates/defguard_core/src/enterprise/firewall/tests/destination.rs b/crates/defguard_core/src/enterprise/firewall/tests/destination.rs index 538510e999..1ed4cfe919 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/destination.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/destination.rs @@ -150,14 +150,10 @@ async fn test_any_address_overwrites_manual_destination( let mut rng = thread_rng(); - let location = WireguardNetwork { - acl_enabled: true, - address: vec!["10.0.0.0/16".parse().unwrap()], - ..Default::default() - } - .save(&pool) - .await - .unwrap(); + let mut location = WireguardNetwork::default(); + location.acl_enabled = true; + location.set_address(["10.0.0.0/16".parse().unwrap()]); + let location = location.save(&pool).await.unwrap(); create_test_users_and_devices(&mut rng, &pool, vec![&location]).await; @@ -234,14 +230,10 @@ async fn test_any_address_overwrites_destination_alias_addrs( let mut rng = thread_rng(); - let location = WireguardNetwork { - acl_enabled: true, - address: vec!["10.0.0.0/16".parse().unwrap()], - ..Default::default() - } - .save(&pool) - .await - .unwrap(); + let mut location = WireguardNetwork::default(); + location.acl_enabled = true; + location.set_address(["10.0.0.0/16".parse().unwrap()]); + let location = location.save(&pool).await.unwrap(); create_test_users_and_devices(&mut rng, &pool, vec![&location]).await; @@ -336,14 +328,10 @@ async fn test_manual_destination_includes_component_alias_address_range( let mut rng = thread_rng(); - let location = WireguardNetwork { - acl_enabled: true, - address: vec!["10.0.0.0/16".parse().unwrap()], - ..Default::default() - } - .save(&pool) - .await - .unwrap(); + let mut location = WireguardNetwork::default(); + location.acl_enabled = true; + location.set_address(["10.0.0.0/16".parse().unwrap()]); + let location = location.save(&pool).await.unwrap(); create_test_users_and_devices(&mut rng, &pool, vec![&location]).await; @@ -441,14 +429,10 @@ async fn test_manual_destination_merges_rule_and_component_alias_address_ranges( let mut rng = thread_rng(); - let location = WireguardNetwork { - acl_enabled: true, - address: vec!["10.0.0.0/16".parse().unwrap()], - ..Default::default() - } - .save(&pool) - .await - .unwrap(); + let mut location = WireguardNetwork::default(); + location.acl_enabled = true; + location.set_address(["10.0.0.0/16".parse().unwrap()]); + let location = location.save(&pool).await.unwrap(); create_test_users_and_devices(&mut rng, &pool, vec![&location]).await; diff --git a/crates/defguard_core/src/enterprise/firewall/tests/disabled_rules.rs b/crates/defguard_core/src/enterprise/firewall/tests/disabled_rules.rs index 1643e79dd6..74595dd06e 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/disabled_rules.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/disabled_rules.rs @@ -20,11 +20,9 @@ async fn test_disabled_acl_rules_ipv4(_: PgPoolOptions, options: PgConnectOption let mut rng = thread_rng(); // Create test location - let location = WireguardNetwork { - acl_enabled: true, - address: vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap()], - ..Default::default() - }; + let mut location = WireguardNetwork::default(); + location.acl_enabled = true; + location.set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap()]); let location = location.save(&pool).await.unwrap(); // Setup some test users and their devices @@ -94,11 +92,9 @@ async fn test_disabled_acl_rules_ipv6(_: PgPoolOptions, options: PgConnectOption let mut rng = thread_rng(); // Create test location - let location = WireguardNetwork { - acl_enabled: true, - address: vec![IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap()], - ..Default::default() - }; + let mut location = WireguardNetwork::default(); + location.acl_enabled = true; + location.set_address([IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap()]); let location = location.save(&pool).await.unwrap(); // Setup some test users and their devices @@ -168,14 +164,12 @@ async fn test_disabled_acl_rules_ipv4_and_ipv6(_: PgPoolOptions, options: PgConn let mut rng = thread_rng(); // Create test location - let location = WireguardNetwork { - acl_enabled: true, - address: vec![ - IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap(), - IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap(), - ], - ..Default::default() - }; + let mut location = WireguardNetwork::default(); + location.acl_enabled = true; + location.set_address([ + IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap(), + IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap(), + ]); let location = location.save(&pool).await.unwrap(); // Setup some test users and their devices diff --git a/crates/defguard_core/src/enterprise/firewall/tests/expired_rules.rs b/crates/defguard_core/src/enterprise/firewall/tests/expired_rules.rs index 2137b529cd..53f9fac5ae 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/expired_rules.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/expired_rules.rs @@ -15,10 +15,8 @@ async fn test_expired_acl_rules_ipv4(_: PgPoolOptions, options: PgConnectOptions set_test_license_business(); let pool = setup_pool(options).await; // Create test location - let location = WireguardNetwork { - acl_enabled: true, - ..Default::default() - }; + let mut location = WireguardNetwork::default(); + location.acl_enabled = true; let location = location.save(&pool).await.unwrap(); // create expired ACL rules @@ -79,11 +77,9 @@ async fn test_expired_acl_rules_ipv6(_: PgPoolOptions, options: PgConnectOptions set_test_license_business(); let pool = setup_pool(options).await; // Create test location - let location = WireguardNetwork { - acl_enabled: true, - address: vec![IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap()], - ..Default::default() - }; + let mut location = WireguardNetwork::default(); + location.acl_enabled = true; + location.set_address([IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap()]); let location = location.save(&pool).await.unwrap(); // create expired ACL rules @@ -144,14 +140,12 @@ async fn test_expired_acl_rules_ipv4_and_ipv6(_: PgPoolOptions, options: PgConne set_test_license_business(); let pool = setup_pool(options).await; // Create test location - let location = WireguardNetwork { - acl_enabled: true, - address: vec![ - IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap(), - IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap(), - ], - ..Default::default() - }; + let mut location = WireguardNetwork::default(); + location.acl_enabled = true; + location.set_address([ + IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap(), + IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap(), + ]); let location = location.save(&pool).await.unwrap(); // create expired ACL rules diff --git a/crates/defguard_core/src/enterprise/firewall/tests/gh1868.rs b/crates/defguard_core/src/enterprise/firewall/tests/gh1868.rs index 24ee604d2c..62d5704029 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/gh1868.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/gh1868.rs @@ -40,7 +40,7 @@ async fn setup_user_and_device( let device = device.save(pool).await.unwrap(); let wireguard_ips = location - .address + .address() .iter() .map(|subnet| match subnet { IpNetwork::V4(ipv4_network) => { @@ -83,19 +83,17 @@ async fn test_gh1868_ipv6_rule_is_not_created_with_v4_only_destination( let pool = setup_pool(options).await; let mut rng = thread_rng(); - // Create test location with both IPv4 and IPv6 subnet - let location = WireguardNetwork { - acl_enabled: true, - address: vec![ - IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 0, 80, 1)), 24).unwrap(), - IpNetwork::new( - IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)), - 64, - ) - .unwrap(), - ], - ..Default::default() - }; + // Create test location with both IPv4 and IPv6 subnet. + let mut location = WireguardNetwork::default(); + location.acl_enabled = true; + location.set_address([ + IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 0, 80, 1)), 24).unwrap(), + IpNetwork::new( + IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)), + 64, + ) + .unwrap(), + ]); let location = location.save(&pool).await.unwrap(); // setup user & device @@ -146,19 +144,17 @@ async fn test_gh1868_ipv4_rule_is_not_created_with_v6_only_destination( let mut rng = thread_rng(); - // Create test location with both IPv4 and IPv6 subnet - let location = WireguardNetwork { - acl_enabled: true, - address: vec![ - IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 0, 80, 1)), 24).unwrap(), - IpNetwork::new( - IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)), - 64, - ) - .unwrap(), - ], - ..Default::default() - }; + // Create test location with both IPv4 and IPv6 subnet. + let mut location = WireguardNetwork::default(); + location.acl_enabled = true; + location.set_address([ + IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 0, 80, 1)), 24).unwrap(), + IpNetwork::new( + IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)), + 64, + ) + .unwrap(), + ]); let location = location.save(&pool).await.unwrap(); // setup user & device @@ -208,18 +204,16 @@ async fn test_gh1868_ipv4_and_ipv6_rules_are_created_with_any_destination( let mut rng = thread_rng(); // Create test location with both IPv4 and IPv6 subnet - let location = WireguardNetwork { - acl_enabled: true, - address: vec![ - IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 0, 80, 1)), 24).unwrap(), - IpNetwork::new( - IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)), - 64, - ) - .unwrap(), - ], - ..Default::default() - }; + let mut location = WireguardNetwork::default(); + location.acl_enabled = true; + location.set_address([ + IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 0, 80, 1)), 24).unwrap(), + IpNetwork::new( + IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)), + 64, + ) + .unwrap(), + ]); let location = location.save(&pool).await.unwrap(); // setup user & device diff --git a/crates/defguard_core/src/enterprise/firewall/tests/mod.rs b/crates/defguard_core/src/enterprise/firewall/tests/mod.rs index c8117ac25a..dbe2d005db 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/mod.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/mod.rs @@ -105,7 +105,7 @@ async fn create_test_users_and_devices( // Add device to locations' VPN network for location in &test_locations { let wireguard_ips = location - .address + .address() .iter() .map(|subnet| match subnet { IpNetwork::V4(ipv4_network) => { @@ -249,10 +249,8 @@ async fn test_generate_firewall_rules_ipv4(_: PgPoolOptions, options: PgConnectO let mut rng = thread_rng(); // Create test location - let location = WireguardNetwork { - acl_enabled: false, - ..Default::default() - }; + let mut location = WireguardNetwork::default(); + location.acl_enabled = false; let mut location = location.save(&pool).await.unwrap(); // Setup test users and their devices @@ -672,11 +670,9 @@ async fn test_generate_firewall_rules_ipv6(_: PgPoolOptions, options: PgConnectO let mut rng = thread_rng(); // Create test location - let location = WireguardNetwork { - acl_enabled: false, - address: vec![IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap()], - ..Default::default() - }; + let mut location = WireguardNetwork::default(); + location.acl_enabled = false; + location.set_address([IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap()]); let mut location = location.save(&pool).await.unwrap(); // Setup test users and their devices @@ -1125,14 +1121,12 @@ async fn test_generate_firewall_rules_ipv4_and_ipv6(_: PgPoolOptions, options: P let mut rng = thread_rng(); // Create test location - let location = WireguardNetwork { - acl_enabled: false, - address: vec![ - IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap(), - IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap(), - ], - ..Default::default() - }; + let mut location = WireguardNetwork::default(); + location.acl_enabled = false; + location.set_address([ + IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap(), + IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap(), + ]); let mut location = location.save(&pool).await.unwrap(); // Setup test users and their devices @@ -1764,14 +1758,10 @@ async fn test_alias_kinds(_: PgPoolOptions, options: PgConnectOptions) { let mut rng = thread_rng(); // Create test location - let location = WireguardNetwork { - acl_enabled: true, - address: vec!["10.0.0.0/16".parse().unwrap()], - ..Default::default() - } - .save(&pool) - .await - .unwrap(); + let mut location = WireguardNetwork::default(); + location.acl_enabled = true; + location.set_address(["10.0.0.0/16".parse().unwrap()]); + let location = location.save(&pool).await.unwrap(); // Setup some test users and their devices create_test_users_and_devices(&mut rng, &pool, vec![&location]).await; @@ -1918,14 +1908,10 @@ async fn test_destination_alias_only_acl(_: PgPoolOptions, options: PgConnectOpt let mut rng = thread_rng(); // Create test location - let location = WireguardNetwork { - acl_enabled: true, - address: vec!["10.0.0.0/16".parse().unwrap()], - ..Default::default() - } - .save(&pool) - .await - .unwrap(); + let mut location = WireguardNetwork::default(); + location.acl_enabled = true; + location.set_address(["10.0.0.0/16".parse().unwrap()]); + let location = location.save(&pool).await.unwrap(); // Setup some test users and their devices create_test_users_and_devices(&mut rng, &pool, vec![&location]).await; @@ -2084,11 +2070,9 @@ async fn test_no_allowed_users_ipv4(_: PgPoolOptions, options: PgConnectOptions) let pool = setup_pool(options).await; // Create test location - let location = WireguardNetwork { - acl_enabled: true, - address: vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap()], - ..Default::default() - }; + let mut location = WireguardNetwork::default(); + location.acl_enabled = true; + location.set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap()]); let location = location.save(&pool).await.unwrap(); // create ACL rules @@ -2145,33 +2129,21 @@ async fn test_empty_manual_destination_only_acl(_: PgPoolOptions, options: PgCon let mut rng = thread_rng(); // Create test locations with IPv4 and IPv6 addresses - let location_ipv4 = WireguardNetwork { - acl_enabled: true, - address: vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap()], - ..Default::default() - } - .save(&pool) - .await - .unwrap(); - let location_ipv6 = WireguardNetwork { - acl_enabled: true, - address: vec![IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap()], - ..Default::default() - } - .save(&pool) - .await - .unwrap(); - let location_ipv4_and_ipv6 = WireguardNetwork { - acl_enabled: true, - address: vec![ - IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap(), - IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap(), - ], - ..Default::default() - } - .save(&pool) - .await - .unwrap(); + let mut location_ipv4 = WireguardNetwork::default(); + location_ipv4.acl_enabled = true; + location_ipv4.set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap()]); + let location_ipv4 = location_ipv4.save(&pool).await.unwrap(); + let mut location_ipv6 = WireguardNetwork::default(); + location_ipv6.acl_enabled = true; + location_ipv6.set_address([IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap()]); + let location_ipv6 = location_ipv6.save(&pool).await.unwrap(); + let mut location_ipv4_and_ipv6 = WireguardNetwork::default(); + location_ipv4_and_ipv6.acl_enabled = true; + location_ipv4_and_ipv6.set_address([ + IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap(), + IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap(), + ]); + let location_ipv4_and_ipv6 = location_ipv4_and_ipv6.save(&pool).await.unwrap(); // Setup some test users and their devices let user_1: User = rng.r#gen(); diff --git a/crates/defguard_core/src/enterprise/firewall/tests/unapplied_rules.rs b/crates/defguard_core/src/enterprise/firewall/tests/unapplied_rules.rs index 39c39b7d5f..78a6e13b31 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/unapplied_rules.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/unapplied_rules.rs @@ -20,11 +20,9 @@ async fn test_unapplied_acl_rules_ipv4(_: PgPoolOptions, options: PgConnectOptio let mut rng = thread_rng(); // Create test location - let location = WireguardNetwork { - acl_enabled: true, - address: vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap()], - ..Default::default() - }; + let mut location = WireguardNetwork::default(); + location.acl_enabled = true; + location.set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap()]); let location = location.save(&pool).await.unwrap(); // Setup some test users and their devices @@ -94,11 +92,9 @@ async fn test_unapplied_acl_rules_ipv6(_: PgPoolOptions, options: PgConnectOptio let mut rng = thread_rng(); // Create test location - let location = WireguardNetwork { - acl_enabled: true, - address: vec![IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap()], - ..Default::default() - }; + let mut location = WireguardNetwork::default(); + location.acl_enabled = true; + location.set_address([IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap()]); let location = location.save(&pool).await.unwrap(); // Setup some test users and their devices @@ -168,14 +164,12 @@ async fn test_unapplied_acl_rules_ipv4_and_ipv6(_: PgPoolOptions, options: PgCon let mut rng = thread_rng(); // Create test location - let location = WireguardNetwork { - acl_enabled: true, - address: vec![ - IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap(), - IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap(), - ], - ..Default::default() - }; + let mut location = WireguardNetwork::default(); + location.acl_enabled = true; + location.set_address([ + IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap(), + IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap(), + ]); let location = location.save(&pool).await.unwrap(); // Setup some test users and their devices diff --git a/crates/defguard_core/src/handlers/network_devices.rs b/crates/defguard_core/src/handlers/network_devices.rs index d32bd05829..0e5fc23dca 100644 --- a/crates/defguard_core/src/handlers/network_devices.rs +++ b/crates/defguard_core/src/handlers/network_devices.rs @@ -59,8 +59,7 @@ impl NetworkDeviceInfo { device: Device, transaction: &mut PgConnection, ) -> Result { - let network = device - .find_network_device_networks(&mut *transaction) + let network = WireguardNetwork::find_network_device_networks(&mut *transaction, device.id) .await? .pop() .ok_or(WebError::ObjectNotFound(format!( @@ -121,8 +120,7 @@ pub async fn download_network_device_config( .ok_or(WebError::ObjectNotFound(format!( "Network device with ID {device_id} not found" )))?; - let network = device - .find_network_device_networks(&appstate.pool) + let network = WireguardNetwork::find_network_device_networks(&appstate.pool, device_id) .await? .pop() .ok_or(WebError::ObjectNotFound(format!( @@ -328,7 +326,7 @@ pub(crate) async fn find_available_ips( let mut transaction = appstate.pool.begin().await?; let mut split_ips = Vec::new(); - for network_address in &network.address { + for network_address in network.address() { let net_ip = network_address.ip(); let net_network = network_address.network(); let net_broadcast = network_address.broadcast(); @@ -349,10 +347,11 @@ pub(crate) async fn find_available_ips( } transaction.commit().await?; - if split_ips.len() != network.address.len() { + if split_ips.len() != network.address().len() { warn!( "Failed to find available IPs for new device in network {} ({:?})", - network.name, network.address + network.name, + network.address() ); return Err(WebError::NetworkFull(format!( "Network {} is full, no IP addresses available", @@ -361,7 +360,9 @@ pub(crate) async fn find_available_ips( } debug!( "Found addresses {:?} for new device in network {} ({:?})", - split_ips, network.name, network.address + split_ips, + network.name, + network.address() ); Ok(ApiResponse::json(split_ips, StatusCode::OK)) } @@ -693,14 +694,14 @@ pub async fn modify_network_device( })?; // store device before modifications let before = device.clone(); - let device_network = device - .find_network_device_networks(&mut *transaction) - .await? - .pop() - .ok_or_else(|| { - error!("Failed to update device {device_id}, device not found in any network"); - WebError::ObjectNotFound(format!("Device {device_id} not found in any network")) - })?; + let device_network = + WireguardNetwork::find_network_device_networks(&mut *transaction, device_id) + .await? + .pop() + .ok_or_else(|| { + error!("Failed to update device {device_id}, device not found in any network"); + WebError::ObjectNotFound(format!("Device {device_id} not found in any network")) + })?; let mut wireguard_network_device = WireguardNetworkDevice::find(&mut *transaction, device.id, device_network.id) .await? diff --git a/crates/defguard_core/src/handlers/wireguard.rs b/crates/defguard_core/src/handlers/wireguard.rs index d3e02bdf27..70bd9a2ba6 100644 --- a/crates/defguard_core/src/handlers/wireguard.rs +++ b/crates/defguard_core/src/handlers/wireguard.rs @@ -348,8 +348,7 @@ pub(crate) async fn modify_network( // store network before mods let before = network.clone(); let new_addresses = data.parse_addresses()?; - - network.address = new_addresses; + network.set_address(new_addresses); network.allowed_ips = data.parse_allowed_ips(); network.name = data.name; diff --git a/crates/defguard_core/src/lib.rs b/crates/defguard_core/src/lib.rs index ba963d2067..3392481ac7 100644 --- a/crates/defguard_core/src/lib.rs +++ b/crates/defguard_core/src/lib.rs @@ -761,7 +761,7 @@ pub async fn init_dev_env(config: &DefGuardConfig) { info!("Creating test network"); let mut network = WireguardNetwork::new( "TestNet".to_string(), - vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 1, 1, 1)), 24).unwrap()], + [IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 1, 1, 1)), 24).unwrap()], 50051, "0.0.0.0".to_string(), None, @@ -852,7 +852,7 @@ pub async fn init_vpn_location( WireguardNetwork::find_by_id(&mut *transaction, location_id).await? { network.name.clone_from(&args.name); - network.address = vec![args.address]; + network.set_address([args.address]); network.port = args.port; network.endpoint.clone_from(&args.endpoint); network.dns.clone_from(&args.dns); diff --git a/crates/defguard_core/src/location_management/allowed_peers.rs b/crates/defguard_core/src/location_management/allowed_peers.rs index c5848d43b8..22dce2c43b 100644 --- a/crates/defguard_core/src/location_management/allowed_peers.rs +++ b/crates/defguard_core/src/location_management/allowed_peers.rs @@ -127,12 +127,10 @@ mod test { .unwrap(); // Normal location (service_location_mode = Disabled) should return peers - let mut network_normal = WireguardNetwork { - name: "normal-location".to_string(), - service_location_mode: ServiceLocationMode::Disabled, - location_mfa_mode: LocationMfaMode::Disabled, - ..Default::default() - }; + let mut network_normal = WireguardNetwork::default(); + network_normal.name = "normal-location".to_string(); + network_normal.service_location_mode = ServiceLocationMode::Disabled; + network_normal.location_mfa_mode = LocationMfaMode::Disabled; network_normal.try_set_address("10.1.1.1/24").unwrap(); let network_normal = network_normal.save(&pool).await.unwrap(); @@ -152,12 +150,10 @@ mod test { assert_eq!(peers_normal[0].pubkey, "pubkey1"); // Service location with PreLogon mode returns peers when enterprise is enabled (test env default) - let mut network_prelogon = WireguardNetwork { - name: "prelogon-service-location".to_string(), - service_location_mode: ServiceLocationMode::PreLogon, - location_mfa_mode: LocationMfaMode::Disabled, - ..Default::default() - }; + let mut network_prelogon = WireguardNetwork::default(); + network_prelogon.name = "prelogon-service-location".to_string(); + network_prelogon.service_location_mode = ServiceLocationMode::PreLogon; + network_prelogon.location_mfa_mode = LocationMfaMode::Disabled; network_prelogon.try_set_address("10.2.1.1/24").unwrap(); let network_prelogon = network_prelogon.save(&pool).await.unwrap(); @@ -182,12 +178,10 @@ mod test { assert_eq!(peers_prelogon[0].pubkey, "pubkey2"); // Service location with AlwaysOn mode also returns peers when enterprise is enabled - let mut network_alwayson = WireguardNetwork { - name: "alwayson-service-location".to_string(), - service_location_mode: ServiceLocationMode::AlwaysOn, - location_mfa_mode: LocationMfaMode::Disabled, - ..Default::default() - }; + let mut network_alwayson = WireguardNetwork::default(); + network_alwayson.name = "alwayson-service-location".to_string(); + network_alwayson.service_location_mode = ServiceLocationMode::AlwaysOn; + network_alwayson.location_mfa_mode = LocationMfaMode::Disabled; network_alwayson.try_set_address("10.3.1.1/24").unwrap(); let network_alwayson = network_alwayson.save(&pool).await.unwrap(); diff --git a/crates/defguard_core/src/location_management/mod.rs b/crates/defguard_core/src/location_management/mod.rs index 18e5fc1d9b..eef808319e 100644 --- a/crates/defguard_core/src/location_management/mod.rs +++ b/crates/defguard_core/src/location_management/mod.rs @@ -174,7 +174,7 @@ pub async fn process_device_access_changes( if let Some(device) = allowed_devices.remove(&device_network_config.device_id) { // Network address has changed and IP addresses need to be updated if !location.contains_all(&device_network_config.wireguard_ips) - || location.address.len() != device_network_config.wireguard_ips.len() + || location.address().len() != device_network_config.wireguard_ips.len() { let wireguard_network_device = device .assign_next_network_ip( diff --git a/crates/defguard_core/src/location_management/tests.rs b/crates/defguard_core/src/location_management/tests.rs index a61dc0660d..c640900b7b 100644 --- a/crates/defguard_core/src/location_management/tests.rs +++ b/crates/defguard_core/src/location_management/tests.rs @@ -1,7 +1,6 @@ use std::net::{IpAddr, Ipv4Addr}; use defguard_common::db::{ - NoId, models::{Device, DeviceType, User, WireguardNetwork, device::WireguardNetworkDevice}, setup_pool, }; @@ -23,14 +22,10 @@ fn test_network_readdress(_: PgPoolOptions, options: PgConnectOptions) { // 192.168.42.45: device // 192.168.42.46: gateway // 192.168.42.47: broadcast - let mut network = WireguardNetwork:: { - address: vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(192, 168, 42, 46)), 30).unwrap()], - allow_all_groups: true, - ..Default::default() - } - .save(&pool) - .await - .unwrap(); + let mut network = WireguardNetwork::default(); + network.set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::new(192, 168, 42, 46)), 30).unwrap()]); + network.allow_all_groups = true; + let mut network = network.save(&pool).await.unwrap(); let mut conn = pool.begin().await.unwrap(); @@ -73,8 +68,7 @@ fn test_network_readdress(_: PgPoolOptions, options: PgConnectOptions) { // 192.168.42.77: gateway // 192.168.42.78: device // 192.168.42.79: broadcast - network.address = - vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(192, 168, 42, 77)), 30).unwrap()]; + network.set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::new(192, 168, 42, 77)), 30).unwrap()]); network.save(&pool).await.unwrap(); // Re-address the network. diff --git a/crates/defguard_core/src/wg_config.rs b/crates/defguard_core/src/wg_config.rs index 82b055d0bc..8c48afd5f1 100644 --- a/crates/defguard_core/src/wg_config.rs +++ b/crates/defguard_core/src/wg_config.rs @@ -214,7 +214,7 @@ mod test { ); assert_eq!(network.id, NoId); assert_eq!(network.name, "Y5ewP5RXstQd71gkmS/M0xL8wi0yVbbVY/ocLM4cQ1Y="); - assert_eq!(network.address, vec!["10.0.0.1/24".parse().unwrap()]); + assert_eq!(network.address(), ["10.0.0.1/24".parse().unwrap()]); assert_eq!(network.port, 55055); assert_eq!( network.pubkey, @@ -281,8 +281,8 @@ mod test { assert_eq!(network.id, NoId); assert_eq!(network.name, "Y5ewP5RXstQd71gkmS/M0xL8wi0yVbbVY/ocLM4cQ1Y="); assert_eq!( - network.address, - vec![ + network.address(), + [ "10.0.0.1/24".parse().unwrap(), "fc00::/112".parse().unwrap() ] diff --git a/crates/defguard_core/tests/integration/api/wireguard_network_devices.rs b/crates/defguard_core/tests/integration/api/wireguard_network_devices.rs index d0cc837ff7..73de4e4c66 100644 --- a/crates/defguard_core/tests/integration/api/wireguard_network_devices.rs +++ b/crates/defguard_core/tests/integration/api/wireguard_network_devices.rs @@ -232,10 +232,10 @@ async fn test_network_devices(_: PgPoolOptions, options: PgConnectOptions) { assert_matches!(event, GatewayEvent::DeviceModified(..)); // Make sure the device is only in the selected network - let device_networks = device - .find_network_device_networks(&client_state.pool) - .await - .unwrap(); + let device_networks = + WireguardNetwork::find_network_device_networks(&client_state.pool, device_id) + .await + .unwrap(); assert_eq!(device_networks.len(), 1); assert_eq!(network_1.id, device_networks[0].id); @@ -277,10 +277,10 @@ async fn test_network_devices(_: PgPoolOptions, options: PgConnectOptions) { .unwrap(); assert!(!device.configured); assert_eq!(device.name, "device-2"); - let device_network = device - .find_network_device_networks(&client_state.pool) - .await - .unwrap(); + let device_network = + WireguardNetwork::find_network_device_networks(&client_state.pool, device_id) + .await + .unwrap(); assert_eq!(device_network.len(), 1); assert_eq!(device_network[0].id, network_1.id); diff --git a/crates/defguard_core/tests/integration/api/wireguard_network_import.rs b/crates/defguard_core/tests/integration/api/wireguard_network_import.rs index 904747b174..0925b3e23e 100644 --- a/crates/defguard_core/tests/integration/api/wireguard_network_import.rs +++ b/crates/defguard_core/tests/integration/api/wireguard_network_import.rs @@ -52,7 +52,7 @@ async fn test_config_import(_: PgPoolOptions, options: PgConnectOptions) { // setup initial network let initial_network = WireguardNetwork::new( "initial".into(), - vec!["10.1.9.0/24".parse().unwrap()], + ["10.1.9.0/24".parse().unwrap()], 51515, String::new(), None, @@ -130,7 +130,7 @@ async fn test_config_import(_: PgPoolOptions, options: PgConnectOptions) { let network = response.network; assert_eq!(network.id, 2); assert_eq!(network.name, "network"); - assert_eq!(network.address, vec!["10.0.0.1/24".parse().unwrap()]); + assert_eq!(network.address(), ["10.0.0.1/24".parse().unwrap()]); assert_eq!(network.port, 55055); assert_eq!( network.pubkey, diff --git a/crates/defguard_gateway_manager/src/handler.rs b/crates/defguard_gateway_manager/src/handler.rs index e10a012959..f086ee35a5 100644 --- a/crates/defguard_gateway_manager/src/handler.rs +++ b/crates/defguard_gateway_manager/src/handler.rs @@ -616,7 +616,7 @@ impl GatewayUpdatesHandler { update: Some(update::Update::Network(Configuration { name: network.name.clone(), prvkey: network.prvkey.clone(), - addresses: network.address.iter().map(ToString::to_string).collect(), + addresses: network.address().iter().map(ToString::to_string).collect(), port: network.port.cast_unsigned(), peers, firewall_config, @@ -803,7 +803,7 @@ fn gen_config( name: network.name.clone(), port: network.port.cast_unsigned(), prvkey: network.prvkey.clone(), - addresses: network.address.iter().map(ToString::to_string).collect(), + addresses: network.address().iter().map(ToString::to_string).collect(), peers, firewall_config: maybe_firewall_config, mtu: network.mtu.cast_unsigned(), diff --git a/crates/defguard_proxy_manager/src/servers/enrollment.rs b/crates/defguard_proxy_manager/src/servers/enrollment.rs index a57a1bb17e..fcf44e7979 100644 --- a/crates/defguard_proxy_manager/src/servers/enrollment.rs +++ b/crates/defguard_proxy_manager/src/servers/enrollment.rs @@ -627,16 +627,16 @@ impl EnrollmentServer { Status::internal("unexpected error") })?; - let mut networks = device - .find_network_device_networks(&mut *transaction) - .await - .map_err(|err| { - error!( - "Failed to find networks for device {} for user {}({:?}): {err}", - device.name, user.username, user.id - ); - Status::internal("unexpected error") - })?; + let mut networks = + WireguardNetwork::find_network_device_networks(&mut *transaction, device_id) + .await + .map_err(|err| { + error!( + "Failed to find networks for device {} for user {}({:?}): {err}", + device.name, user.username, user.id + ); + Status::internal("unexpected error") + })?; let Some(network) = networks.pop() else { error!( diff --git a/crates/defguard_session_manager/tests/common/mod.rs b/crates/defguard_session_manager/tests/common/mod.rs index ddd8181288..84325c8d57 100644 --- a/crates/defguard_session_manager/tests/common/mod.rs +++ b/crates/defguard_session_manager/tests/common/mod.rs @@ -118,7 +118,7 @@ pub(crate) async fn create_location_with_mfa_mode( ) -> WireguardNetwork { WireguardNetwork::new( "TestNet".to_string(), - vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 0)), 24).unwrap()], + [IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 0)), 24).unwrap()], 51820, "10.0.0.1".to_string(), None, diff --git a/crates/defguard_setup/src/auto_adoption.rs b/crates/defguard_setup/src/auto_adoption.rs index 8e2e6c1c29..6434226bc3 100644 --- a/crates/defguard_setup/src/auto_adoption.rs +++ b/crates/defguard_setup/src/auto_adoption.rs @@ -738,7 +738,7 @@ id={} for new gateway", let mut transaction = pool.begin().await.context("Failed to begin transaction")?; let network = WireguardNetwork::new( common_name.to_string(), - vec![network_address], + [network_address], DEFAULT_AUTO_ADOPTION_WIREGUARD_PORT, host.to_string(), None, diff --git a/crates/defguard_setup/src/handlers/auto_wizard.rs b/crates/defguard_setup/src/handlers/auto_wizard.rs index 73e47d1bb4..0c542a27ef 100644 --- a/crates/defguard_setup/src/handlers/auto_wizard.rs +++ b/crates/defguard_setup/src/handlers/auto_wizard.rs @@ -137,7 +137,7 @@ pub async fn set_vpn_settings( network.endpoint = vpn_settings.public_ip; network.port = vpn_settings.wireguard_port; - network.address = addresses; + network.set_address(addresses); network.allowed_ips = allowed_ips; network.dns = { let dns = vpn_settings.dns_server_ip.trim(); diff --git a/crates/defguard_setup/tests/auto_adoption_wizard.rs b/crates/defguard_setup/tests/auto_adoption_wizard.rs index c60e98a361..1c4dbbada7 100644 --- a/crates/defguard_setup/tests/auto_adoption_wizard.rs +++ b/crates/defguard_setup/tests/auto_adoption_wizard.rs @@ -41,7 +41,7 @@ async fn assert_auto_adoption_step(pool: &sqlx::PgPool, expected: AutoAdoptionWi async fn seed_wireguard_network(pool: &sqlx::PgPool) -> WireguardNetwork { WireguardNetwork::new( "auto-net".to_string(), - vec!["10.0.0.0/24".parse::().unwrap()], + ["10.0.0.0/24".parse::().unwrap()], 51820, "1.2.3.4".to_string(), None, diff --git a/crates/defguard_setup/tests/wizard_state.rs b/crates/defguard_setup/tests/wizard_state.rs index 362a8d545a..bc572c7e70 100644 --- a/crates/defguard_setup/tests/wizard_state.rs +++ b/crates/defguard_setup/tests/wizard_state.rs @@ -140,7 +140,7 @@ async fn test_wizard_state_auto_adoption(_: PgPoolOptions, options: PgConnectOpt WireguardNetwork::new( "auto-net".to_string(), - vec!["10.0.0.0/24".parse().unwrap()], + ["10.0.0.0/24".parse().unwrap()], 51820, "1.2.3.4".to_string(), None, diff --git a/crates/defguard_static_ip/src/lib.rs b/crates/defguard_static_ip/src/lib.rs index 346ebcaf0a..c96c999a30 100644 --- a/crates/defguard_static_ip/src/lib.rs +++ b/crates/defguard_static_ip/src/lib.rs @@ -265,20 +265,16 @@ mod tests { .expect("Failed to create user"); // Create test locations - let mut location_a = WireguardNetwork { - name: "Location A".into(), - ..Default::default() - }; + let mut location_a = WireguardNetwork::default(); + location_a.name = "Location A".into(); location_a.try_set_address("10.0.1.1/24").unwrap(); let location_a = location_a .save(&pool) .await .expect("Failed to create Location A"); - let mut location_b = WireguardNetwork { - name: "Location B".into(), - ..Default::default() - }; + let mut location_b = WireguardNetwork::default(); + location_b.name = "Location B".into(); location_b.try_set_address("10.0.2.1/24").unwrap(); let location_b = location_b .save(&pool) @@ -384,8 +380,8 @@ mod tests { let locations = result.unwrap(); assert_eq!(locations.len(), 2); - let net_a = location_a.address[0]; - let net_b = location_b.address[0]; + let net_a = location_a.address()[0]; + let net_b = location_b.address()[0]; // Verify Location A assert_eq!(locations[0].location_name, "Location A"); @@ -438,10 +434,8 @@ mod tests { .await .expect("Failed to create user"); - let mut network = WireguardNetwork { - name: "Assign Network".into(), - ..Default::default() - }; + let mut network = WireguardNetwork::default(); + network.name = "Assign Network".into(); network.try_set_address("10.0.0.1/24").unwrap(); let network = network.save(&pool).await.expect("Failed to create network"); @@ -514,10 +508,8 @@ mod tests { .await .expect("Failed to create user"); - let mut network = WireguardNetwork { - name: "NoDevice Network".into(), - ..Default::default() - }; + let mut network = WireguardNetwork::default(); + network.name = "NoDevice Network".into(); network.try_set_address("10.0.0.1/24").unwrap(); let network = network.save(&pool).await.expect("Failed to create network"); @@ -564,10 +556,8 @@ mod tests { .await .expect("Failed to create user"); - let mut network = WireguardNetwork { - name: "Range Network".into(), - ..Default::default() - }; + let mut network = WireguardNetwork::default(); + network.name = "Range Network".into(); network.try_set_address("10.0.0.1/24").unwrap(); let network = network.save(&pool).await.expect("Failed to create network"); @@ -621,10 +611,8 @@ mod tests { .await .expect("Failed to create user"); - let mut network = WireguardNetwork { - name: "Conflict Network".into(), - ..Default::default() - }; + let mut network = WireguardNetwork::default(); + network.name = "Conflict; Network".into(); network.try_set_address("10.0.0.1/24").unwrap(); let network = network.save(&pool).await.expect("Failed to create network"); From ddfef85411beb2daddbceb4de6bd22f42e726128 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Tue, 17 Mar 2026 11:34:25 +0100 Subject: [PATCH 02/10] Reduce WireguardNetwork::new --- .../src/db/models/wireguard.rs | 21 +++-------- .../src/enterprise/db/models/acl/tests.rs | 10 +----- .../src/enterprise/directory_sync/tests.rs | 6 +--- .../defguard_core/src/handlers/wireguard.rs | 10 +++--- crates/defguard_core/src/lib.rs | 35 ++++++------------- crates/defguard_core/src/wg_config.rs | 11 ++---- .../tests/integration/api/acl/mod.rs | 2 +- .../tests/integration/api/acl/rules.rs | 8 ----- .../api/wireguard_network_import.rs | 9 +---- .../tests/common/mod.rs | 4 --- crates/defguard_setup/src/auto_adoption.rs | 9 +---- .../tests/auto_adoption_wizard.rs | 21 ++++++----- crates/defguard_setup/tests/wizard_state.rs | 16 ++++----- 13 files changed, 46 insertions(+), 116 deletions(-) diff --git a/crates/defguard_common/src/db/models/wireguard.rs b/crates/defguard_common/src/db/models/wireguard.rs index 28d3ea6112..cd761586c8 100644 --- a/crates/defguard_common/src/db/models/wireguard.rs +++ b/crates/defguard_common/src/db/models/wireguard.rs @@ -39,6 +39,7 @@ pub const DEFAULT_KEEPALIVE_INTERVAL: i32 = 25; pub const DEFAULT_DISCONNECT_THRESHOLD: i32 = 300; /// Default MTU for WireGuard interfaces. pub const DEFAULT_WIREGUARD_MTU: i32 = 1420; // TODO: use u32 once sqlx supports unsigned integers. +const DEFAULT_FWMARK: i64 = 0; // Zero means: don't use firewall mark. // Used in process of importing network from WireGuard config. #[derive(Clone, Debug, Deserialize, Serialize)] @@ -218,12 +219,8 @@ impl WireguardNetwork { port: i32, endpoint: String, dns: Option, - mtu: i32, - fwmark: i64, allowed_ips: Vec, allow_all_groups: bool, - keepalive_interval: i32, - peer_disconnect_threshold: i32, acl_enabled: bool, acl_default_allow: bool, location_mfa_mode: LocationMfaMode, @@ -243,13 +240,13 @@ impl WireguardNetwork { prvkey: BASE64_STANDARD.encode(prvkey.to_bytes()), endpoint, dns, - mtu, - fwmark, + mtu: DEFAULT_WIREGUARD_MTU, + fwmark: DEFAULT_FWMARK, allowed_ips, allow_all_groups, connected_at: None, - keepalive_interval, - peer_disconnect_threshold, + keepalive_interval: DEFAULT_KEEPALIVE_INTERVAL, + peer_disconnect_threshold: DEFAULT_DISCONNECT_THRESHOLD, acl_enabled, acl_default_allow, location_mfa_mode, @@ -1937,12 +1934,8 @@ mod test { 50051, String::new(), None, - DEFAULT_WIREGUARD_MTU, - 0, vec![IpNetwork::from_str("10.1.1.0/24").unwrap()], false, - 300, - 300, false, false, LocationMfaMode::Disabled, @@ -2072,12 +2065,8 @@ mod test { 50051, String::new(), None, - DEFAULT_WIREGUARD_MTU, - 0, vec![IpNetwork::from_str("10.1.1.0/24").unwrap()], false, - 300, - 300, false, false, LocationMfaMode::Disabled, diff --git a/crates/defguard_core/src/enterprise/db/models/acl/tests.rs b/crates/defguard_core/src/enterprise/db/models/acl/tests.rs index 82843f3f90..db5a82ad05 100644 --- a/crates/defguard_core/src/enterprise/db/models/acl/tests.rs +++ b/crates/defguard_core/src/enterprise/db/models/acl/tests.rs @@ -2,7 +2,7 @@ use std::ops::Bound; use defguard_common::{ db::{ - models::wireguard::{DEFAULT_WIREGUARD_MTU, LocationMfaMode, ServiceLocationMode}, + models::wireguard::{LocationMfaMode, ServiceLocationMode}, setup_pool, }, utils::parse_address_list, @@ -157,12 +157,8 @@ async fn test_rule_relations(_: PgPoolOptions, options: PgConnectOptions) { 1000, "endpoint1".to_string(), None, - DEFAULT_WIREGUARD_MTU, - 0, Vec::new(), true, - 100, - 100, false, false, LocationMfaMode::Disabled, @@ -177,12 +173,8 @@ async fn test_rule_relations(_: PgPoolOptions, options: PgConnectOptions) { 2000, "endpoint2".to_string(), None, - DEFAULT_WIREGUARD_MTU, - 0, Vec::new(), true, - 200, - 200, false, false, LocationMfaMode::Disabled, diff --git a/crates/defguard_core/src/enterprise/directory_sync/tests.rs b/crates/defguard_core/src/enterprise/directory_sync/tests.rs index 6b33822550..ccdd5c0881 100644 --- a/crates/defguard_core/src/enterprise/directory_sync/tests.rs +++ b/crates/defguard_core/src/enterprise/directory_sync/tests.rs @@ -54,16 +54,12 @@ mod test { WireguardNetwork::new( "test".to_string(), - vec![IpNetwork::from_str("10.10.10.1/24").unwrap()], + [IpNetwork::from_str("10.10.10.1/24").unwrap()], 1234, "123.123.123.123".to_string(), None, - 1420, - 0, Vec::new(), true, - 32, - 32, false, false, LocationMfaMode::Disabled, diff --git a/crates/defguard_core/src/handlers/wireguard.rs b/crates/defguard_core/src/handlers/wireguard.rs index 70bd9a2ba6..9ef305b37a 100644 --- a/crates/defguard_core/src/handlers/wireguard.rs +++ b/crates/defguard_core/src/handlers/wireguard.rs @@ -236,23 +236,23 @@ pub(crate) async fn create_network( data.validate_location_mfa_mode(&appstate.pool).await?; let allowed_ips = data.parse_allowed_ips(); - let network = WireguardNetwork::new( + let mut network = WireguardNetwork::new( data.name, parse_address_list(&data.address), data.port, data.endpoint, data.dns, - data.mtu, - data.fwmark, allowed_ips, data.allow_all_groups, - data.keepalive_interval, - data.peer_disconnect_threshold, data.acl_enabled, data.acl_default_allow, data.location_mfa_mode, data.service_location_mode, ); + network.mtu = data.mtu; + network.fwmark = data.fwmark; + network.keepalive_interval = data.keepalive_interval; + network.peer_disconnect_threshold = data.peer_disconnect_threshold; let mut transaction = appstate.pool.begin().await?; let network = network.save(&mut *transaction).await?; diff --git a/crates/defguard_core/src/lib.rs b/crates/defguard_core/src/lib.rs index 3392481ac7..a37f2091d6 100644 --- a/crates/defguard_core/src/lib.rs +++ b/crates/defguard_core/src/lib.rs @@ -23,10 +23,7 @@ use defguard_common::{ Device, DeviceType, Settings, User, WireguardNetwork, oauth2client::OAuth2Client, settings::{initialize_current_settings, update_current_settings}, - wireguard::{ - DEFAULT_DISCONNECT_THRESHOLD, DEFAULT_KEEPALIVE_INTERVAL, DEFAULT_WIREGUARD_MTU, - LocationMfaMode, ServiceLocationMode, - }, + wireguard::{LocationMfaMode, ServiceLocationMode}, }, }, types::proxy::ProxyControlMessage, @@ -765,12 +762,8 @@ pub async fn init_dev_env(config: &DefGuardConfig) { 50051, "0.0.0.0".to_string(), None, - DEFAULT_WIREGUARD_MTU, - 0, vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 1, 1, 0)), 24).unwrap()], true, - DEFAULT_KEEPALIVE_INTERVAL, - DEFAULT_DISCONNECT_THRESHOLD, false, false, LocationMfaMode::Disabled, @@ -863,25 +856,22 @@ pub async fn init_vpn_location( } // Otherwise create it with the predefined ID else { - let network = WireguardNetwork::new( + let mut network = WireguardNetwork::new( args.name.clone(), vec![args.address], args.port, args.endpoint.clone(), args.dns.clone(), - args.mtu as i32, - i64::from(args.fwmark), args.allowed_ips.clone(), true, - DEFAULT_KEEPALIVE_INTERVAL, - DEFAULT_DISCONNECT_THRESHOLD, false, false, LocationMfaMode::Disabled, ServiceLocationMode::Disabled, - ) - .save(&mut *transaction) - .await?; + ); + network.mtu = args.mtu as i32; + network.fwmark = i64::from(args.fwmark); + let network = network.save(&mut *transaction).await?; if network.id != location_id { return Err(anyhow!( "Failed to initialize VPN location. The ID of the newly created network ({}) does not match \ @@ -906,25 +896,22 @@ pub async fn init_vpn_location( } // create a new network - WireguardNetwork::new( + let mut location = WireguardNetwork::new( args.name.clone(), vec![args.address], args.port, args.endpoint.clone(), args.dns.clone(), - args.mtu as i32, - i64::from(args.fwmark), args.allowed_ips.clone(), true, - DEFAULT_KEEPALIVE_INTERVAL, - DEFAULT_DISCONNECT_THRESHOLD, false, false, LocationMfaMode::Disabled, ServiceLocationMode::Disabled, - ) - .save(pool) - .await? + ); + location.mtu = args.mtu as i32; + location.fwmark = i64::from(args.fwmark); + location.save(pool).await? }; // generate gateway token diff --git a/crates/defguard_core/src/wg_config.rs b/crates/defguard_core/src/wg_config.rs index 8c48afd5f1..e80f22fa08 100644 --- a/crates/defguard_core/src/wg_config.rs +++ b/crates/defguard_core/src/wg_config.rs @@ -5,10 +5,7 @@ use defguard_common::{ KEY_LENGTH, db::models::{ Device, WireguardNetwork, - wireguard::{ - DEFAULT_DISCONNECT_THRESHOLD, DEFAULT_KEEPALIVE_INTERVAL, DEFAULT_WIREGUARD_MTU, - LocationMfaMode, ServiceLocationMode, - }, + wireguard::{DEFAULT_WIREGUARD_MTU, LocationMfaMode, ServiceLocationMode}, }, }; use ipnetwork::{IpNetwork, IpNetworkError}; @@ -114,17 +111,15 @@ pub(crate) fn parse_wireguard_config( port, String::new(), dns, - mtu, - fwmark, allowed_ips, true, - DEFAULT_KEEPALIVE_INTERVAL, - DEFAULT_DISCONNECT_THRESHOLD, false, false, LocationMfaMode::Disabled, ServiceLocationMode::Disabled, ); + network.mtu = mtu; + network.fwmark = fwmark; network.pubkey = pubkey; network.prvkey = prvkey.to_string(); diff --git a/crates/defguard_core/tests/integration/api/acl/mod.rs b/crates/defguard_core/tests/integration/api/acl/mod.rs index 952b762302..6f43cc58c5 100644 --- a/crates/defguard_core/tests/integration/api/acl/mod.rs +++ b/crates/defguard_core/tests/integration/api/acl/mod.rs @@ -6,7 +6,7 @@ use defguard_common::{ Device, DeviceType, User, WireguardNetwork, group::Group, settings::initialize_current_settings, - wireguard::{DEFAULT_WIREGUARD_MTU, LocationMfaMode, ServiceLocationMode}, + wireguard::{LocationMfaMode, ServiceLocationMode}, }, }, }; diff --git a/crates/defguard_core/tests/integration/api/acl/rules.rs b/crates/defguard_core/tests/integration/api/acl/rules.rs index 6baa6bbd77..9e5eb5baac 100644 --- a/crates/defguard_core/tests/integration/api/acl/rules.rs +++ b/crates/defguard_core/tests/integration/api/acl/rules.rs @@ -341,12 +341,8 @@ async fn test_related_objects(_: PgPoolOptions, options: PgConnectOptions) { 1000, "endpoint1".to_string(), None, - DEFAULT_WIREGUARD_MTU, - 0, Vec::new(), true, - 100, - 100, false, false, LocationMfaMode::Disabled, @@ -862,12 +858,8 @@ async fn test_rule_delete_state_applied(_: PgPoolOptions, options: PgConnectOpti 1000, "endpoint1".to_string(), None, - DEFAULT_WIREGUARD_MTU, - 0, Vec::new(), true, - 100, - 100, false, false, LocationMfaMode::Disabled, diff --git a/crates/defguard_core/tests/integration/api/wireguard_network_import.rs b/crates/defguard_core/tests/integration/api/wireguard_network_import.rs index 0925b3e23e..d01abed0f2 100644 --- a/crates/defguard_core/tests/integration/api/wireguard_network_import.rs +++ b/crates/defguard_core/tests/integration/api/wireguard_network_import.rs @@ -3,10 +3,7 @@ use std::net::IpAddr; use defguard_common::db::models::{ Device, DeviceType, WireguardNetwork, device::UserDevice, - wireguard::{ - DEFAULT_DISCONNECT_THRESHOLD, DEFAULT_KEEPALIVE_INTERVAL, DEFAULT_WIREGUARD_MTU, - LocationMfaMode, ServiceLocationMode, - }, + wireguard::{LocationMfaMode, ServiceLocationMode}, }; use defguard_core::{ grpc::GatewayEvent, @@ -56,12 +53,8 @@ async fn test_config_import(_: PgPoolOptions, options: PgConnectOptions) { 51515, String::new(), None, - DEFAULT_WIREGUARD_MTU, - 0, Vec::new(), false, - DEFAULT_KEEPALIVE_INTERVAL, - DEFAULT_DISCONNECT_THRESHOLD, false, false, LocationMfaMode::Disabled, diff --git a/crates/defguard_session_manager/tests/common/mod.rs b/crates/defguard_session_manager/tests/common/mod.rs index 84325c8d57..3ec17c2300 100644 --- a/crates/defguard_session_manager/tests/common/mod.rs +++ b/crates/defguard_session_manager/tests/common/mod.rs @@ -122,12 +122,8 @@ pub(crate) async fn create_location_with_mfa_mode( 51820, "10.0.0.1".to_string(), None, - 1420, - 0, vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap()], true, - 25, - 300, false, false, location_mfa_mode, diff --git a/crates/defguard_setup/src/auto_adoption.rs b/crates/defguard_setup/src/auto_adoption.rs index 6434226bc3..f766e59c86 100644 --- a/crates/defguard_setup/src/auto_adoption.rs +++ b/crates/defguard_setup/src/auto_adoption.rs @@ -14,10 +14,7 @@ use defguard_common::{ setup_auto_adoption::{ AutoAdoptionComponentResult, AutoAdoptionWizardState, SetupAutoAdoptionComponent, }, - wireguard::{ - DEFAULT_DISCONNECT_THRESHOLD, DEFAULT_KEEPALIVE_INTERVAL, DEFAULT_WIREGUARD_MTU, - LocationMfaMode, ServiceLocationMode, - }, + wireguard::{LocationMfaMode, ServiceLocationMode}, }, }; use defguard_core::version::{MIN_GATEWAY_VERSION, MIN_PROXY_VERSION}; @@ -742,12 +739,8 @@ id={} for new gateway", DEFAULT_AUTO_ADOPTION_WIREGUARD_PORT, host.to_string(), None, - DEFAULT_WIREGUARD_MTU, - 0, Vec::new(), true, - DEFAULT_KEEPALIVE_INTERVAL, - DEFAULT_DISCONNECT_THRESHOLD, false, false, LocationMfaMode::Disabled, diff --git a/crates/defguard_setup/tests/auto_adoption_wizard.rs b/crates/defguard_setup/tests/auto_adoption_wizard.rs index 1c4dbbada7..cf72b02ecb 100644 --- a/crates/defguard_setup/tests/auto_adoption_wizard.rs +++ b/crates/defguard_setup/tests/auto_adoption_wizard.rs @@ -1,6 +1,7 @@ use defguard_common::{ config::DefGuardConfig, db::{ + Id, models::{ Settings, WireguardNetwork, settings::initialize_current_settings, @@ -38,27 +39,25 @@ async fn assert_auto_adoption_step(pool: &sqlx::PgPool, expected: AutoAdoptionWi } /// Seed a minimal WireguardNetwork row required by the auto-adoption VPN/MFA steps. -async fn seed_wireguard_network(pool: &sqlx::PgPool) -> WireguardNetwork { - WireguardNetwork::new( +async fn seed_wireguard_network(pool: &sqlx::PgPool) -> WireguardNetwork { + let mut location = WireguardNetwork::new( "auto-net".to_string(), ["10.0.0.0/24".parse::().unwrap()], 51820, "1.2.3.4".to_string(), None, - 1280, - 0, - vec!["0.0.0.0/0".parse::().unwrap()], + vec!["0.0.0.0/0".parse().unwrap()], false, - 180, - 25, false, false, LocationMfaMode::Disabled, ServiceLocationMode::Disabled, - ) - .save(pool) - .await - .expect("Failed to save wireguard network") + ); + location.mtu = 1280; + location + .save(pool) + .await + .expect("Failed to save wireguard network") } #[sqlx::test] diff --git a/crates/defguard_setup/tests/wizard_state.rs b/crates/defguard_setup/tests/wizard_state.rs index bc572c7e70..556201a300 100644 --- a/crates/defguard_setup/tests/wizard_state.rs +++ b/crates/defguard_setup/tests/wizard_state.rs @@ -138,26 +138,24 @@ async fn test_wizard_state_auto_adoption(_: PgPoolOptions, options: PgConnectOpt .await .expect("Failed to initialize settings"); - WireguardNetwork::new( + let mut location = WireguardNetwork::new( "auto-net".to_string(), ["10.0.0.0/24".parse().unwrap()], 51820, "1.2.3.4".to_string(), None, - 1280, - 0, vec!["0.0.0.0/0".parse().unwrap()], false, - 180, - 25, false, false, LocationMfaMode::Disabled, ServiceLocationMode::Disabled, - ) - .save(&pool) - .await - .expect("Failed to seed wireguard network"); + ); + location.mtu = 1280; + location + .save(&pool) + .await + .expect("Failed to seed wireguard network"); Wizard::init(&pool, true) .await From 4677c70d23943f21015b6f49aa488e600f378853 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Tue, 17 Mar 2026 12:07:18 +0100 Subject: [PATCH 03/10] Rewrite WireguardNetwork::try_set_address --- .../defguard_common/src/db/models/device.rs | 44 +++++++----- .../src/db/models/wireguard.rs | 71 +++++++++---------- .../src/enterprise/db/models/acl/tests.rs | 2 - .../src/enterprise/directory_sync/tests.rs | 10 ++- .../defguard_core/src/handlers/wireguard.rs | 2 +- crates/defguard_core/src/lib.rs | 6 +- .../src/location_management/allowed_peers.rs | 15 ++-- .../src/location_management/mod.rs | 18 +++-- crates/defguard_core/src/wg_config.rs | 4 +- .../tests/integration/api/acl/rules.rs | 2 - .../api/wireguard_network_import.rs | 4 +- .../tests/common/mod.rs | 13 ++-- crates/defguard_setup/src/auto_adoption.rs | 13 ++-- .../tests/auto_adoption_wizard.rs | 4 +- crates/defguard_setup/tests/wizard_state.rs | 4 +- crates/defguard_static_ip/src/lib.rs | 30 ++++---- 16 files changed, 132 insertions(+), 110 deletions(-) diff --git a/crates/defguard_common/src/db/models/device.rs b/crates/defguard_common/src/db/models/device.rs index b92382a369..5679bf6983 100644 --- a/crates/defguard_common/src/db/models/device.rs +++ b/crates/defguard_common/src/db/models/device.rs @@ -1099,9 +1099,12 @@ mod test { async fn test_assign_device_ip(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let mut network = WireguardNetwork::default(); - network.try_set_address("10.1.1.1/30").unwrap(); - let network = network.save(&pool).await.unwrap(); + let network = WireguardNetwork::default() + .try_set_address("10.1.1.1/30") + .unwrap() + .save(&pool) + .await + .unwrap(); let user = User::new( "testuser", @@ -1140,11 +1143,12 @@ mod test { ) { let pool = setup_pool(options).await; - let mut network = WireguardNetwork::default(); - network + let network = WireguardNetwork::default() .try_set_address("10.0.0.1/8,123.10.0.1/16,123.123.123.1/24") + .unwrap() + .save(&pool) + .await .unwrap(); - let network = network.save(&pool).await.unwrap(); let user = User::new( "testuser", @@ -1237,9 +1241,12 @@ mod test { ) { let pool = setup_pool(options).await; - let mut network = WireguardNetwork::default(); - network.try_set_address("10.0.0.1/8").unwrap(); - let network = network.save(&pool).await.unwrap(); + let network = WireguardNetwork::default() + .try_set_address("10.0.0.1/8") + .unwrap() + .save(&pool) + .await + .unwrap(); let user = User::new( "testuser", @@ -1317,9 +1324,12 @@ mod test { ) { let pool = setup_pool(options).await; - let mut network = WireguardNetwork::default(); - network.try_set_address("123.123.123.1/24").unwrap(); - let network = network.save(&pool).await.unwrap(); + let network = WireguardNetwork::default() + .try_set_address("123.123.123.1/24") + .unwrap() + .save(&pool) + .await + .unwrap(); let user = User::new( "testuser", @@ -1420,14 +1430,16 @@ mod test { .await .unwrap(); - let mut network = WireguardNetwork::default(); + let mut network = WireguardNetwork::default() + .try_set_address("10.1.1.1/24") + .unwrap(); network.allow_all_groups = true; - network.try_set_address("10.1.1.1/24").unwrap(); let network = network.save(&pool).await.unwrap(); - let mut network_2 = WireguardNetwork::default(); + let mut network_2 = WireguardNetwork::default() + .try_set_address("10.1.2.1/24") + .unwrap(); network_2.name = "testnetwork2".into(); network_2.allow_all_groups = true; - network_2.try_set_address("10.1.2.1/24").unwrap(); let network2 = network_2.save(&pool).await.unwrap(); let device = Device::new( diff --git a/crates/defguard_common/src/db/models/wireguard.rs b/crates/defguard_common/src/db/models/wireguard.rs index cd761586c8..683143458f 100644 --- a/crates/defguard_common/src/db/models/wireguard.rs +++ b/crates/defguard_common/src/db/models/wireguard.rs @@ -32,7 +32,6 @@ use crate::{ }, }, types::user_info::UserInfo, - utils::parse_address_list, }; pub const DEFAULT_KEEPALIVE_INTERVAL: i32 = 25; @@ -215,11 +214,10 @@ impl WireguardNetwork { #[must_use] pub fn new( name: String, - address: V, port: i32, endpoint: String, dns: Option, - allowed_ips: Vec, + allowed_ips: V, allow_all_groups: bool, acl_enabled: bool, acl_default_allow: bool, @@ -234,7 +232,7 @@ impl WireguardNetwork { Self { id: NoId, name, - address: address.into(), + address: Vec::new(), port, pubkey: BASE64_STANDARD.encode(pubkey.to_bytes()), prvkey: BASE64_STANDARD.encode(prvkey.to_bytes()), @@ -242,7 +240,7 @@ impl WireguardNetwork { dns, mtu: DEFAULT_WIREGUARD_MTU, fwmark: DEFAULT_FWMARK, - allowed_ips, + allowed_ips: allowed_ips.into(), allow_all_groups, connected_at: None, keepalive_interval: DEFAULT_KEEPALIVE_INTERVAL, @@ -254,14 +252,17 @@ impl WireguardNetwork { } } - /// Try to set `address` from `&str`. - pub fn try_set_address(&mut self, address: &str) -> Result<(), IpNetworkError> { - let address = parse_address_list(address); + /// Try to set `address` from comma-separated string of addresses. + /// If there is an error parsing the address list, `address` will be partially set. + pub fn try_set_address(mut self, address: &str) -> Result { + self.address = Vec::new(); + for addr in address.split(',') { + self.address.push(addr.trim().parse()?); + } if address.is_empty() { - Err(IpNetworkError::InvalidAddr("invalid address".into())) + Err(IpNetworkError::InvalidAddr("empty address".into())) } else { - self.address = address; - Ok(()) + Ok(self) } } } @@ -1750,11 +1751,10 @@ mod test { #[sqlx::test] async fn test_get_allowed_devices_for_user(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let mut network = WireguardNetwork:: { - allow_all_groups: true, - ..Default::default() - }; - network.try_set_address("10.1.1.1/29").unwrap(); + let mut network = WireguardNetwork::default() + .try_set_address("10.1.1.1/29") + .unwrap(); + network.allow_all_groups = true; let network = network.save(&pool).await.unwrap(); let user1 = User::new( @@ -1845,9 +1845,12 @@ mod test { options: PgConnectOptions, ) { let pool = setup_pool(options).await; - let mut network = WireguardNetwork::default(); - network.try_set_address("10.1.1.1/29").unwrap(); - let network = network.save(&pool).await.unwrap(); + let network = WireguardNetwork::default() + .try_set_address("10.1.1.1/29") + .unwrap() + .save(&pool) + .await + .unwrap(); let user1 = User::new( "user1", @@ -1928,22 +1931,20 @@ mod test { async fn test_can_assign_ips(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let network = WireguardNetwork::new( + let mut network = WireguardNetwork::new( "network".to_string(), - [IpNetwork::from_str("10.1.1.1/24").unwrap()], 50051, String::new(), None, - vec![IpNetwork::from_str("10.1.1.0/24").unwrap()], + [IpNetwork::from_str("10.1.1.0/24").unwrap()], false, false, false, LocationMfaMode::Disabled, ServiceLocationMode::Disabled, - ) - .save(&pool) - .await - .unwrap(); + ); + network.set_address([IpNetwork::from_str("10.1.1.1/24").unwrap()]); + let network = network.save(&pool).await.unwrap(); // assign free address let addrs = vec![IpAddr::from_str("10.1.1.2").unwrap()]; @@ -2056,25 +2057,23 @@ mod test { async fn test_can_assign_ips_multiple_addresses(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let network = WireguardNetwork::new( + let mut network = WireguardNetwork::new( "network".to_string(), - [ - IpNetwork::from_str("10.1.1.1/24").unwrap(), - IpNetwork::from_str("fc00::1/112").unwrap(), - ], 50051, String::new(), None, - vec![IpNetwork::from_str("10.1.1.0/24").unwrap()], + [IpNetwork::from_str("10.1.1.0/24").unwrap()], false, false, false, LocationMfaMode::Disabled, ServiceLocationMode::Disabled, - ) - .save(&pool) - .await - .unwrap(); + ); + network.set_address([ + IpNetwork::from_str("10.1.1.1/24").unwrap(), + IpNetwork::from_str("fc00::1/112").unwrap(), + ]); + let network = network.save(&pool).await.unwrap(); // assign free addresses let addrs = vec![ diff --git a/crates/defguard_core/src/enterprise/db/models/acl/tests.rs b/crates/defguard_core/src/enterprise/db/models/acl/tests.rs index db5a82ad05..34776cc3a6 100644 --- a/crates/defguard_core/src/enterprise/db/models/acl/tests.rs +++ b/crates/defguard_core/src/enterprise/db/models/acl/tests.rs @@ -153,7 +153,6 @@ async fn test_rule_relations(_: PgPoolOptions, options: PgConnectOptions) { // create 2 networks let network1 = WireguardNetwork::new( "network1".to_string(), - Vec::new(), 1000, "endpoint1".to_string(), None, @@ -169,7 +168,6 @@ async fn test_rule_relations(_: PgPoolOptions, options: PgConnectOptions) { .unwrap(); let _network2 = WireguardNetwork::new( "network2".to_string(), - Vec::new(), 2000, "endpoint2".to_string(), None, diff --git a/crates/defguard_core/src/enterprise/directory_sync/tests.rs b/crates/defguard_core/src/enterprise/directory_sync/tests.rs index ccdd5c0881..3c5849f0f6 100644 --- a/crates/defguard_core/src/enterprise/directory_sync/tests.rs +++ b/crates/defguard_core/src/enterprise/directory_sync/tests.rs @@ -52,9 +52,8 @@ mod test { provider.delete(pool).await.unwrap(); } - WireguardNetwork::new( + let mut location = WireguardNetwork::new( "test".to_string(), - [IpNetwork::from_str("10.10.10.1/24").unwrap()], 1234, "123.123.123.123".to_string(), None, @@ -64,10 +63,9 @@ mod test { false, LocationMfaMode::Disabled, ServiceLocationMode::Disabled, - ) - .save(pool) - .await - .unwrap(); + ); + location.set_address([IpNetwork::from_str("10.10.10.1/24").unwrap()]); + location.save(pool).await.unwrap(); OpenIdProvider::new( "Test".to_string(), diff --git a/crates/defguard_core/src/handlers/wireguard.rs b/crates/defguard_core/src/handlers/wireguard.rs index 9ef305b37a..bbbc3ec06d 100644 --- a/crates/defguard_core/src/handlers/wireguard.rs +++ b/crates/defguard_core/src/handlers/wireguard.rs @@ -238,7 +238,6 @@ pub(crate) async fn create_network( let allowed_ips = data.parse_allowed_ips(); let mut network = WireguardNetwork::new( data.name, - parse_address_list(&data.address), data.port, data.endpoint, data.dns, @@ -249,6 +248,7 @@ pub(crate) async fn create_network( data.location_mfa_mode, data.service_location_mode, ); + network.set_address(parse_address_list(&data.address)); network.mtu = data.mtu; network.fwmark = data.fwmark; network.keepalive_interval = data.keepalive_interval; diff --git a/crates/defguard_core/src/lib.rs b/crates/defguard_core/src/lib.rs index a37f2091d6..41cbf0a813 100644 --- a/crates/defguard_core/src/lib.rs +++ b/crates/defguard_core/src/lib.rs @@ -758,7 +758,6 @@ pub async fn init_dev_env(config: &DefGuardConfig) { info!("Creating test network"); let mut network = WireguardNetwork::new( "TestNet".to_string(), - [IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 1, 1, 1)), 24).unwrap()], 50051, "0.0.0.0".to_string(), None, @@ -769,6 +768,7 @@ pub async fn init_dev_env(config: &DefGuardConfig) { LocationMfaMode::Disabled, ServiceLocationMode::Disabled, ); + network.set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 1, 1, 1)), 24).unwrap()]); network.pubkey = "zGMeVGm9HV9I4wSKF9AXmYnnAIhDySyqLMuKpcfIaQo=".to_string(); network.prvkey = "MAk3d5KuB167G88HM7nGYR6ksnPMAOguAg2s5EcPp1M=".to_string(); network @@ -858,7 +858,6 @@ pub async fn init_vpn_location( else { let mut network = WireguardNetwork::new( args.name.clone(), - vec![args.address], args.port, args.endpoint.clone(), args.dns.clone(), @@ -869,6 +868,7 @@ pub async fn init_vpn_location( LocationMfaMode::Disabled, ServiceLocationMode::Disabled, ); + network.set_address([args.address]); network.mtu = args.mtu as i32; network.fwmark = i64::from(args.fwmark); let network = network.save(&mut *transaction).await?; @@ -898,7 +898,6 @@ pub async fn init_vpn_location( // create a new network let mut location = WireguardNetwork::new( args.name.clone(), - vec![args.address], args.port, args.endpoint.clone(), args.dns.clone(), @@ -909,6 +908,7 @@ pub async fn init_vpn_location( LocationMfaMode::Disabled, ServiceLocationMode::Disabled, ); + location.set_address([args.address]); location.mtu = args.mtu as i32; location.fwmark = i64::from(args.fwmark); location.save(pool).await? diff --git a/crates/defguard_core/src/location_management/allowed_peers.rs b/crates/defguard_core/src/location_management/allowed_peers.rs index 22dce2c43b..49827834c4 100644 --- a/crates/defguard_core/src/location_management/allowed_peers.rs +++ b/crates/defguard_core/src/location_management/allowed_peers.rs @@ -127,11 +127,12 @@ mod test { .unwrap(); // Normal location (service_location_mode = Disabled) should return peers - let mut network_normal = WireguardNetwork::default(); + let mut network_normal = WireguardNetwork::default() + .try_set_address("10.1.1.1/24") + .unwrap(); network_normal.name = "normal-location".to_string(); network_normal.service_location_mode = ServiceLocationMode::Disabled; network_normal.location_mfa_mode = LocationMfaMode::Disabled; - network_normal.try_set_address("10.1.1.1/24").unwrap(); let network_normal = network_normal.save(&pool).await.unwrap(); WireguardNetworkDevice::new( @@ -150,11 +151,12 @@ mod test { assert_eq!(peers_normal[0].pubkey, "pubkey1"); // Service location with PreLogon mode returns peers when enterprise is enabled (test env default) - let mut network_prelogon = WireguardNetwork::default(); + let mut network_prelogon = WireguardNetwork::default() + .try_set_address("10.2.1.1/24") + .unwrap(); network_prelogon.name = "prelogon-service-location".to_string(); network_prelogon.service_location_mode = ServiceLocationMode::PreLogon; network_prelogon.location_mfa_mode = LocationMfaMode::Disabled; - network_prelogon.try_set_address("10.2.1.1/24").unwrap(); let network_prelogon = network_prelogon.save(&pool).await.unwrap(); WireguardNetworkDevice::new( @@ -178,11 +180,12 @@ mod test { assert_eq!(peers_prelogon[0].pubkey, "pubkey2"); // Service location with AlwaysOn mode also returns peers when enterprise is enabled - let mut network_alwayson = WireguardNetwork::default(); + let mut network_alwayson = WireguardNetwork::default() + .try_set_address("10.3.1.1/24") + .unwrap(); network_alwayson.name = "alwayson-service-location".to_string(); network_alwayson.service_location_mode = ServiceLocationMode::AlwaysOn; network_alwayson.location_mfa_mode = LocationMfaMode::Disabled; - network_alwayson.try_set_address("10.3.1.1/24").unwrap(); let network_alwayson = network_alwayson.save(&pool).await.unwrap(); let device3 = Device::new( diff --git a/crates/defguard_core/src/location_management/mod.rs b/crates/defguard_core/src/location_management/mod.rs index eef808319e..1355cce3c5 100644 --- a/crates/defguard_core/src/location_management/mod.rs +++ b/crates/defguard_core/src/location_management/mod.rs @@ -400,9 +400,12 @@ mod test { #[sqlx::test] async fn test_sync_allowed_devices_for_user(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let mut network = WireguardNetwork::default(); - network.try_set_address("10.1.1.1/29").unwrap(); - let network = network.save(&pool).await.unwrap(); + let network = WireguardNetwork::default() + .try_set_address("10.1.1.1/29") + .unwrap() + .save(&pool) + .await + .unwrap(); let user1 = User::new( "testuser1", @@ -517,9 +520,12 @@ mod test { options: PgConnectOptions, ) { let pool = setup_pool(options).await; - let mut network = WireguardNetwork::default(); - network.try_set_address("10.1.1.1/29").unwrap(); - let network = network.save(&pool).await.unwrap(); + let network = WireguardNetwork::default() + .try_set_address("10.1.1.1/29") + .unwrap() + .save(&pool) + .await + .unwrap(); let user1 = User::new( "testuser1", diff --git a/crates/defguard_core/src/wg_config.rs b/crates/defguard_core/src/wg_config.rs index e80f22fa08..b1c0ead6d8 100644 --- a/crates/defguard_core/src/wg_config.rs +++ b/crates/defguard_core/src/wg_config.rs @@ -90,7 +90,7 @@ pub(crate) fn parse_wireguard_config( .map_err(|_| WireguardConfigParseError::InvalidFwMark(value.to_string()))?, None => 0, }; - let mut addresses: Vec = Vec::new(); + let mut addresses = Vec::::new(); for addr in address.split(',') { match addr.trim().parse() { Ok(ip) => addresses.push(ip), @@ -107,7 +107,6 @@ pub(crate) fn parse_wireguard_config( .collect::, _>>()?; let mut network = WireguardNetwork::new( pubkey.clone(), - addresses.clone(), port, String::new(), dns, @@ -118,6 +117,7 @@ pub(crate) fn parse_wireguard_config( LocationMfaMode::Disabled, ServiceLocationMode::Disabled, ); + network.set_address(addresses.clone()); network.mtu = mtu; network.fwmark = fwmark; network.pubkey = pubkey; diff --git a/crates/defguard_core/tests/integration/api/acl/rules.rs b/crates/defguard_core/tests/integration/api/acl/rules.rs index 9e5eb5baac..40e7764e33 100644 --- a/crates/defguard_core/tests/integration/api/acl/rules.rs +++ b/crates/defguard_core/tests/integration/api/acl/rules.rs @@ -337,7 +337,6 @@ async fn test_related_objects(_: PgPoolOptions, options: PgConnectOptions) { for net in ["net 1", "net 2"] { WireguardNetwork::new( net.to_string(), - Vec::new(), 1000, "endpoint1".to_string(), None, @@ -854,7 +853,6 @@ async fn test_rule_delete_state_applied(_: PgPoolOptions, options: PgConnectOpti // create a location WireguardNetwork::new( "test location".to_string(), - Vec::new(), 1000, "endpoint1".to_string(), None, diff --git a/crates/defguard_core/tests/integration/api/wireguard_network_import.rs b/crates/defguard_core/tests/integration/api/wireguard_network_import.rs index d01abed0f2..599ce237ca 100644 --- a/crates/defguard_core/tests/integration/api/wireguard_network_import.rs +++ b/crates/defguard_core/tests/integration/api/wireguard_network_import.rs @@ -47,9 +47,8 @@ async fn test_config_import(_: PgPoolOptions, options: PgConnectOptions) { let pool = client_state.pool; // setup initial network - let initial_network = WireguardNetwork::new( + let mut initial_network = WireguardNetwork::new( "initial".into(), - ["10.1.9.0/24".parse().unwrap()], 51515, String::new(), None, @@ -60,6 +59,7 @@ async fn test_config_import(_: PgPoolOptions, options: PgConnectOptions) { LocationMfaMode::Disabled, ServiceLocationMode::Disabled, ); + initial_network.set_address(["10.1.9.0/24".parse().unwrap()]); initial_network.save(&pool).await.unwrap(); // add existing devices diff --git a/crates/defguard_session_manager/tests/common/mod.rs b/crates/defguard_session_manager/tests/common/mod.rs index 3ec17c2300..84aa0b8219 100644 --- a/crates/defguard_session_manager/tests/common/mod.rs +++ b/crates/defguard_session_manager/tests/common/mod.rs @@ -116,9 +116,8 @@ pub(crate) async fn create_location_with_mfa_mode( pool: &sqlx::PgPool, location_mfa_mode: LocationMfaMode, ) -> WireguardNetwork { - WireguardNetwork::new( + let mut location = WireguardNetwork::new( "TestNet".to_string(), - [IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 0)), 24).unwrap()], 51820, "10.0.0.1".to_string(), None, @@ -128,10 +127,12 @@ pub(crate) async fn create_location_with_mfa_mode( false, location_mfa_mode, ServiceLocationMode::Disabled, - ) - .save(pool) - .await - .expect("failed to create Wireguard location") + ); + location.set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 0)), 24).unwrap()]); + location + .save(pool) + .await + .expect("failed to create WireGuard location") } pub(crate) async fn create_user(pool: &sqlx::PgPool) -> User { diff --git a/crates/defguard_setup/src/auto_adoption.rs b/crates/defguard_setup/src/auto_adoption.rs index f766e59c86..a02904dd47 100644 --- a/crates/defguard_setup/src/auto_adoption.rs +++ b/crates/defguard_setup/src/auto_adoption.rs @@ -733,9 +733,8 @@ id={} for new gateway", .context("Failed to parse default auto-adoption network address")?; let mut transaction = pool.begin().await.context("Failed to begin transaction")?; - let network = WireguardNetwork::new( + let mut network = WireguardNetwork::new( common_name.to_string(), - [network_address], DEFAULT_AUTO_ADOPTION_WIREGUARD_PORT, host.to_string(), None, @@ -745,10 +744,12 @@ id={} for new gateway", false, LocationMfaMode::Disabled, ServiceLocationMode::Disabled, - ) - .save(&mut *transaction) - .await - .context("Failed to save auto-adopted WireguardNetwork")?; + ); + network.set_address([network_address]); + let network = network + .save(&mut *transaction) + .await + .context("Failed to save auto-adopted WireguardNetwork")?; network .add_all_allowed_devices(&mut transaction) diff --git a/crates/defguard_setup/tests/auto_adoption_wizard.rs b/crates/defguard_setup/tests/auto_adoption_wizard.rs index cf72b02ecb..2bd5562a82 100644 --- a/crates/defguard_setup/tests/auto_adoption_wizard.rs +++ b/crates/defguard_setup/tests/auto_adoption_wizard.rs @@ -42,17 +42,17 @@ async fn assert_auto_adoption_step(pool: &sqlx::PgPool, expected: AutoAdoptionWi async fn seed_wireguard_network(pool: &sqlx::PgPool) -> WireguardNetwork { let mut location = WireguardNetwork::new( "auto-net".to_string(), - ["10.0.0.0/24".parse::().unwrap()], 51820, "1.2.3.4".to_string(), None, - vec!["0.0.0.0/0".parse().unwrap()], + ["0.0.0.0/0".parse().unwrap()], false, false, false, LocationMfaMode::Disabled, ServiceLocationMode::Disabled, ); + location.set_address(["10.0.0.0/24".parse::().unwrap()]); location.mtu = 1280; location .save(pool) diff --git a/crates/defguard_setup/tests/wizard_state.rs b/crates/defguard_setup/tests/wizard_state.rs index 556201a300..eeaa8c5546 100644 --- a/crates/defguard_setup/tests/wizard_state.rs +++ b/crates/defguard_setup/tests/wizard_state.rs @@ -140,17 +140,17 @@ async fn test_wizard_state_auto_adoption(_: PgPoolOptions, options: PgConnectOpt let mut location = WireguardNetwork::new( "auto-net".to_string(), - ["10.0.0.0/24".parse().unwrap()], 51820, "1.2.3.4".to_string(), None, - vec!["0.0.0.0/0".parse().unwrap()], + ["0.0.0.0/0".parse().unwrap()], false, false, false, LocationMfaMode::Disabled, ServiceLocationMode::Disabled, ); + location.set_address(["10.0.0.0/24".parse().unwrap()]); location.mtu = 1280; location .save(&pool) diff --git a/crates/defguard_static_ip/src/lib.rs b/crates/defguard_static_ip/src/lib.rs index c96c999a30..e59f7b4905 100644 --- a/crates/defguard_static_ip/src/lib.rs +++ b/crates/defguard_static_ip/src/lib.rs @@ -265,17 +265,19 @@ mod tests { .expect("Failed to create user"); // Create test locations - let mut location_a = WireguardNetwork::default(); + let mut location_a = WireguardNetwork::default() + .try_set_address("10.0.1.1/24") + .unwrap(); location_a.name = "Location A".into(); - location_a.try_set_address("10.0.1.1/24").unwrap(); let location_a = location_a .save(&pool) .await .expect("Failed to create Location A"); - let mut location_b = WireguardNetwork::default(); + let mut location_b = WireguardNetwork::default() + .try_set_address("10.0.2.1/24") + .unwrap(); location_b.name = "Location B".into(); - location_b.try_set_address("10.0.2.1/24").unwrap(); let location_b = location_b .save(&pool) .await @@ -434,9 +436,10 @@ mod tests { .await .expect("Failed to create user"); - let mut network = WireguardNetwork::default(); + let mut network = WireguardNetwork::default() + .try_set_address("10.0.0.1/24") + .unwrap(); network.name = "Assign Network".into(); - network.try_set_address("10.0.0.1/24").unwrap(); let network = network.save(&pool).await.expect("Failed to create network"); let device = Device::new( @@ -508,9 +511,10 @@ mod tests { .await .expect("Failed to create user"); - let mut network = WireguardNetwork::default(); + let mut network = WireguardNetwork::default() + .try_set_address("10.0.0.1/24") + .unwrap(); network.name = "NoDevice Network".into(); - network.try_set_address("10.0.0.1/24").unwrap(); let network = network.save(&pool).await.expect("Failed to create network"); let device = Device::new( @@ -556,9 +560,10 @@ mod tests { .await .expect("Failed to create user"); - let mut network = WireguardNetwork::default(); + let mut network = WireguardNetwork::default() + .try_set_address("10.0.0.1/24") + .unwrap(); network.name = "Range Network".into(); - network.try_set_address("10.0.0.1/24").unwrap(); let network = network.save(&pool).await.expect("Failed to create network"); let device = Device::new( @@ -611,9 +616,10 @@ mod tests { .await .expect("Failed to create user"); - let mut network = WireguardNetwork::default(); + let mut network = WireguardNetwork::default() + .try_set_address("10.0.0.1/24") + .unwrap(); network.name = "Conflict; Network".into(); - network.try_set_address("10.0.0.1/24").unwrap(); let network = network.save(&pool).await.expect("Failed to create network"); let device1 = Device::new( From 7b3b7f7a664cf0a55e1a9841fb3cf953f904544b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Tue, 17 Mar 2026 13:46:51 +0100 Subject: [PATCH 04/10] Fix tests --- .../src/db/models/auth_code.rs | 5 +- .../src/db/models/authentication_key.rs | 4 +- .../src/db/models/biometric_auth.rs | 13 +-- .../defguard_common/src/db/models/device.rs | 55 +++++------ .../src/db/models/device_login.rs | 4 +- .../defguard_common/src/db/models/gateway.rs | 5 +- crates/defguard_common/src/db/models/group.rs | 16 ++-- .../src/db/models/initial_setup_wizard.rs | 8 +- .../defguard_common/src/db/models/mfa_info.rs | 4 +- .../src/db/models/migration_wizard.rs | 6 +- .../src/db/models/oauth2authorizedapp.rs | 4 +- .../src/db/models/oauth2client.rs | 13 +-- .../src/db/models/oauth2token.rs | 14 +-- .../src/db/models/polling_token.rs | 4 +- .../defguard_common/src/db/models/session.rs | 18 ++-- .../src/db/models/setup_auto_adoption.rs | 6 +- crates/defguard_common/src/db/models/user.rs | 92 +++++++------------ .../src/db/models/vpn_client_session.rs | 12 +-- .../src/db/models/vpn_session_stats.rs | 2 +- .../defguard_common/src/db/models/webauthn.rs | 8 +- .../src/db/models/wireguard.rs | 56 ++++++----- crates/defguard_common/src/types/user_info.rs | 12 +-- .../defguard_core/src/db/models/enrollment.rs | 4 +- crates/defguard_core/src/db/models/webhook.rs | 6 +- .../src/enterprise/db/models/acl.rs | 62 +++++-------- .../db/models/activity_log_stream.rs | 4 +- .../src/enterprise/db/models/api_tokens.rs | 6 +- .../enterprise/db/models/openid_provider.rs | 8 +- .../src/enterprise/directory_sync/mod.rs | 4 +- .../src/enterprise/directory_sync/tests.rs | 11 ++- .../src/enterprise/firewall/mod.rs | 10 +- .../firewall/tests/all_locations.rs | 46 +++++----- .../enterprise/firewall/tests/destination.rs | 20 ++-- .../firewall/tests/disabled_rules.rs | 24 ++--- .../firewall/tests/expired_rules.rs | 19 ++-- .../src/enterprise/firewall/tests/gh1868.rs | 57 ++++++------ .../src/enterprise/firewall/tests/mod.rs | 54 ++++++----- .../firewall/tests/unapplied_rules.rs | 24 ++--- .../src/enterprise/ldap/error.rs | 3 +- .../src/enterprise/ldap/model.rs | 8 +- .../defguard_core/src/enterprise/license.rs | 4 +- crates/defguard_core/src/enterprise/limits.rs | 4 +- crates/defguard_core/src/error.rs | 3 + crates/defguard_core/src/handlers/mod.rs | 11 +++ .../src/handlers/ssh_authorized_keys.rs | 4 +- crates/defguard_core/src/handlers/user.rs | 4 +- .../defguard_core/src/handlers/wireguard.rs | 8 +- crates/defguard_core/src/lib.rs | 15 +-- .../src/location_management/allowed_peers.rs | 4 +- .../src/location_management/tests.rs | 11 ++- crates/defguard_core/src/wg_config.rs | 8 +- .../tests/integration/api/wireguard.rs | 10 +- .../api/wireguard_network_import.rs | 11 ++- .../tests/common/mod.rs | 16 ++-- crates/defguard_setup/src/auto_adoption.rs | 13 ++- .../src/handlers/auto_wizard.rs | 2 +- .../tests/auto_adoption_wizard.rs | 5 +- crates/defguard_setup/tests/wizard_state.rs | 5 +- 58 files changed, 434 insertions(+), 435 deletions(-) diff --git a/crates/defguard_common/src/db/models/auth_code.rs b/crates/defguard_common/src/db/models/auth_code.rs index b57774c708..283bc99a80 100644 --- a/crates/defguard_common/src/db/models/auth_code.rs +++ b/crates/defguard_common/src/db/models/auth_code.rs @@ -66,10 +66,7 @@ impl From> for AuthCode { impl AuthCode { /// Find by code. /// If found, delete `AuthCode` from the database right away, so it can't be reused. - pub async fn find_code<'e, E>( - executor: E, - code: &str, - ) -> Result>, sqlx::Error> + pub async fn find_code<'e, E>(executor: E, code: &str) -> sqlx::Result>> where E: PgExecutor<'e>, { diff --git a/crates/defguard_common/src/db/models/authentication_key.rs b/crates/defguard_common/src/db/models/authentication_key.rs index 78b44b3c58..6def68ff93 100644 --- a/crates/defguard_common/src/db/models/authentication_key.rs +++ b/crates/defguard_common/src/db/models/authentication_key.rs @@ -2,7 +2,7 @@ use std::fmt::Display; use model_derive::Model; use serde::{Deserialize, Serialize}; -use sqlx::{Error as SqlxError, PgExecutor, Type, query_as}; +use sqlx::{PgExecutor, Type, query_as}; use crate::db::{Id, NoId}; @@ -60,7 +60,7 @@ impl AuthenticationKey { executor: E, user_id: Id, key_type: Option, - ) -> Result, SqlxError> + ) -> sqlx::Result> where E: PgExecutor<'e>, { diff --git a/crates/defguard_common/src/db/models/biometric_auth.rs b/crates/defguard_common/src/db/models/biometric_auth.rs index 8f477b5fc0..5564068003 100644 --- a/crates/defguard_common/src/db/models/biometric_auth.rs +++ b/crates/defguard_common/src/db/models/biometric_auth.rs @@ -57,10 +57,7 @@ impl BiometricAuth { } impl BiometricAuth { - pub async fn find_by_device_id<'e, E>( - executor: E, - device_id: Id, - ) -> Result, sqlx::Error> + pub async fn find_by_device_id<'e, E>(executor: E, device_id: Id) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -73,11 +70,7 @@ impl BiometricAuth { .await } - pub async fn verify_owner<'e, E>( - executor: E, - user_id: Id, - pub_key: &str, - ) -> Result + pub async fn verify_owner<'e, E>(executor: E, user_id: Id, pub_key: &str) -> sqlx::Result where E: PgExecutor<'e>, { @@ -91,7 +84,7 @@ impl BiometricAuth { Ok(q_result.is_some()) } - pub async fn find_by_user_id<'e, E>(executor: E, user_id: Id) -> Result, sqlx::Error> + pub async fn find_by_user_id<'e, E>(executor: E, user_id: Id) -> sqlx::Result> where E: PgExecutor<'e>, { diff --git a/crates/defguard_common/src/db/models/device.rs b/crates/defguard_common/src/db/models/device.rs index 5679bf6983..a3ccb525ab 100644 --- a/crates/defguard_common/src/db/models/device.rs +++ b/crates/defguard_common/src/db/models/device.rs @@ -10,10 +10,7 @@ use rand::{ prelude::Distribution, }; use serde::{Deserialize, Serialize}; -use sqlx::{ - Error as SqlxError, FromRow, PgConnection, PgExecutor, PgPool, Type, query, query_as, - query_scalar, -}; +use sqlx::{FromRow, PgConnection, PgExecutor, PgPool, Type, query, query_as, query_scalar}; use thiserror::Error; use tracing::{debug, error, info, warn}; use utoipa::ToSchema; @@ -428,7 +425,7 @@ impl WireguardNetworkDevice { pub async fn find_by_device<'e, E>( executor: E, device_id: Id, - ) -> Result>, SqlxError> + ) -> sqlx::Result>> where E: PgExecutor<'e>, { @@ -616,7 +613,7 @@ impl Device { executor: E, ip: IpAddr, network_id: Id, - ) -> Result, SqlxError> + ) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -634,7 +631,7 @@ impl Device { .await } - pub async fn find_by_pubkey<'e, E>(executor: E, pubkey: &str) -> Result, SqlxError> + pub async fn find_by_pubkey<'e, E>(executor: E, pubkey: &str) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -653,7 +650,7 @@ impl Device { executor: E, id: Id, username: &str, - ) -> Result, SqlxError> { + ) -> sqlx::Result> { query_as!( Self, "SELECT device.id, name, wireguard_pubkey, user_id, created, description, \ @@ -667,7 +664,7 @@ impl Device { .await } - pub async fn all_for_username(pool: &PgPool, username: &str) -> Result, SqlxError> { + pub async fn all_for_username(pool: &PgPool, username: &str) -> sqlx::Result> { query_as!( Self, "SELECT device.id, name, wireguard_pubkey, user_id, created, description, \ @@ -975,7 +972,7 @@ impl Device { pub async fn find_by_type<'e, E>( executor: E, device_type: DeviceType, - ) -> Result, SqlxError> + ) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -991,7 +988,7 @@ impl Device { executor: E, device_type: DeviceType, network_id: Id, - ) -> Result, SqlxError> + ) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -1006,7 +1003,7 @@ impl Device { ).fetch_all(executor).await } - pub async fn get_owner<'e, E>(&self, executor: E) -> Result, SqlxError> + pub async fn get_owner<'e, E>(&self, executor: E) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -1025,7 +1022,7 @@ impl Device { &self, executor: E, location_id: Id, - ) -> Result, SqlxError> { + ) -> sqlx::Result> { query_scalar!( "SELECT connected_at \"connected_at!\" FROM vpn_client_session \ WHERE location_id = $1 AND device_id = $2 AND connected_at IS NOT NULL \ @@ -1185,12 +1182,14 @@ mod test { .await .unwrap(); - let mut updated_network = network.clone(); - updated_network.set_address([ - "10.0.0.0/16".parse::().unwrap(), - "123.12.0.0/16".parse::().unwrap(), - "123.123.0.0/16".parse::().unwrap(), - ]); + let updated_network = network + .clone() + .set_address([ + "10.0.0.1/16".parse().unwrap(), + "123.12.0.1/16".parse().unwrap(), + "123.123.0.1/16".parse().unwrap(), + ]) + .unwrap(); updated_network.save(&mut *conn).await.unwrap(); let used_ips = updated_network @@ -1281,8 +1280,10 @@ mod test { .await .unwrap(); - let mut updated_network = network.clone(); - updated_network.set_address(["10.0.0.0/16".parse::().unwrap()]); + let updated_network = network + .clone() + .set_address(["10.0.0.1/16".parse().unwrap()]) + .unwrap(); updated_network.save(&mut *conn).await.unwrap(); let used_ips = updated_network @@ -1364,8 +1365,10 @@ mod test { .await .unwrap(); - let mut updated_network = network.clone(); - updated_network.set_address(["123.123.0.0/16".parse::().unwrap()]); + let updated_network = network + .clone() + .set_address(["123.123.0.1/16".parse().unwrap()]) + .unwrap(); updated_network.save(&mut *conn).await.unwrap(); let used_ips = updated_network @@ -1537,9 +1540,9 @@ mod test { .await .unwrap(); - let mut network = WireguardNetwork::default(); - network - .set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::new(192, 168, 42, 4)), 29).unwrap()]); + let mut network = WireguardNetwork::default() + .set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::new(192, 168, 42, 4)), 29).unwrap()]) + .unwrap(); network.allow_all_groups = true; let network = network.save(&pool).await.unwrap(); diff --git a/crates/defguard_common/src/db/models/device_login.rs b/crates/defguard_common/src/db/models/device_login.rs index 09b7d52eeb..073e083f3e 100644 --- a/crates/defguard_common/src/db/models/device_login.rs +++ b/crates/defguard_common/src/db/models/device_login.rs @@ -3,7 +3,7 @@ use std::fmt; use chrono::{NaiveDateTime, Utc}; use model_derive::Model; use serde::{Deserialize, Serialize}; -use sqlx::{Error as SqlxError, PgPool, query_as}; +use sqlx::{PgPool, query_as}; use crate::db::{Id, NoId}; @@ -77,7 +77,7 @@ impl DeviceLoginEvent { pub async fn find_device_login_event( &self, pool: &PgPool, - ) -> Result>, SqlxError> { + ) -> sqlx::Result>> { query_as!( DeviceLoginEvent::, "SELECT id, user_id, ip_address, model, family, brand, os_family, browser, event_type, created \ diff --git a/crates/defguard_common/src/db/models/gateway.rs b/crates/defguard_common/src/db/models/gateway.rs index e76b1f443b..cb32ed6864 100644 --- a/crates/defguard_common/src/db/models/gateway.rs +++ b/crates/defguard_common/src/db/models/gateway.rs @@ -89,10 +89,7 @@ impl Gateway { Ok(()) } - pub async fn find_by_location_id<'e, E>( - executor: E, - location_id: Id, - ) -> Result, sqlx::Error> + pub async fn find_by_location_id<'e, E>(executor: E, location_id: Id) -> sqlx::Result> where E: PgExecutor<'e>, { diff --git a/crates/defguard_common/src/db/models/group.rs b/crates/defguard_common/src/db/models/group.rs index e875f14c1f..762efbe70f 100644 --- a/crates/defguard_common/src/db/models/group.rs +++ b/crates/defguard_common/src/db/models/group.rs @@ -2,7 +2,7 @@ use std::fmt; use model_derive::Model; use serde::Serialize; -use sqlx::{Error as SqlxError, FromRow, PgExecutor, query, query_as, query_scalar}; +use sqlx::{FromRow, PgExecutor, query, query_as, query_scalar}; use utoipa::ToSchema; use crate::db::{Id, NoId, models::user::User}; @@ -48,7 +48,7 @@ impl Group { } impl Group { - pub async fn find_by_name<'e, E>(executor: E, name: &str) -> Result, SqlxError> + pub async fn find_by_name<'e, E>(executor: E, name: &str) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -61,7 +61,7 @@ impl Group { .await } - pub async fn member_usernames<'e, E>(&self, executor: E) -> Result, SqlxError> + pub async fn member_usernames<'e, E>(&self, executor: E) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -74,7 +74,7 @@ impl Group { .await } - pub async fn members<'e, E>(&self, executor: E) -> Result>, SqlxError> + pub async fn members<'e, E>(&self, executor: E) -> sqlx::Result>> where E: PgExecutor<'e>, { @@ -94,7 +94,7 @@ impl Group { } /// Fetches a list of VPN locations where a given group is allowed. - pub async fn allowed_vpn_locations<'e, E>(&self, executor: E) -> Result, SqlxError> + pub async fn allowed_vpn_locations<'e, E>(&self, executor: E) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -111,7 +111,7 @@ impl Group { pub async fn find_by_permission<'e, E>( executor: E, permission: Permission, - ) -> Result, SqlxError> + ) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -124,7 +124,7 @@ impl Group { &self, executor: E, permission: Permission, - ) -> Result + ) -> sqlx::Result where E: PgExecutor<'e>, { @@ -141,7 +141,7 @@ impl Group { executor: E, permission: Permission, value: bool, - ) -> Result<(), SqlxError> + ) -> sqlx::Result<()> where E: PgExecutor<'e>, { diff --git a/crates/defguard_common/src/db/models/initial_setup_wizard.rs b/crates/defguard_common/src/db/models/initial_setup_wizard.rs index b4f9b8c410..39c0027e09 100644 --- a/crates/defguard_common/src/db/models/initial_setup_wizard.rs +++ b/crates/defguard_common/src/db/models/initial_setup_wizard.rs @@ -24,7 +24,7 @@ pub struct InitialSetupState { } impl InitialSetupState { - pub async fn set_step<'e, E>(executor: E, step: InitialSetupStep) -> Result<(), sqlx::Error> + pub async fn set_step<'e, E>(executor: E, step: InitialSetupStep) -> sqlx::Result<()> where E: PgExecutor<'e> + Copy, { @@ -37,7 +37,7 @@ impl InitialSetupState { Ok(()) } - pub async fn save<'e, E>(&self, executor: E) -> Result<(), sqlx::Error> + pub async fn save<'e, E>(&self, executor: E) -> sqlx::Result<()> where E: PgExecutor<'e>, { @@ -55,7 +55,7 @@ impl InitialSetupState { Ok(()) } - pub async fn get<'e, E>(executor: E) -> Result, sqlx::Error> + pub async fn get<'e, E>(executor: E) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -71,7 +71,7 @@ impl InitialSetupState { Ok(state.map(|j| j.0)) } - pub async fn clear<'e, E>(executor: E) -> Result<(), sqlx::Error> + pub async fn clear<'e, E>(executor: E) -> sqlx::Result<()> where E: PgExecutor<'e>, { diff --git a/crates/defguard_common/src/db/models/mfa_info.rs b/crates/defguard_common/src/db/models/mfa_info.rs index 07eda69483..d18543b2c5 100644 --- a/crates/defguard_common/src/db/models/mfa_info.rs +++ b/crates/defguard_common/src/db/models/mfa_info.rs @@ -1,5 +1,5 @@ use serde::{Deserialize, Serialize}; -use sqlx::{Error as SqlxError, PgPool, query_as}; +use sqlx::{PgPool, query_as}; use crate::db::{ Id, @@ -15,7 +15,7 @@ pub struct MFAInfo { } impl MFAInfo { - pub async fn for_user(pool: &PgPool, user: &User) -> Result, SqlxError> { + pub async fn for_user(pool: &PgPool, user: &User) -> sqlx::Result> { query_as!( Self, "SELECT mfa_method \"mfa_method: _\", totp_enabled totp_available, \ diff --git a/crates/defguard_common/src/db/models/migration_wizard.rs b/crates/defguard_common/src/db/models/migration_wizard.rs index 3415a2f58d..ef2fd2084c 100644 --- a/crates/defguard_common/src/db/models/migration_wizard.rs +++ b/crates/defguard_common/src/db/models/migration_wizard.rs @@ -35,7 +35,7 @@ pub struct MigrationWizardState { } impl MigrationWizardState { - pub async fn get<'e, E>(executor: E) -> Result, sqlx::Error> + pub async fn get<'e, E>(executor: E) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -54,7 +54,7 @@ impl MigrationWizardState { .map_err(|error| sqlx::Error::Decode(Box::new(error))) } - pub async fn save<'e, E>(&self, executor: E) -> Result<(), sqlx::Error> + pub async fn save<'e, E>(&self, executor: E) -> sqlx::Result<()> where E: PgExecutor<'e>, { @@ -73,7 +73,7 @@ impl MigrationWizardState { Ok(()) } - pub async fn clear<'e, E>(executor: E) -> Result<(), sqlx::Error> + pub async fn clear<'e, E>(executor: E) -> sqlx::Result<()> where E: PgExecutor<'e>, { diff --git a/crates/defguard_common/src/db/models/oauth2authorizedapp.rs b/crates/defguard_common/src/db/models/oauth2authorizedapp.rs index 421a93437b..ffd7731493 100644 --- a/crates/defguard_common/src/db/models/oauth2authorizedapp.rs +++ b/crates/defguard_common/src/db/models/oauth2authorizedapp.rs @@ -1,5 +1,5 @@ use model_derive::Model; -use sqlx::{Error as SqlxError, PgPool, query_as}; +use sqlx::{PgPool, query_as}; use crate::db::{Id, NoId}; @@ -26,7 +26,7 @@ impl OAuth2AuthorizedApp { pool: &PgPool, user_id: Id, oauth2client_id: Id, - ) -> Result, SqlxError> { + ) -> sqlx::Result> { query_as!( Self, "SELECT id, user_id, oauth2client_id \ diff --git a/crates/defguard_common/src/db/models/oauth2client.rs b/crates/defguard_common/src/db/models/oauth2client.rs index bec7be9f44..120915f0f8 100644 --- a/crates/defguard_common/src/db/models/oauth2client.rs +++ b/crates/defguard_common/src/db/models/oauth2client.rs @@ -1,6 +1,6 @@ use model_derive::Model; use serde::{Deserialize, Serialize}; -use sqlx::{Error as SqlxError, PgExecutor, PgPool, query_as}; +use sqlx::{PgExecutor, PgPool, query_as}; use crate::{ db::{Id, NoId, models::OAuth2Token}, @@ -43,7 +43,7 @@ impl OAuth2Client { pub async fn find_by_client_id<'e, E>( executor: E, client_id: &str, - ) -> Result, SqlxError> + ) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -57,7 +57,7 @@ impl OAuth2Client { .await } - pub async fn clear_authorizations<'e, E>(&self, executor: E) -> Result<(), SqlxError> + pub async fn clear_authorizations<'e, E>(&self, executor: E) -> sqlx::Result<()> where E: PgExecutor<'e>, { @@ -75,7 +75,7 @@ impl OAuth2Client { pool: &PgPool, client_id: &str, client_secret: &str, - ) -> Result, SqlxError> { + ) -> sqlx::Result> { query_as!( Self, "SELECT id, client_id, client_secret, redirect_uri, scope, name, enabled \ @@ -87,10 +87,7 @@ impl OAuth2Client { .await } - pub async fn find_by_token( - pool: &PgPool, - token: &OAuth2Token, - ) -> Result, SqlxError> { + pub async fn find_by_token(pool: &PgPool, token: &OAuth2Token) -> sqlx::Result> { query_as!( Self, "SELECT c.id, c.client_id, c.client_secret, c.redirect_uri, c.scope, c.name, c.enabled \ diff --git a/crates/defguard_common/src/db/models/oauth2token.rs b/crates/defguard_common/src/db/models/oauth2token.rs index e550882f84..3af29fbf22 100644 --- a/crates/defguard_common/src/db/models/oauth2token.rs +++ b/crates/defguard_common/src/db/models/oauth2token.rs @@ -1,5 +1,5 @@ use chrono::{TimeDelta, Utc}; -use sqlx::{Error as SqlxError, PgPool, query, query_as}; +use sqlx::{PgPool, query, query_as}; use crate::{ db::{Id, models::Settings}, @@ -32,7 +32,7 @@ impl OAuth2Token { } /// Generate new access token, scratching the old one. Changes are reflected in the database. - pub async fn refresh_and_save(&mut self, pool: &PgPool) -> Result<(), SqlxError> { + pub async fn refresh_and_save(&mut self, pool: &PgPool) -> sqlx::Result<()> { let settings = Settings::get_current_settings(); let timeout = settings.authentication_timeout(); let new_access_token = gen_alphanumeric(24); @@ -60,7 +60,7 @@ impl OAuth2Token { } /// Store data in the database. - pub async fn save(&self, pool: &PgPool) -> Result<(), SqlxError> { + pub async fn save(&self, pool: &PgPool) -> sqlx::Result<()> { query!( "INSERT INTO oauth2token (oauth2authorizedapp_id, access_token, refresh_token, redirect_uri, scope, expires_in) \ VALUES ($1, $2, $3, $4, $5, $6)", @@ -76,7 +76,7 @@ impl OAuth2Token { } /// Delete token from the database. - pub async fn delete(self, pool: &PgPool) -> Result<(), SqlxError> { + pub async fn delete(self, pool: &PgPool) -> sqlx::Result<()> { query!( "DELETE FROM oauth2token WHERE access_token = $1 AND refresh_token = $2", self.access_token, @@ -91,7 +91,7 @@ impl OAuth2Token { pub async fn find_access_token( pool: &PgPool, access_token: &str, - ) -> Result, SqlxError> { + ) -> sqlx::Result> { match query_as!( Self, "SELECT oauth2authorizedapp_id, access_token, refresh_token, redirect_uri, scope, expires_in \ @@ -118,7 +118,7 @@ impl OAuth2Token { pub async fn find_refresh_token( pool: &PgPool, refresh_token: &str, - ) -> Result, SqlxError> { + ) -> sqlx::Result> { match query_as!( Self, "SELECT oauth2authorizedapp_id, access_token, refresh_token, redirect_uri, scope, expires_in \ @@ -145,7 +145,7 @@ impl OAuth2Token { pub async fn find_by_authorized_app_id( pool: &PgPool, oauth2authorizedapp_id: Id, - ) -> Result, SqlxError> { + ) -> sqlx::Result> { match query_as!( Self, "SELECT oauth2authorizedapp_id, access_token, refresh_token, redirect_uri, scope, expires_in \ diff --git a/crates/defguard_common/src/db/models/polling_token.rs b/crates/defguard_common/src/db/models/polling_token.rs index 750ec80a80..69ff04e531 100644 --- a/crates/defguard_common/src/db/models/polling_token.rs +++ b/crates/defguard_common/src/db/models/polling_token.rs @@ -29,7 +29,7 @@ impl PollingToken { } impl PollingToken { - pub async fn find<'e, E>(executor: E, token: &str) -> Result, sqlx::Error> + pub async fn find<'e, E>(executor: E, token: &str) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -43,7 +43,7 @@ impl PollingToken { .await } - pub async fn delete_for_device_id<'e, E>(executor: E, device_id: Id) -> Result<(), sqlx::Error> + pub async fn delete_for_device_id<'e, E>(executor: E, device_id: Id) -> sqlx::Result<()> where E: PgExecutor<'e>, { diff --git a/crates/defguard_common/src/db/models/session.rs b/crates/defguard_common/src/db/models/session.rs index 7bc1510cc1..f52bf570da 100644 --- a/crates/defguard_common/src/db/models/session.rs +++ b/crates/defguard_common/src/db/models/session.rs @@ -1,5 +1,5 @@ use chrono::{NaiveDateTime, TimeDelta, Utc}; -use sqlx::{Error as SqlxError, PgExecutor, PgPool, Type, query, query_as}; +use sqlx::{PgExecutor, PgPool, Type, query, query_as}; use webauthn_rs::prelude::{PasskeyAuthentication, PasskeyRegistration}; use crate::{ @@ -58,7 +58,7 @@ impl Session { self.expires < Utc::now().naive_utc() } - pub async fn find_by_id(pool: &PgPool, id: &str) -> Result, SqlxError> { + pub async fn find_by_id(pool: &PgPool, id: &str) -> sqlx::Result> { query_as!( Self, "SELECT id, user_id, state \"state: SessionState\", created, expires, webauthn_challenge, \ @@ -69,7 +69,7 @@ impl Session { .await } - pub async fn save(&self, pool: &PgPool) -> Result<(), SqlxError> { + pub async fn save(&self, pool: &PgPool) -> sqlx::Result<()> { query!( "INSERT INTO session (id, user_id, state, created, expires, webauthn_challenge, ip_address, device_info) \ VALUES ($1, $2, $3, $4, $5, $6, $7, $8)", @@ -88,7 +88,7 @@ impl Session { Ok(()) } - pub async fn set_state(&mut self, pool: &PgPool, state: SessionState) -> Result<(), SqlxError> { + pub async fn set_state(&mut self, pool: &PgPool, state: SessionState) -> sqlx::Result<()> { query!( "UPDATE session SET state = $1 WHERE id = $2", state.clone() as i16, @@ -119,7 +119,7 @@ impl Session { &mut self, executor: E, passkey_auth: &PasskeyAuthentication, - ) -> Result<(), SqlxError> + ) -> sqlx::Result<()> where E: PgExecutor<'e>, { @@ -141,7 +141,7 @@ impl Session { &mut self, executor: E, passkey_reg: &PasskeyRegistration, - ) -> Result<(), SqlxError> + ) -> sqlx::Result<()> where E: PgExecutor<'e>, { @@ -159,7 +159,7 @@ impl Session { Ok(()) } - pub async fn delete<'e, E>(self, executor: E) -> Result<(), SqlxError> + pub async fn delete<'e, E>(self, executor: E) -> sqlx::Result<()> where E: PgExecutor<'e>, { @@ -170,7 +170,7 @@ impl Session { Ok(()) } - pub async fn delete_expired<'e, E>(executor: E) -> Result<(), SqlxError> + pub async fn delete_expired<'e, E>(executor: E) -> sqlx::Result<()> where E: PgExecutor<'e>, { @@ -181,7 +181,7 @@ impl Session { Ok(()) } - pub async fn delete_all_for_user<'e, E>(executor: E, user_id: i64) -> Result<(), SqlxError> + pub async fn delete_all_for_user<'e, E>(executor: E, user_id: i64) -> sqlx::Result<()> where E: PgExecutor<'e>, { diff --git a/crates/defguard_common/src/db/models/setup_auto_adoption.rs b/crates/defguard_common/src/db/models/setup_auto_adoption.rs index 12a815557e..566c167dfa 100644 --- a/crates/defguard_common/src/db/models/setup_auto_adoption.rs +++ b/crates/defguard_common/src/db/models/setup_auto_adoption.rs @@ -40,7 +40,7 @@ pub struct AutoAdoptionWizardState { } impl AutoAdoptionWizardState { - pub async fn save<'e, E>(&self, executor: E) -> Result<(), sqlx::Error> + pub async fn save<'e, E>(&self, executor: E) -> sqlx::Result<()> where E: PgExecutor<'e>, { @@ -58,7 +58,7 @@ impl AutoAdoptionWizardState { Ok(()) } - pub async fn get<'e, E>(executor: E) -> Result, sqlx::Error> + pub async fn get<'e, E>(executor: E) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -74,7 +74,7 @@ impl AutoAdoptionWizardState { Ok(state.map(|j| j.0)) } - pub async fn clear<'e, E>(executor: E) -> Result<(), sqlx::Error> + pub async fn clear<'e, E>(executor: E) -> sqlx::Result<()> where E: PgExecutor<'e>, { diff --git a/crates/defguard_common/src/db/models/user.rs b/crates/defguard_common/src/db/models/user.rs index 9238901fd9..1152450606 100644 --- a/crates/defguard_common/src/db/models/user.rs +++ b/crates/defguard_common/src/db/models/user.rs @@ -17,10 +17,7 @@ use rand::{ prelude::Distribution, }; use serde::{Deserialize, Serialize}; -use sqlx::{ - Error as SqlxError, FromRow, PgConnection, PgExecutor, PgPool, Type, query, query_as, - query_scalar, -}; +use sqlx::{FromRow, PgConnection, PgExecutor, PgPool, Type, query, query_as, query_scalar}; use thiserror::Error; use totp_lite::{Sha1, totp_custom}; use tracing::{debug, error, info, warn}; @@ -49,7 +46,7 @@ pub enum UserError { #[error("Invalid MFA state for user {username}")] InvalidMfaState { username: String }, #[error(transparent)] - DbError(#[from] SqlxError), + DbError(#[from] sqlx::Error), #[error("{0}")] EmailMfaError(String), } @@ -306,7 +303,7 @@ impl User { impl User { /// Generate new TOTP secret, save it, then return it as RFC 4648 base32-encoded string. - pub async fn new_totp_secret<'e, E>(&mut self, executor: E) -> Result + pub async fn new_totp_secret<'e, E>(&mut self, executor: E) -> sqlx::Result where E: PgExecutor<'e>, { @@ -325,7 +322,7 @@ impl User { } /// Generate new email secret, similar to TOTP secret above, but don't return generated value. - pub async fn new_email_secret<'e, E>(&mut self, executor: E) -> Result<(), SqlxError> + pub async fn new_email_secret<'e, E>(&mut self, executor: E) -> sqlx::Result<()> where E: PgExecutor<'e>, { @@ -347,7 +344,7 @@ impl User { &mut self, executor: E, mfa_method: MFAMethod, - ) -> Result<(), SqlxError> + ) -> sqlx::Result<()> where E: PgExecutor<'e>, { @@ -370,7 +367,7 @@ impl User { /// Check if any of the multi-factor authentication methods is on. /// - TOTP is enabled /// - a security key for Webauthn - async fn check_mfa_enabled<'e, E>(&self, executor: E) -> Result + async fn check_mfa_enabled<'e, E>(&self, executor: E) -> sqlx::Result where E: PgExecutor<'e>, { @@ -463,7 +460,7 @@ impl User { pub async fn get_recovery_codes<'e, E>( &mut self, executor: E, - ) -> Result>, SqlxError> + ) -> sqlx::Result>> where E: PgExecutor<'e>, { @@ -487,7 +484,7 @@ impl User { } /// Disable MFA; discard recovery codes, TOTP secret, and security keys. - pub async fn disable_mfa(&mut self, pool: &PgPool) -> Result<(), SqlxError> { + pub async fn disable_mfa(&mut self, pool: &PgPool) -> sqlx::Result<()> { query!( "UPDATE \"user\" SET mfa_enabled = FALSE, mfa_method = 'none', totp_enabled = FALSE, email_mfa_enabled = FALSE, \ totp_secret = NULL, email_mfa_secret = NULL, recovery_codes = '{}' WHERE id = $1", @@ -509,7 +506,7 @@ impl User { } /// Enable TOTP - pub async fn enable_totp<'e, E>(&mut self, executor: E) -> Result<(), SqlxError> + pub async fn enable_totp<'e, E>(&mut self, executor: E) -> sqlx::Result<()> where E: PgExecutor<'e>, { @@ -527,7 +524,7 @@ impl User { } /// Disable TOTP; discard the secret. - pub async fn disable_totp(&mut self, pool: &PgPool) -> Result<(), SqlxError> { + pub async fn disable_totp(&mut self, pool: &PgPool) -> sqlx::Result<()> { if self.totp_enabled { // FIXME: check if this flag is set correctly when TOTP is the only method self.mfa_enabled = self.check_mfa_enabled(pool).await?; @@ -550,7 +547,7 @@ impl User { } /// Enable email MFA - pub async fn enable_email_mfa<'e, E>(&mut self, executor: E) -> Result<(), SqlxError> + pub async fn enable_email_mfa<'e, E>(&mut self, executor: E) -> sqlx::Result<()> where E: PgExecutor<'e>, { @@ -569,7 +566,7 @@ impl User { } /// Disable email MFA; discard the secret. - pub async fn disable_email_mfa(&mut self, pool: &PgPool) -> Result<(), SqlxError> { + pub async fn disable_email_mfa(&mut self, pool: &PgPool) -> sqlx::Result<()> { if self.email_mfa_enabled { self.mfa_enabled = self.check_mfa_enabled(pool).await?; self.email_mfa_enabled = false; @@ -592,9 +589,7 @@ impl User { /// Select all users without sensitive data. // FIXME: Remove it when Model macro will support SecretString - pub async fn all_without_sensitive_data( - pool: &PgPool, - ) -> Result, SqlxError> { + pub async fn all_without_sensitive_data(pool: &PgPool) -> sqlx::Result> { let users = query!( "SELECT id, mfa_enabled, totp_enabled, email_mfa_enabled, \ mfa_method \"mfa_method: MFAMethod\", password_hash, is_active, openid_sub, \ @@ -620,7 +615,7 @@ impl User { } /// Return all active users. - pub async fn all_active<'e, E>(executor: E) -> Result, SqlxError> + pub async fn all_active<'e, E>(executor: E) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -641,7 +636,7 @@ impl User { pub async fn find_by_group_name( pool: &PgPool, group_name: &str, - ) -> Result>, SqlxError> { + ) -> sqlx::Result>> { let users = query_as!( Self, "SELECT \"user\".id, username, password_hash, last_name, first_name, email, \ @@ -768,11 +763,7 @@ impl User { } /// Verify recovery code. If it is valid, consume it, so it can't be used again. - pub async fn verify_recovery_code( - &mut self, - pool: &PgPool, - code: &str, - ) -> Result { + pub async fn verify_recovery_code(&mut self, pool: &PgPool, code: &str) -> sqlx::Result { if let Some(index) = self.recovery_codes.iter().position(|c| c == code) { // Note: swap_remove() should be faster than remove(). self.recovery_codes.swap_remove(index); @@ -791,10 +782,7 @@ impl User { } } - pub async fn find_by_username<'e, E>( - executor: E, - username: &str, - ) -> Result, SqlxError> + pub async fn find_by_username<'e, E>(executor: E, username: &str) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -811,7 +799,7 @@ impl User { .await } - pub async fn find_by_email<'e, E>(executor: E, email: &str) -> Result, SqlxError> + pub async fn find_by_email<'e, E>(executor: E, email: &str) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -832,7 +820,7 @@ impl User { pub async fn find_by_username_or_email( conn: &mut PgConnection, username_or_email: &str, - ) -> Result, SqlxError> { + ) -> sqlx::Result> { let maybe_user = Self::find_by_username(&mut *conn, username_or_email).await?; if let Some(user) = maybe_user { Ok(Some(user)) @@ -844,10 +832,7 @@ impl User { } } - pub async fn find_many_by_emails<'e, E>( - executor: E, - emails: &[&str], - ) -> Result, SqlxError> + pub async fn find_many_by_emails<'e, E>(executor: E, emails: &[&str]) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -863,7 +848,7 @@ impl User { .await } - pub async fn find_by_sub<'e, E>(executor: E, sub: &str) -> Result, SqlxError> + pub async fn find_by_sub<'e, E>(executor: E, sub: &str) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -880,7 +865,7 @@ impl User { .await } - pub async fn member_of_names<'e, E>(&self, executor: E) -> Result, SqlxError> + pub async fn member_of_names<'e, E>(&self, executor: E) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -893,7 +878,7 @@ impl User { .await } - pub async fn member_of<'e, E>(&self, executor: E) -> Result>, SqlxError> + pub async fn member_of<'e, E>(&self, executor: E) -> sqlx::Result>> where E: PgExecutor<'e>, { @@ -910,7 +895,7 @@ impl User { /// Returns a vector of [`UserDevice`]s (hence the name). /// [`UserDevice`] is a struct containing additional network info about a device. /// If you only need [`Device`]s, use [`User::devices()`] instead. - pub async fn user_devices(&self, pool: &PgPool) -> Result, SqlxError> { + pub async fn user_devices(&self, pool: &PgPool) -> sqlx::Result> { let devices = self.devices(pool).await?; let mut user_devices = Vec::new(); for device in devices { @@ -924,7 +909,7 @@ impl User { /// Returns a vector of [`Device`]s related to a user. If you want to get [`UserDevice`]s (which contain additional network info), /// use [`User::user_devices()`] instead. - pub async fn devices<'e, E>(&self, executor: E) -> Result>, SqlxError> + pub async fn devices<'e, E>(&self, executor: E) -> sqlx::Result>> where E: PgExecutor<'e>, { @@ -943,7 +928,7 @@ impl User { pub async fn oauth2authorizedapps<'e, E>( &self, executor: E, - ) -> Result, SqlxError> + ) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -960,7 +945,7 @@ impl User { .await } - pub async fn security_keys(&self, pool: &PgPool) -> Result, SqlxError> { + pub async fn security_keys(&self, pool: &PgPool) -> sqlx::Result> { query_as!( SecurityKey, "SELECT id \"id!\", name FROM webauthn WHERE user_id = $1", @@ -970,7 +955,7 @@ impl User { .await } - pub async fn add_to_group<'e, E>(&self, executor: E, group: &Group) -> Result<(), SqlxError> + pub async fn add_to_group<'e, E>(&self, executor: E, group: &Group) -> sqlx::Result<()> where E: PgExecutor<'e>, { @@ -985,11 +970,7 @@ impl User { Ok(()) } - pub async fn remove_from_group<'e, E>( - &self, - executor: E, - group: &Group, - ) -> Result<(), SqlxError> + pub async fn remove_from_group<'e, E>(&self, executor: E, group: &Group) -> sqlx::Result<()> where E: PgExecutor<'e>, { @@ -1008,7 +989,7 @@ impl User { &self, executor: E, app_client_ids: &[i64], - ) -> Result<(), SqlxError> + ) -> sqlx::Result<()> where E: PgExecutor<'e>, { @@ -1082,7 +1063,7 @@ impl User { Ok(()) } - pub async fn logout_all_sessions<'e, E>(&self, executor: E) -> Result<(), SqlxError> + pub async fn logout_all_sessions<'e, E>(&self, executor: E) -> sqlx::Result<()> where E: PgExecutor<'e>, { @@ -1090,10 +1071,7 @@ impl User { Ok(()) } - pub async fn find_by_device_id<'e, E>( - executor: E, - device_id: Id, - ) -> Result, SqlxError> + pub async fn find_by_device_id<'e, E>(executor: E, device_id: Id) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -1114,7 +1092,7 @@ impl User { } /// Find users which emails are NOT in `user_emails`. - pub async fn exclude<'e, E>(executor: E, user_emails: &[&str]) -> Result, SqlxError> + pub async fn exclude<'e, E>(executor: E, user_emails: &[&str]) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -1131,7 +1109,7 @@ impl User { .await } - pub async fn is_admin<'e, E>(&self, executor: E) -> Result + pub async fn is_admin<'e, E>(&self, executor: E) -> sqlx::Result where E: PgExecutor<'e>, { @@ -1142,7 +1120,7 @@ impl User { } /// Find all users that are admins and are active. - pub async fn find_admins<'e, E>(executor: E) -> Result, SqlxError> + pub async fn find_admins<'e, E>(executor: E) -> sqlx::Result> where E: PgExecutor<'e>, { diff --git a/crates/defguard_common/src/db/models/vpn_client_session.rs b/crates/defguard_common/src/db/models/vpn_client_session.rs index f31ddca8b1..bfb95095d2 100644 --- a/crates/defguard_common/src/db/models/vpn_client_session.rs +++ b/crates/defguard_common/src/db/models/vpn_client_session.rs @@ -1,6 +1,6 @@ use chrono::{NaiveDateTime, Utc}; use model_derive::Model; -use sqlx::{Error as SqlxError, Type, query_as}; +use sqlx::{Type, query_as}; use crate::db::{ Id, NoId, @@ -81,7 +81,7 @@ impl VpnClientSession { executor: E, location_id: Id, device_id: Id, - ) -> Result, SqlxError> { + ) -> sqlx::Result> { query_as!( Self, "SELECT id, location_id, user_id, device_id, created_at, connected_at, disconnected_at, \ @@ -99,7 +99,7 @@ impl VpnClientSession { pub async fn get_latest_stats_for_all_gateways<'e, E: sqlx::PgExecutor<'e>>( &self, executor: E, - ) -> Result>, SqlxError> { + ) -> sqlx::Result>> { query_as!( VpnSessionStats, "SELECT DISTINCT ON (gateway_id) id, session_id, gateway_id, collected_at, latest_handshake, endpoint, \ @@ -117,7 +117,7 @@ impl VpnClientSession { pub async fn get_all_inactive_for_location<'e, E: sqlx::PgExecutor<'e>>( executor: E, location: &WireguardNetwork, - ) -> Result, SqlxError> { + ) -> sqlx::Result> { query_as!( Self, "SELECT s.id, location_id, user_id, device_id, created_at, s.connected_at, disconnected_at, \ @@ -141,7 +141,7 @@ impl VpnClientSession { pub async fn get_never_connected<'e, E: sqlx::PgExecutor<'e>>( executor: E, location: &WireguardNetwork, - ) -> Result, SqlxError> { + ) -> sqlx::Result> { query_as!( Self, "SELECT id, location_id, user_id, device_id, created_at, connected_at, disconnected_at, \ @@ -159,7 +159,7 @@ impl VpnClientSession { executor: E, location_id: Id, device_id: Id, - ) -> Result, SqlxError> { + ) -> sqlx::Result> { query_as!( Self, "SELECT id, location_id, user_id, device_id, created_at, connected_at, disconnected_at, \ diff --git a/crates/defguard_common/src/db/models/vpn_session_stats.rs b/crates/defguard_common/src/db/models/vpn_session_stats.rs index 3225780393..543548bc7b 100644 --- a/crates/defguard_common/src/db/models/vpn_session_stats.rs +++ b/crates/defguard_common/src/db/models/vpn_session_stats.rs @@ -59,7 +59,7 @@ impl VpnSessionStats { executor: E, device_id: Id, location_id: Id, - ) -> Result, sqlx::Error> { + ) -> sqlx::Result> { let maybe_stats = query_as!( Self, "SELECT st.id, session_id, gateway_id, collected_at, latest_handshake, endpoint, \ diff --git a/crates/defguard_common/src/db/models/webauthn.rs b/crates/defguard_common/src/db/models/webauthn.rs index 2fc9730f6a..ca17beb767 100644 --- a/crates/defguard_common/src/db/models/webauthn.rs +++ b/crates/defguard_common/src/db/models/webauthn.rs @@ -1,5 +1,5 @@ use model_derive::Model; -use sqlx::{Error as SqlxError, PgExecutor, PgPool, query, query_as, query_scalar}; +use sqlx::{PgExecutor, PgPool, query, query_as, query_scalar}; use webauthn_rs::prelude::Passkey; use crate::db::{Id, NoId, models::ModelError}; @@ -37,7 +37,7 @@ impl WebAuthn { impl WebAuthn { /// Fetch all [`Passkey`]s for a given user. - pub async fn passkeys_for_user(pool: &PgPool, user_id: Id) -> Result, SqlxError> { + pub async fn passkeys_for_user(pool: &PgPool, user_id: Id) -> sqlx::Result> { query_scalar!("SELECT passkey FROM webauthn WHERE user_id = $1", user_id) .fetch_all(pool) .await @@ -50,7 +50,7 @@ impl WebAuthn { } /// Fetch all for a given user. - pub async fn all_for_user(pool: &PgPool, user_id: Id) -> Result, SqlxError> { + pub async fn all_for_user(pool: &PgPool, user_id: Id) -> sqlx::Result> { query_as!( Self, "SELECT id, user_id, name, passkey FROM webauthn WHERE user_id = $1", @@ -61,7 +61,7 @@ impl WebAuthn { } /// Delete all for a given user. - pub async fn delete_all_for_user<'e, E>(executor: E, user_id: Id) -> Result<(), SqlxError> + pub async fn delete_all_for_user<'e, E>(executor: E, user_id: Id) -> sqlx::Result<()> where E: PgExecutor<'e>, { diff --git a/crates/defguard_common/src/db/models/wireguard.rs b/crates/defguard_common/src/db/models/wireguard.rs index 683143458f..3fde5cf255 100644 --- a/crates/defguard_common/src/db/models/wireguard.rs +++ b/crates/defguard_common/src/db/models/wireguard.rs @@ -274,23 +274,23 @@ impl WireguardNetwork { } /// Address list setter. - pub fn set_address(&mut self, address: V) + pub fn set_address(mut self, address: V) -> Result where V: Into>, { - self.address = address.into(); - } - - /// Validate addresses. - pub fn address_is_valid(&self) -> bool { - for addr in &self.address { + let address = address.into(); + for addr in &address { let ip = addr.ip(); - if ip == addr.network() || ip == addr.broadcast() { - return false; + if ip == addr.network() { + return Err(IpNetworkError::InvalidAddr("address is network".into())); + } + if ip == addr.broadcast() { + return Err(IpNetworkError::InvalidAddr("address is broadcast".into())); } } + self.address = address; - true + Ok(self) } } @@ -674,7 +674,7 @@ impl WireguardNetwork { conn: &PgPool, from: &NaiveDateTime, aggregation: &DateTimeAggregation, - ) -> Result, sqlx::Error> { + ) -> sqlx::Result> { let mut user_map: HashMap> = HashMap::new(); // Retrieve data series for all active devices and assign them to users @@ -708,7 +708,7 @@ impl WireguardNetwork { aggregation: &DateTimeAggregation, page: u32, page_size: u32, - ) -> Result<(Vec, u32), sqlx::Error> { + ) -> sqlx::Result<(Vec, u32)> { // helper struct used to fetch connected users from the DB struct ConnectedUserRow { user_id: Id, @@ -830,7 +830,7 @@ impl WireguardNetwork { aggregation: &DateTimeAggregation, page: u32, page_size: u32, - ) -> Result<(Vec, u32), sqlx::Error> { + ) -> sqlx::Result<(Vec, u32)> { // helper struct used to fetch connected network devices from the DB struct ConnectedNetworkDeviceRow { device_id: Id, @@ -940,7 +940,7 @@ impl WireguardNetwork { user_id: Id, from: &NaiveDateTime, aggregation: &DateTimeAggregation, - ) -> Result, sqlx::Error> { + ) -> sqlx::Result> { // helper struct used to fetch connected user devices from the DB struct ConnectedUserDeviceRow { device_id: Id, @@ -1079,7 +1079,7 @@ impl WireguardNetwork { pool: &PgPool, from: &NaiveDateTime, aggregation: &DateTimeAggregation, - ) -> Result, sqlx::Error> { + ) -> sqlx::Result> { let stats = query_as!( WireguardStatsRow, "SELECT \ @@ -1129,7 +1129,7 @@ impl WireguardNetwork { &self, executor: E, device_type: DeviceType, - ) -> Result>, sqlx::Error> + ) -> sqlx::Result>> where E: PgExecutor<'e>, { @@ -1931,7 +1931,7 @@ mod test { async fn test_can_assign_ips(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let mut network = WireguardNetwork::new( + let network = WireguardNetwork::new( "network".to_string(), 50051, String::new(), @@ -1942,9 +1942,12 @@ mod test { false, LocationMfaMode::Disabled, ServiceLocationMode::Disabled, - ); - network.set_address([IpNetwork::from_str("10.1.1.1/24").unwrap()]); - let network = network.save(&pool).await.unwrap(); + ) + .set_address([IpNetwork::from_str("10.1.1.1/24").unwrap()]) + .unwrap() + .save(&pool) + .await + .unwrap(); // assign free address let addrs = vec![IpAddr::from_str("10.1.1.2").unwrap()]; @@ -2057,7 +2060,7 @@ mod test { async fn test_can_assign_ips_multiple_addresses(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let mut network = WireguardNetwork::new( + let network = WireguardNetwork::new( "network".to_string(), 50051, String::new(), @@ -2068,12 +2071,15 @@ mod test { false, LocationMfaMode::Disabled, ServiceLocationMode::Disabled, - ); - network.set_address([ + ) + .set_address([ IpNetwork::from_str("10.1.1.1/24").unwrap(), IpNetwork::from_str("fc00::1/112").unwrap(), - ]); - let network = network.save(&pool).await.unwrap(); + ]) + .unwrap() + .save(&pool) + .await + .unwrap(); // assign free addresses let addrs = vec![ diff --git a/crates/defguard_common/src/types/user_info.rs b/crates/defguard_common/src/types/user_info.rs index 6716d877a0..9dab2c1a19 100644 --- a/crates/defguard_common/src/types/user_info.rs +++ b/crates/defguard_common/src/types/user_info.rs @@ -1,5 +1,5 @@ use serde::{Deserialize, Serialize}; -use sqlx::{Error as SqlxError, PgConnection, PgPool}; +use sqlx::{PgConnection, PgPool}; use utoipa::ToSchema; use crate::{ @@ -39,7 +39,7 @@ pub struct UserInfo { } impl UserInfo { - pub async fn from_user(pool: &PgPool, user: &User) -> Result { + pub async fn from_user(pool: &PgPool, user: &User) -> sqlx::Result { let groups = user.member_of_names(pool).await?; let authorized_apps = user.oauth2authorizedapps(pool).await?; @@ -71,7 +71,7 @@ impl UserInfo { &self, transaction: &mut PgConnection, user: &mut User, - ) -> Result { + ) -> sqlx::Result { if self.is_active == user.is_active { Ok(false) } else { @@ -91,7 +91,7 @@ impl UserInfo { &self, transaction: &mut PgConnection, user: &mut User, - ) -> Result { + ) -> sqlx::Result { // initialize return value let mut group_diff = GroupDiff::default(); @@ -126,7 +126,7 @@ impl UserInfo { } /// Copy fields to [`User`]. This function is safe to call by a non-admin user. - pub fn into_user_safe_fields(self, user: &mut User) -> Result<(), SqlxError> { + pub fn into_user_safe_fields(self, user: &mut User) -> sqlx::Result<()> { user.phone = self.phone; user.mfa_method = self.mfa_method; @@ -134,7 +134,7 @@ impl UserInfo { } /// Copy fields to [`User`]. This function should be used by administrators. - pub fn into_user_all_fields(self, user: &mut User) -> Result<(), SqlxError> { + pub fn into_user_all_fields(self, user: &mut User) -> sqlx::Result<()> { user.phone = self.phone; user.username = self.username; user.last_name = self.last_name; diff --git a/crates/defguard_core/src/db/models/enrollment.rs b/crates/defguard_core/src/db/models/enrollment.rs index 2c1862df89..76b5b1c0f0 100644 --- a/crates/defguard_core/src/db/models/enrollment.rs +++ b/crates/defguard_core/src/db/models/enrollment.rs @@ -12,7 +12,7 @@ use defguard_mail::{ Mail, templates::{self, TemplateError, safe_tera}, }; -use sqlx::{Error as SqlxError, PgConnection, PgExecutor, PgPool, Transaction, query, query_as}; +use sqlx::{PgConnection, PgExecutor, PgPool, Transaction, query, query_as}; use tera::Context; use thiserror::Error; use tonic::{Code, Status}; @@ -23,7 +23,7 @@ pub static PASSWORD_RESET_TOKEN_TYPE: &str = "PASSWORD_RESET"; #[derive(Error, Debug)] pub enum TokenError { #[error(transparent)] - DbError(#[from] SqlxError), + DbError(#[from] sqlx::Error), #[error("Enrollment token not found")] NotFound, #[error("Enrollment token expired")] diff --git a/crates/defguard_core/src/db/models/webhook.rs b/crates/defguard_core/src/db/models/webhook.rs index 91086edbce..ff4040b006 100644 --- a/crates/defguard_core/src/db/models/webhook.rs +++ b/crates/defguard_core/src/db/models/webhook.rs @@ -3,7 +3,7 @@ use defguard_common::{ types::user_info::UserInfo, }; use model_derive::Model; -use sqlx::{Error as SqlxError, FromRow, PgPool, query_as}; +use sqlx::{FromRow, PgPool, query_as}; /// App events which triggers webhook action #[derive(Debug)] @@ -63,7 +63,7 @@ pub struct WebHook { impl WebHook { /// Fetch all enabled webhooks. - pub async fn all_enabled(pool: &PgPool, trigger: &AppEvent) -> Result, SqlxError> { + pub async fn all_enabled(pool: &PgPool, trigger: &AppEvent) -> sqlx::Result> { let column_name = trigger.column_name(); let query = format!( "SELECT id, url, description, token, enabled, on_user_created, \ @@ -74,7 +74,7 @@ impl WebHook { } /// Find [`WebHook`] by URL. - pub async fn find_by_url(pool: &PgPool, url: &str) -> Result, SqlxError> { + pub async fn find_by_url(pool: &PgPool, url: &str) -> sqlx::Result> { query_as!( Self, "SELECT id, url, description, token, enabled, on_user_created, \ diff --git a/crates/defguard_core/src/enterprise/db/models/acl.rs b/crates/defguard_core/src/enterprise/db/models/acl.rs index 472029f5ac..21279899d1 100644 --- a/crates/defguard_core/src/enterprise/db/models/acl.rs +++ b/crates/defguard_core/src/enterprise/db/models/acl.rs @@ -13,8 +13,8 @@ use defguard_common::db::{ use ipnetwork::IpNetwork; use model_derive::Model; use sqlx::{ - Error as SqlxError, FromRow, PgConnection, PgExecutor, PgPool, Type, error::ErrorKind, - postgres::types::PgRange, query, query_as, query_scalar, + FromRow, PgConnection, PgExecutor, PgPool, Type, error::ErrorKind, postgres::types::PgRange, + query, query_as, query_scalar, }; use thiserror::Error; use utoipa::ToSchema; @@ -41,7 +41,7 @@ pub enum AclError { #[error(transparent)] AddrParseError(#[from] std::net::AddrParseError), #[error(transparent)] - DbError(#[from] SqlxError), + DbError(#[from] sqlx::Error), #[error("InvalidRelationError: {0}")] InvalidRelationError(String), #[error("RuleNotFoundError: {0}")] @@ -664,8 +664,8 @@ pub fn parse_ports(ports: &str) -> Result, AclError> { } /// Maps [`sqlx::Error`] to [`AclError`] while checking for [`ErrorKind::ForeignKeyViolation`]. -fn map_relation_error(err: SqlxError, class: &str, id: Id) -> AclError { - if let SqlxError::Database(dberror) = &err { +fn map_relation_error(err: sqlx::Error, class: &str, id: Id) -> AclError { + if let sqlx::Error::Database(dberror) = &err { if dberror.kind() == ErrorKind::ForeignKeyViolation { error!( "Failed to create ACL related object, foreign key violation: {class}({id}): {dberror}" @@ -797,10 +797,7 @@ impl AclRule { } /// Deletes relation objects for given [`AclRule`] - async fn delete_related_objects( - &self, - transaction: &mut PgConnection, - ) -> Result<(), SqlxError> { + async fn delete_related_objects(&self, transaction: &mut PgConnection) -> sqlx::Result<()> { let rule_id = self.id; debug!("Deleting related objects for ACL rule {rule_id}"); // networks @@ -972,10 +969,7 @@ impl AclRule { } /// Returns all [`AclAlias`]es the rule applies to - pub(crate) async fn get_aliases<'e, E>( - &self, - executor: E, - ) -> Result>, SqlxError> + pub(crate) async fn get_aliases<'e, E>(&self, executor: E) -> sqlx::Result>> where E: PgExecutor<'e>, { @@ -998,7 +992,7 @@ impl AclRule { &self, executor: E, allowed: bool, - ) -> Result>, SqlxError> + ) -> sqlx::Result>> where E: PgExecutor<'e>, { @@ -1010,10 +1004,7 @@ impl AclRule { } /// Returns **active** [`User`]s that are allowed by the rule - pub(crate) async fn get_allowed_users<'e, E>( - &self, - executor: E, - ) -> Result>, SqlxError> + pub(crate) async fn get_allowed_users<'e, E>(&self, executor: E) -> sqlx::Result>> where E: PgExecutor<'e>, { @@ -1036,10 +1027,7 @@ impl AclRule { } /// Returns **active** [`User`]s that are denied by the rule - pub(crate) async fn get_denied_users<'e, E>( - &self, - executor: E, - ) -> Result>, SqlxError> + pub(crate) async fn get_denied_users<'e, E>(&self, executor: E) -> sqlx::Result>> where E: PgExecutor<'e>, { @@ -1066,7 +1054,7 @@ impl AclRule { &self, executor: E, allowed: bool, - ) -> Result>, SqlxError> + ) -> sqlx::Result>> where E: PgExecutor<'e>, { @@ -1088,7 +1076,7 @@ impl AclRule { &self, executor: E, allowed: bool, - ) -> Result>, SqlxError> + ) -> sqlx::Result>> where E: PgExecutor<'e>, { @@ -1102,7 +1090,7 @@ impl AclRule { pub(crate) async fn get_allowed_network_devices<'e, E>( &self, executor: E, - ) -> Result>, SqlxError> + ) -> sqlx::Result>> where E: PgExecutor<'e>, { @@ -1122,7 +1110,7 @@ impl AclRule { pub(crate) async fn get_denied_network_devices<'e, E>( &self, executor: E, - ) -> Result>, SqlxError> + ) -> sqlx::Result>> where E: PgExecutor<'e>, { @@ -1143,7 +1131,7 @@ impl AclRule { pub(crate) async fn get_destination_address_ranges<'e, E>( &self, executor: E, - ) -> Result>, SqlxError> + ) -> sqlx::Result>> where E: PgExecutor<'e>, { @@ -1160,7 +1148,7 @@ impl AclRule { /// Retrieves all related objects from the db and converts [`AclRule`] /// instance to [`AclRuleInfo`]. - pub async fn to_info(&self, conn: &mut PgConnection) -> Result, SqlxError> { + pub async fn to_info(&self, conn: &mut PgConnection) -> sqlx::Result> { let locations = self.get_networks(&mut *conn).await?; let allowed_users = self.get_users(&mut *conn, true).await?; let denied_users = self.get_users(&mut *conn, false).await?; @@ -1218,7 +1206,7 @@ impl AclRuleInfo { pub(crate) async fn get_all_allowed_users( &self, conn: &mut PgConnection, - ) -> Result>, SqlxError> { + ) -> sqlx::Result>> { debug!( "Preparing list of all allowed users for ACL rule {}", self.id @@ -1274,7 +1262,7 @@ impl AclRuleInfo { pub(crate) async fn get_all_denied_users( &self, conn: &mut PgConnection, - ) -> Result>, SqlxError> { + ) -> sqlx::Result>> { debug!( "Preparing list of all denied users for ACL rule {}", self.id @@ -1332,7 +1320,7 @@ impl AclRuleInfo { &self, executor: E, location_id: Id, - ) -> Result>, SqlxError> { + ) -> sqlx::Result>> { debug!( "Preparing list of all allowed network devices for ACL rule {}", self.id @@ -1364,7 +1352,7 @@ impl AclRuleInfo { &self, executor: E, location_id: Id, - ) -> Result>, SqlxError> { + ) -> sqlx::Result>> { debug!( "Preparing list of all denied network devices for ACL rule {}", self.id @@ -1784,7 +1772,7 @@ impl AclAlias { pub(crate) async fn get_destination_ranges<'e, E>( &self, executor: E, - ) -> Result>, SqlxError> + ) -> sqlx::Result>> where E: PgExecutor<'e>, { @@ -1800,7 +1788,7 @@ impl AclAlias { } /// Returns all [`AclRule`]s which use this alias - pub(crate) async fn get_rules<'e, E>(&self, executor: E) -> Result>, SqlxError> + pub(crate) async fn get_rules<'e, E>(&self, executor: E) -> sqlx::Result>> where E: PgExecutor<'e>, { @@ -1821,7 +1809,7 @@ impl AclAlias { /// Retrieves all related objects from the db and converts [`AclAlias`] /// instance to [`AclAliasInfo`]. - pub(crate) async fn to_info(&self, pool: &PgPool) -> Result { + pub(crate) async fn to_info(&self, pool: &PgPool) -> sqlx::Result { let destination_ranges = self.get_destination_ranges(pool).await?; let rules = self.get_rules(pool).await?; @@ -2007,7 +1995,7 @@ pub struct AclRuleDestinationRange { } impl AclRuleDestinationRange { - pub async fn save<'e, E>(self, executor: E) -> Result, SqlxError> + pub async fn save<'e, E>(self, executor: E) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -2047,7 +2035,7 @@ pub(crate) struct AclAliasDestinationRange { } impl AclAliasDestinationRange { - pub async fn save<'e, E>(self, executor: E) -> Result, SqlxError> + pub async fn save<'e, E>(self, executor: E) -> sqlx::Result> where E: PgExecutor<'e>, { diff --git a/crates/defguard_core/src/enterprise/db/models/activity_log_stream.rs b/crates/defguard_core/src/enterprise/db/models/activity_log_stream.rs index 3548de1883..6a4cbc1366 100644 --- a/crates/defguard_core/src/enterprise/db/models/activity_log_stream.rs +++ b/crates/defguard_core/src/enterprise/db/models/activity_log_stream.rs @@ -4,7 +4,7 @@ use defguard_common::{ }; use model_derive::Model; use serde::Serialize; -use sqlx::{Error as SqlxError, FromRow, PgExecutor, Type, query_as}; +use sqlx::{FromRow, PgExecutor, Type, query_as}; use strum_macros::{Display, EnumString}; use crate::enterprise::activity_log_stream::error::ActivityLogStreamError; @@ -89,7 +89,7 @@ impl ActivityLogStream { pub async fn find_by_stream_type<'e, E>( executor: E, stream_type: &ActivityLogStreamType, - ) -> Result, SqlxError> + ) -> sqlx::Result> where E: PgExecutor<'e>, { diff --git a/crates/defguard_core/src/enterprise/db/models/api_tokens.rs b/crates/defguard_core/src/enterprise/db/models/api_tokens.rs index 474282177a..9c0af50b74 100644 --- a/crates/defguard_core/src/enterprise/db/models/api_tokens.rs +++ b/crates/defguard_core/src/enterprise/db/models/api_tokens.rs @@ -1,7 +1,7 @@ use chrono::NaiveDateTime; use defguard_common::db::{Id, NoId}; use model_derive::Model; -use sqlx::{Error as SqlxError, PgExecutor, query_as}; +use sqlx::{PgExecutor, query_as}; #[derive(Clone, Debug, Deserialize, Model, Serialize, PartialEq)] #[table(api_token)] @@ -33,7 +33,7 @@ impl ApiToken { } impl ApiToken { - pub async fn find_by_user_id<'e, E>(executor: E, user_id: Id) -> Result, SqlxError> + pub async fn find_by_user_id<'e, E>(executor: E, user_id: Id) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -50,7 +50,7 @@ impl ApiToken { pub async fn try_find_by_auth_token<'e, E>( executor: E, auth_token: &str, - ) -> Result, SqlxError> + ) -> sqlx::Result> where E: PgExecutor<'e>, { diff --git a/crates/defguard_core/src/enterprise/db/models/openid_provider.rs b/crates/defguard_core/src/enterprise/db/models/openid_provider.rs index 16f15d51ff..e0687d1634 100644 --- a/crates/defguard_core/src/enterprise/db/models/openid_provider.rs +++ b/crates/defguard_core/src/enterprise/db/models/openid_provider.rs @@ -2,7 +2,7 @@ use std::fmt; use defguard_common::db::{Id, NoId}; use model_derive::Model; -use sqlx::{Error as SqlxError, PgExecutor, PgPool, Type, query, query_as}; +use sqlx::{PgExecutor, PgPool, Type, query, query_as}; use utoipa::ToSchema; // The behavior when a user is deleted from the directory @@ -181,7 +181,7 @@ impl OpenIdProvider { } } - pub(crate) async fn upsert(self, pool: &PgPool) -> Result, SqlxError> { + pub(crate) async fn upsert(self, pool: &PgPool) -> sqlx::Result> { if let Some(provider) = OpenIdProvider::::get_current(pool).await? { query!( "UPDATE openidprovider SET name = $1, base_url = $2, kind = $3, client_id = $4, \ @@ -224,7 +224,7 @@ impl OpenIdProvider { } impl OpenIdProvider { - pub async fn find_by_name<'e, E>(executor: E, name: &str) -> Result, SqlxError> + pub async fn find_by_name<'e, E>(executor: E, name: &str) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -243,7 +243,7 @@ impl OpenIdProvider { .await } - pub async fn get_current<'e, E>(executor: E) -> Result, SqlxError> + pub async fn get_current<'e, E>(executor: E) -> sqlx::Result> where E: PgExecutor<'e>, { diff --git a/crates/defguard_core/src/enterprise/directory_sync/mod.rs b/crates/defguard_core/src/enterprise/directory_sync/mod.rs index 79c3a87837..91b80fd99b 100644 --- a/crates/defguard_core/src/enterprise/directory_sync/mod.rs +++ b/crates/defguard_core/src/enterprise/directory_sync/mod.rs @@ -10,7 +10,7 @@ use defguard_common::db::{ }; use paste::paste; use reqwest::header::AUTHORIZATION; -use sqlx::{PgConnection, PgPool, error::Error as SqlxError}; +use sqlx::{PgConnection, PgPool}; use thiserror::Error; use tokio::sync::broadcast::Sender; @@ -42,7 +42,7 @@ const REQUEST_PAGINATION_SLOWDOWN: Duration = Duration::from_millis(100); #[derive(Debug, Error)] pub enum DirectorySyncError { #[error("Database error: {0}")] - DbError(#[from] SqlxError), + DbError(#[from] sqlx::Error), #[error( "Access token has expired or is not present. An issue may have occured while trying to obtain a new one." )] diff --git a/crates/defguard_core/src/enterprise/directory_sync/tests.rs b/crates/defguard_core/src/enterprise/directory_sync/tests.rs index 3c5849f0f6..6bed257fd3 100644 --- a/crates/defguard_core/src/enterprise/directory_sync/tests.rs +++ b/crates/defguard_core/src/enterprise/directory_sync/tests.rs @@ -52,7 +52,7 @@ mod test { provider.delete(pool).await.unwrap(); } - let mut location = WireguardNetwork::new( + WireguardNetwork::new( "test".to_string(), 1234, "123.123.123.123".to_string(), @@ -63,9 +63,12 @@ mod test { false, LocationMfaMode::Disabled, ServiceLocationMode::Disabled, - ); - location.set_address([IpNetwork::from_str("10.10.10.1/24").unwrap()]); - location.save(pool).await.unwrap(); + ) + .set_address([IpNetwork::from_str("10.10.10.1/24").unwrap()]) + .unwrap() + .save(pool) + .await + .unwrap(); OpenIdProvider::new( "Test".to_string(), diff --git a/crates/defguard_core/src/enterprise/firewall/mod.rs b/crates/defguard_core/src/enterprise/firewall/mod.rs index 7a8b61cb3b..52feb7bea8 100644 --- a/crates/defguard_core/src/enterprise/firewall/mod.rs +++ b/crates/defguard_core/src/enterprise/firewall/mod.rs @@ -13,7 +13,7 @@ use defguard_proto::enterprise::firewall::{ port::Port as PortInner, }; use ipnetwork::IpNetwork; -use sqlx::{Error as SqlxError, PgConnection, query_as, query_scalar}; +use sqlx::{PgConnection, query_as, query_scalar}; use super::{ db::models::acl::{AclRule, AclRuleDestinationRange, AclRuleInfo, PortRange, Protocol}, @@ -480,7 +480,7 @@ async fn get_user_device_ips<'e, E: sqlx::PgExecutor<'e>>( user_ids: &[Id], location_id: Id, executor: E, -) -> Result>, SqlxError> { +) -> sqlx::Result>> { // fetch network IPs query_scalar!( "SELECT wireguard_ips \"wireguard_ips: Vec\" \ @@ -515,7 +515,7 @@ async fn get_network_device_ips( network_devices: &[Device], location_id: Id, conn: &mut PgConnection, -) -> Result>, SqlxError> { +) -> sqlx::Result>> { // prepare a list of IDs let network_device_ids: Vec = network_devices.iter().map(|device| device.id).collect(); @@ -898,7 +898,7 @@ fn merge_port_ranges(port_ranges: Vec) -> Vec { async fn generate_user_snat_bindings_for_location( location_id: Id, conn: &mut PgConnection, -) -> Result, SqlxError> { +) -> sqlx::Result> { debug!("Generating SNAT bindings for location {location_id}"); let user_snat_bindings = UserSnatBinding::all_for_location(&mut *conn, location_id).await?; @@ -978,7 +978,7 @@ async fn generate_user_snat_bindings_for_location( pub(crate) async fn get_location_active_acl_rules( location: &WireguardNetwork, conn: &mut PgConnection, -) -> Result>, SqlxError> { +) -> sqlx::Result>> { debug!("Fetching active ACL rules for location {location}"); let rules: Vec> = query_as( "SELECT DISTINCT ON (a.id) a.id, name, allow_all_users, deny_all_users, all_locations, \ diff --git a/crates/defguard_core/src/enterprise/firewall/tests/all_locations.rs b/crates/defguard_core/src/enterprise/firewall/tests/all_locations.rs index 9ecbe07180..a5f2d1e3f9 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/all_locations.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/all_locations.rs @@ -1,7 +1,5 @@ -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; - use defguard_common::db::{models::WireguardNetwork, setup_pool}; -use ipnetwork::IpNetwork; + use rand::thread_rng; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; @@ -20,15 +18,17 @@ async fn test_acl_rules_all_locations_ipv4(_: PgPoolOptions, options: PgConnectO set_test_license_business(); // Create test location - let mut location_1 = WireguardNetwork::default(); + let mut location_1 = WireguardNetwork::default() + .set_address(["192.168.0.1/24".parse().unwrap()]) + .unwrap(); location_1.acl_enabled = true; - location_1.set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap()]); let location_1 = location_1.save(&pool).await.unwrap(); // Create another test location - let mut location_2 = WireguardNetwork::default(); + let mut location_2 = WireguardNetwork::default() + .set_address(["fb00::1/112".parse().unwrap()]) + .unwrap(); location_2.acl_enabled = true; - location_2.set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap()]); let location_2 = location_2.save(&pool).await.unwrap(); // Setup some test users and their devices @@ -117,15 +117,17 @@ async fn test_acl_rules_all_locations_ipv6(_: PgPoolOptions, options: PgConnectO let mut rng = thread_rng(); // Create test location - let mut location_1 = WireguardNetwork::default(); + let mut location_1 = WireguardNetwork::default() + .set_address(["fb00::1/112".parse().unwrap()]) + .unwrap(); location_1.acl_enabled = true; - location_1.set_address([IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap()]); let location_1 = location_1.save(&pool).await.unwrap(); // Create another test location - let mut location_2 = WireguardNetwork::default(); + let mut location_2 = WireguardNetwork::default() + .set_address(["fb00::1/112".parse().unwrap()]) + .unwrap(); location_2.acl_enabled = true; - location_2.set_address([IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap()]); let location_2 = location_2.save(&pool).await.unwrap(); // Setup some test users and their devices @@ -214,21 +216,23 @@ async fn test_acl_rules_all_locations_ipv4_and_ipv6(_: PgPoolOptions, options: P let mut rng = thread_rng(); // Create test location - let mut location_1 = WireguardNetwork::default(); + let mut location_1 = WireguardNetwork::default() + .set_address([ + "192.168.0.1/24".parse().unwrap(), + "fb00::1/112".parse().unwrap(), + ]) + .unwrap(); location_1.acl_enabled = true; - location_1.set_address([ - IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap(), - IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap(), - ]); let location_1 = location_1.save(&pool).await.unwrap(); // Create another test location - let mut location_2 = WireguardNetwork::default(); + let mut location_2 = WireguardNetwork::default() + .set_address([ + "192.168.0.1/24".parse().unwrap(), + "fb00::1/112".parse().unwrap(), + ]) + .unwrap(); location_2.acl_enabled = true; - location_2.set_address([ - IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap(), - IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap(), - ]); let location_2 = location_2.save(&pool).await.unwrap(); // Setup some test users and their devices diff --git a/crates/defguard_core/src/enterprise/firewall/tests/destination.rs b/crates/defguard_core/src/enterprise/firewall/tests/destination.rs index 1ed4cfe919..213bbd1a18 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/destination.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/destination.rs @@ -150,9 +150,10 @@ async fn test_any_address_overwrites_manual_destination( let mut rng = thread_rng(); - let mut location = WireguardNetwork::default(); + let mut location = WireguardNetwork::default() + .set_address(["10.0.0.1/16".parse().unwrap()]) + .unwrap(); location.acl_enabled = true; - location.set_address(["10.0.0.0/16".parse().unwrap()]); let location = location.save(&pool).await.unwrap(); create_test_users_and_devices(&mut rng, &pool, vec![&location]).await; @@ -230,9 +231,10 @@ async fn test_any_address_overwrites_destination_alias_addrs( let mut rng = thread_rng(); - let mut location = WireguardNetwork::default(); + let mut location = WireguardNetwork::default() + .set_address(["10.0.0.1/16".parse().unwrap()]) + .unwrap(); location.acl_enabled = true; - location.set_address(["10.0.0.0/16".parse().unwrap()]); let location = location.save(&pool).await.unwrap(); create_test_users_and_devices(&mut rng, &pool, vec![&location]).await; @@ -328,9 +330,10 @@ async fn test_manual_destination_includes_component_alias_address_range( let mut rng = thread_rng(); - let mut location = WireguardNetwork::default(); + let mut location = WireguardNetwork::default() + .set_address(["10.0.0.1/16".parse().unwrap()]) + .unwrap(); location.acl_enabled = true; - location.set_address(["10.0.0.0/16".parse().unwrap()]); let location = location.save(&pool).await.unwrap(); create_test_users_and_devices(&mut rng, &pool, vec![&location]).await; @@ -429,9 +432,10 @@ async fn test_manual_destination_merges_rule_and_component_alias_address_ranges( let mut rng = thread_rng(); - let mut location = WireguardNetwork::default(); + let mut location = WireguardNetwork::default() + .set_address(["10.0.0.1/16".parse().unwrap()]) + .unwrap(); location.acl_enabled = true; - location.set_address(["10.0.0.0/16".parse().unwrap()]); let location = location.save(&pool).await.unwrap(); create_test_users_and_devices(&mut rng, &pool, vec![&location]).await; diff --git a/crates/defguard_core/src/enterprise/firewall/tests/disabled_rules.rs b/crates/defguard_core/src/enterprise/firewall/tests/disabled_rules.rs index 74595dd06e..e1a930fafe 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/disabled_rules.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/disabled_rules.rs @@ -1,7 +1,4 @@ -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; - use defguard_common::db::{models::WireguardNetwork, setup_pool}; -use ipnetwork::IpNetwork; use rand::thread_rng; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; @@ -20,9 +17,10 @@ async fn test_disabled_acl_rules_ipv4(_: PgPoolOptions, options: PgConnectOption let mut rng = thread_rng(); // Create test location - let mut location = WireguardNetwork::default(); + let mut location = WireguardNetwork::default() + .set_address(["192.168.0.1/24".parse().unwrap()]) + .unwrap(); location.acl_enabled = true; - location.set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap()]); let location = location.save(&pool).await.unwrap(); // Setup some test users and their devices @@ -92,9 +90,10 @@ async fn test_disabled_acl_rules_ipv6(_: PgPoolOptions, options: PgConnectOption let mut rng = thread_rng(); // Create test location - let mut location = WireguardNetwork::default(); + let mut location = WireguardNetwork::default() + .set_address(["fb00::1/112".parse().unwrap()]) + .unwrap(); location.acl_enabled = true; - location.set_address([IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap()]); let location = location.save(&pool).await.unwrap(); // Setup some test users and their devices @@ -164,12 +163,13 @@ async fn test_disabled_acl_rules_ipv4_and_ipv6(_: PgPoolOptions, options: PgConn let mut rng = thread_rng(); // Create test location - let mut location = WireguardNetwork::default(); + let mut location = WireguardNetwork::default() + .set_address([ + "192.168.0.1/24".parse().unwrap(), + "fb00::1/112".parse().unwrap(), + ]) + .unwrap(); location.acl_enabled = true; - location.set_address([ - IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap(), - IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap(), - ]); let location = location.save(&pool).await.unwrap(); // Setup some test users and their devices diff --git a/crates/defguard_core/src/enterprise/firewall/tests/expired_rules.rs b/crates/defguard_core/src/enterprise/firewall/tests/expired_rules.rs index 53f9fac5ae..dfdb5b7f60 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/expired_rules.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/expired_rules.rs @@ -1,8 +1,5 @@ -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; - use chrono::{DateTime, NaiveDateTime}; use defguard_common::db::{models::WireguardNetwork, setup_pool}; -use ipnetwork::IpNetwork; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use crate::enterprise::{ @@ -77,9 +74,10 @@ async fn test_expired_acl_rules_ipv6(_: PgPoolOptions, options: PgConnectOptions set_test_license_business(); let pool = setup_pool(options).await; // Create test location - let mut location = WireguardNetwork::default(); + let mut location = WireguardNetwork::default() + .set_address(["fb00::1/112".parse().unwrap()]) + .unwrap(); location.acl_enabled = true; - location.set_address([IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap()]); let location = location.save(&pool).await.unwrap(); // create expired ACL rules @@ -140,12 +138,13 @@ async fn test_expired_acl_rules_ipv4_and_ipv6(_: PgPoolOptions, options: PgConne set_test_license_business(); let pool = setup_pool(options).await; // Create test location - let mut location = WireguardNetwork::default(); + let mut location = WireguardNetwork::default() + .set_address([ + "192.168.0.1/24".parse().unwrap(), + "fb00::1/112".parse().unwrap(), + ]) + .unwrap(); location.acl_enabled = true; - location.set_address([ - IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap(), - IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap(), - ]); let location = location.save(&pool).await.unwrap(); // create expired ACL rules diff --git a/crates/defguard_core/src/enterprise/firewall/tests/gh1868.rs b/crates/defguard_core/src/enterprise/firewall/tests/gh1868.rs index 62d5704029..c7093101f3 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/gh1868.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/gh1868.rs @@ -84,16 +84,17 @@ async fn test_gh1868_ipv6_rule_is_not_created_with_v4_only_destination( let mut rng = thread_rng(); // Create test location with both IPv4 and IPv6 subnet. - let mut location = WireguardNetwork::default(); + let mut location = WireguardNetwork::default() + .set_address([ + IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 0, 80, 1)), 24).unwrap(), + IpNetwork::new( + IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)), + 64, + ) + .unwrap(), + ]) + .unwrap(); location.acl_enabled = true; - location.set_address([ - IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 0, 80, 1)), 24).unwrap(), - IpNetwork::new( - IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)), - 64, - ) - .unwrap(), - ]); let location = location.save(&pool).await.unwrap(); // setup user & device @@ -145,16 +146,17 @@ async fn test_gh1868_ipv4_rule_is_not_created_with_v6_only_destination( let mut rng = thread_rng(); // Create test location with both IPv4 and IPv6 subnet. - let mut location = WireguardNetwork::default(); + let mut location = WireguardNetwork::default() + .set_address([ + IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 0, 80, 1)), 24).unwrap(), + IpNetwork::new( + IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)), + 64, + ) + .unwrap(), + ]) + .unwrap(); location.acl_enabled = true; - location.set_address([ - IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 0, 80, 1)), 24).unwrap(), - IpNetwork::new( - IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)), - 64, - ) - .unwrap(), - ]); let location = location.save(&pool).await.unwrap(); // setup user & device @@ -204,16 +206,17 @@ async fn test_gh1868_ipv4_and_ipv6_rules_are_created_with_any_destination( let mut rng = thread_rng(); // Create test location with both IPv4 and IPv6 subnet - let mut location = WireguardNetwork::default(); + let mut location = WireguardNetwork::default() + .set_address([ + IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 0, 80, 1)), 24).unwrap(), + IpNetwork::new( + IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)), + 64, + ) + .unwrap(), + ]) + .unwrap(); location.acl_enabled = true; - location.set_address([ - IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 0, 80, 1)), 24).unwrap(), - IpNetwork::new( - IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)), - 64, - ) - .unwrap(), - ]); let location = location.save(&pool).await.unwrap(); // setup user & device diff --git a/crates/defguard_core/src/enterprise/firewall/tests/mod.rs b/crates/defguard_core/src/enterprise/firewall/tests/mod.rs index dbe2d005db..258c459552 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/mod.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/mod.rs @@ -613,7 +613,7 @@ async fn test_generate_firewall_rules_ipv4(_: PgPoolOptions, options: PgConnectO ] ); - let expected_destination_addrs = vec![ + let expected_destination_addrs = [ IpAddress { address: Some(Address::Ip("10.0.1.13".to_string())), }, @@ -670,9 +670,10 @@ async fn test_generate_firewall_rules_ipv6(_: PgPoolOptions, options: PgConnectO let mut rng = thread_rng(); // Create test location - let mut location = WireguardNetwork::default(); + let mut location = WireguardNetwork::default() + .set_address(["fb00::1/112".parse().unwrap()]) + .unwrap(); location.acl_enabled = false; - location.set_address([IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap()]); let mut location = location.save(&pool).await.unwrap(); // Setup test users and their devices @@ -1121,12 +1122,13 @@ async fn test_generate_firewall_rules_ipv4_and_ipv6(_: PgPoolOptions, options: P let mut rng = thread_rng(); // Create test location - let mut location = WireguardNetwork::default(); + let mut location = WireguardNetwork::default() + .set_address([ + "192.168.0.1/24".parse().unwrap(), + "fb00::1/112".parse().unwrap(), + ]) + .unwrap(); location.acl_enabled = false; - location.set_address([ - IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap(), - IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap(), - ]); let mut location = location.save(&pool).await.unwrap(); // Setup test users and their devices @@ -1758,9 +1760,10 @@ async fn test_alias_kinds(_: PgPoolOptions, options: PgConnectOptions) { let mut rng = thread_rng(); // Create test location - let mut location = WireguardNetwork::default(); + let mut location = WireguardNetwork::default() + .set_address(["10.0.0.1/16".parse().unwrap()]) + .unwrap(); location.acl_enabled = true; - location.set_address(["10.0.0.0/16".parse().unwrap()]); let location = location.save(&pool).await.unwrap(); // Setup some test users and their devices @@ -1908,9 +1911,10 @@ async fn test_destination_alias_only_acl(_: PgPoolOptions, options: PgConnectOpt let mut rng = thread_rng(); // Create test location - let mut location = WireguardNetwork::default(); + let mut location = WireguardNetwork::default() + .set_address(["10.0.0.1/16".parse().unwrap()]) + .unwrap(); location.acl_enabled = true; - location.set_address(["10.0.0.0/16".parse().unwrap()]); let location = location.save(&pool).await.unwrap(); // Setup some test users and their devices @@ -2070,9 +2074,10 @@ async fn test_no_allowed_users_ipv4(_: PgPoolOptions, options: PgConnectOptions) let pool = setup_pool(options).await; // Create test location - let mut location = WireguardNetwork::default(); + let mut location = WireguardNetwork::default() + .set_address(["192.168.0.1/24".parse().unwrap()]) + .unwrap(); location.acl_enabled = true; - location.set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap()]); let location = location.save(&pool).await.unwrap(); // create ACL rules @@ -2129,20 +2134,23 @@ async fn test_empty_manual_destination_only_acl(_: PgPoolOptions, options: PgCon let mut rng = thread_rng(); // Create test locations with IPv4 and IPv6 addresses - let mut location_ipv4 = WireguardNetwork::default(); + let mut location_ipv4 = WireguardNetwork::default() + .set_address(["192.168.0.1/24".parse().unwrap()]) + .unwrap(); location_ipv4.acl_enabled = true; - location_ipv4.set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap()]); let location_ipv4 = location_ipv4.save(&pool).await.unwrap(); - let mut location_ipv6 = WireguardNetwork::default(); + let mut location_ipv6 = WireguardNetwork::default() + .set_address(["fb00::1/112".parse().unwrap()]) + .unwrap(); location_ipv6.acl_enabled = true; - location_ipv6.set_address([IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap()]); let location_ipv6 = location_ipv6.save(&pool).await.unwrap(); - let mut location_ipv4_and_ipv6 = WireguardNetwork::default(); + let mut location_ipv4_and_ipv6 = WireguardNetwork::default() + .set_address([ + "192.168.0.1/24".parse().unwrap(), + "fb00::1/112".parse().unwrap(), + ]) + .unwrap(); location_ipv4_and_ipv6.acl_enabled = true; - location_ipv4_and_ipv6.set_address([ - IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap(), - IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap(), - ]); let location_ipv4_and_ipv6 = location_ipv4_and_ipv6.save(&pool).await.unwrap(); // Setup some test users and their devices diff --git a/crates/defguard_core/src/enterprise/firewall/tests/unapplied_rules.rs b/crates/defguard_core/src/enterprise/firewall/tests/unapplied_rules.rs index 78a6e13b31..32be20ea7e 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/unapplied_rules.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/unapplied_rules.rs @@ -1,7 +1,4 @@ -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; - use defguard_common::db::{models::WireguardNetwork, setup_pool}; -use ipnetwork::IpNetwork; use rand::thread_rng; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; @@ -20,9 +17,10 @@ async fn test_unapplied_acl_rules_ipv4(_: PgPoolOptions, options: PgConnectOptio let mut rng = thread_rng(); // Create test location - let mut location = WireguardNetwork::default(); + let mut location = WireguardNetwork::default() + .set_address(["192.168.0.1/24".parse().unwrap()]) + .unwrap(); location.acl_enabled = true; - location.set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap()]); let location = location.save(&pool).await.unwrap(); // Setup some test users and their devices @@ -92,9 +90,10 @@ async fn test_unapplied_acl_rules_ipv6(_: PgPoolOptions, options: PgConnectOptio let mut rng = thread_rng(); // Create test location - let mut location = WireguardNetwork::default(); + let mut location = WireguardNetwork::default() + .set_address(["fb00::1/112".parse().unwrap()]) + .unwrap(); location.acl_enabled = true; - location.set_address([IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap()]); let location = location.save(&pool).await.unwrap(); // Setup some test users and their devices @@ -164,12 +163,13 @@ async fn test_unapplied_acl_rules_ipv4_and_ipv6(_: PgPoolOptions, options: PgCon let mut rng = thread_rng(); // Create test location - let mut location = WireguardNetwork::default(); + let mut location = WireguardNetwork::default() + .set_address([ + "192.168.0.1/24".parse().unwrap(), + "fb00::1/112".parse().unwrap(), + ]) + .unwrap(); location.acl_enabled = true; - location.set_address([ - IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap(), - IpNetwork::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).unwrap(), - ]); let location = location.save(&pool).await.unwrap(); // Setup some test users and their devices diff --git a/crates/defguard_core/src/enterprise/ldap/error.rs b/crates/defguard_core/src/enterprise/ldap/error.rs index b2ebb2981a..3b32d92486 100644 --- a/crates/defguard_core/src/enterprise/ldap/error.rs +++ b/crates/defguard_core/src/enterprise/ldap/error.rs @@ -1,5 +1,4 @@ use defguard_common::db::models::settings::SettingsSaveError; -use sqlx::error::Error as SqlxError; use thiserror::Error; #[derive(Debug, Error)] @@ -13,7 +12,7 @@ pub enum LdapError { #[error("Found multiple objects, expected one")] TooManyObjects, #[error("Database error: {0}")] - Database(#[from] SqlxError), + Database(#[from] sqlx::Error), #[error(transparent)] SettingsSave(#[from] SettingsSaveError), #[error("Expected different DN: {0}")] diff --git a/crates/defguard_core/src/enterprise/ldap/model.rs b/crates/defguard_core/src/enterprise/ldap/model.rs index cbb0ee9e40..77a9cc61e4 100644 --- a/crates/defguard_core/src/enterprise/ldap/model.rs +++ b/crates/defguard_core/src/enterprise/ldap/model.rs @@ -5,7 +5,7 @@ use defguard_common::db::{ models::{Settings, User}, }; use ldap3::{Mod, SearchEntry}; -use sqlx::{Error as SqlxError, PgExecutor}; +use sqlx::PgExecutor; use super::{LDAPConfig, error::LdapError}; use crate::{handlers::user::check_username, hashset}; @@ -250,7 +250,7 @@ pub(crate) fn maybe_update_rdn(user: &mut User) { pub(crate) async fn ldap_sync_allowed_for_user<'e, E>( user: &User, executor: E, -) -> Result +) -> sqlx::Result where E: PgExecutor<'e>, { @@ -263,9 +263,7 @@ where ) } -pub(super) async fn get_users_without_ldap_path<'e, E>( - executor: E, -) -> Result>, SqlxError> +pub(super) async fn get_users_without_ldap_path<'e, E>(executor: E) -> sqlx::Result>> where E: PgExecutor<'e>, { diff --git a/crates/defguard_core/src/enterprise/license.rs b/crates/defguard_core/src/enterprise/license.rs index 70f2562d29..5b04067785 100644 --- a/crates/defguard_core/src/enterprise/license.rs +++ b/crates/defguard_core/src/enterprise/license.rs @@ -21,7 +21,7 @@ use pgp::{ types::KeyDetails, }; use prost::Message; -use sqlx::{PgPool, error::Error as SqlxError}; +use sqlx::PgPool; use thiserror::Error; use tokio::time::sleep; @@ -50,7 +50,7 @@ pub enum LicenseError { #[error("Provided signature is invalid")] InvalidSignature, #[error("Database error")] - DbError(#[from] SqlxError), + DbError(#[from] sqlx::Error), #[error(transparent)] SettingsSave(#[from] SettingsSaveError), #[error("License decoding error: {0}")] diff --git a/crates/defguard_core/src/enterprise/limits.rs b/crates/defguard_core/src/enterprise/limits.rs index 46f3a43007..f692d4ccc7 100644 --- a/crates/defguard_core/src/enterprise/limits.rs +++ b/crates/defguard_core/src/enterprise/limits.rs @@ -1,6 +1,6 @@ use defguard_common::global_value; use serde::Serialize; -use sqlx::{error::Error as SqlxError, query}; +use sqlx::query; use super::license::License; #[cfg(test)] @@ -18,7 +18,7 @@ global_value!(COUNTS, Counts, Counts::default(), set_counts, get_counts); /// Update the counts of users, devices, and wireguard networks stored in the memory. // TODO: Use it with database triggers when they are implemented -pub async fn update_counts<'e, E: sqlx::PgExecutor<'e>>(executor: E) -> Result<(), SqlxError> { +pub async fn update_counts<'e, E: sqlx::PgExecutor<'e>>(executor: E) -> sqlx::Result<()> { debug!("Updating device, user, and wireguard network counts."); let result = query!( "SELECT \ diff --git a/crates/defguard_core/src/error.rs b/crates/defguard_core/src/error.rs index 754bc736c3..1067667439 100644 --- a/crates/defguard_core/src/error.rs +++ b/crates/defguard_core/src/error.rs @@ -92,6 +92,9 @@ pub enum WebError { StaticIpError(#[from] StaticIpError), #[error("Network full: {0}")] NetworkFull(String), + #[error(transparent)] + #[schema(value_type=Object)] + IpNetwork(#[from] ipnetwork::IpNetworkError), } impl From for WebError { diff --git a/crates/defguard_core/src/handlers/mod.rs b/crates/defguard_core/src/handlers/mod.rs index 7367f0a66e..f8a109c2ad 100644 --- a/crates/defguard_core/src/handlers/mod.rs +++ b/crates/defguard_core/src/handlers/mod.rs @@ -15,6 +15,7 @@ use defguard_common::{ types::user_info::UserInfo, }; use defguard_static_ip::error::StaticIpError; +use ipnetwork::IpNetworkError; use serde_json::{Value, json}; use sqlx::PgPool; use utoipa::ToSchema; @@ -285,6 +286,16 @@ impl From for ApiResponse { ) } }, + WebError::IpNetwork(err) => match err { + IpNetworkError::InvalidAddr(msg) | IpNetworkError::InvalidCidrFormat(msg) => { + warn!(msg); + ApiResponse::new(json!({"msg": msg}), StatusCode::BAD_REQUEST) + } + IpNetworkError::InvalidPrefix => { + warn!("Invalid prefix"); + ApiResponse::new(json!({"msg": "invalid prefix"}), StatusCode::BAD_REQUEST) + } + }, } } } diff --git a/crates/defguard_core/src/handlers/ssh_authorized_keys.rs b/crates/defguard_core/src/handlers/ssh_authorized_keys.rs index a88747eac7..162e543560 100644 --- a/crates/defguard_core/src/handlers/ssh_authorized_keys.rs +++ b/crates/defguard_core/src/handlers/ssh_authorized_keys.rs @@ -7,7 +7,7 @@ use defguard_common::db::{ Id, models::{AuthenticationKey, AuthenticationKeyType, User, group::Group}, }; -use sqlx::{Error as SqlxError, PgExecutor, PgPool, query}; +use sqlx::{PgExecutor, PgPool, query}; use ssh_key::PublicKey; use super::{ApiResponse, ApiResult, user_for_admin_or_self}; @@ -31,7 +31,7 @@ pub(crate) struct AuthenticationKeyInfo { } impl AuthenticationKeyInfo { - pub async fn find_by_user_id<'e, E>(executor: E, user_id: Id) -> Result, SqlxError> + pub async fn find_by_user_id<'e, E>(executor: E, user_id: Id) -> sqlx::Result> where E: PgExecutor<'e>, { diff --git a/crates/defguard_core/src/handlers/user.rs b/crates/defguard_core/src/handlers/user.rs index ddd8a1848d..d1216553b3 100644 --- a/crates/defguard_core/src/handlers/user.rs +++ b/crates/defguard_core/src/handlers/user.rs @@ -17,7 +17,7 @@ use defguard_common::{ use defguard_mail::{Mail, templates}; use humantime::parse_duration; use serde_json::json; -use sqlx::{Error as SqlxError, PgPool}; +use sqlx::PgPool; use utoipa::ToSchema; use super::{ @@ -134,7 +134,7 @@ pub struct UserDetails { } impl UserDetails { - pub async fn from_user(pool: &PgPool, user: &User) -> Result { + pub async fn from_user(pool: &PgPool, user: &User) -> sqlx::Result { let devices = user.user_devices(pool).await?; let security_keys = user.security_keys(pool).await?; let biometric_enabled_devices = BiometricAuth::find_by_user_id(pool, user.id) diff --git a/crates/defguard_core/src/handlers/wireguard.rs b/crates/defguard_core/src/handlers/wireguard.rs index bbbc3ec06d..44f50750f1 100644 --- a/crates/defguard_core/src/handlers/wireguard.rs +++ b/crates/defguard_core/src/handlers/wireguard.rs @@ -247,8 +247,8 @@ pub(crate) async fn create_network( data.acl_default_allow, data.location_mfa_mode, data.service_location_mode, - ); - network.set_address(parse_address_list(&data.address)); + ) + .try_set_address(&data.address)?; network.mtu = data.mtu; network.fwmark = data.fwmark; network.keepalive_interval = data.keepalive_interval; @@ -344,11 +344,11 @@ pub(crate) async fn modify_network( data.validate_peer_disconnect_threshold()?; data.validate_location_mfa_mode(&appstate.pool).await?; - let mut network = find_network(network_id, &appstate.pool).await?; + let network = find_network(network_id, &appstate.pool).await?; // store network before mods let before = network.clone(); let new_addresses = data.parse_addresses()?; - network.set_address(new_addresses); + let mut network = network.set_address(new_addresses)?; network.allowed_ips = data.parse_allowed_ips(); network.name = data.name; diff --git a/crates/defguard_core/src/lib.rs b/crates/defguard_core/src/lib.rs index 41cbf0a813..96b74915ca 100644 --- a/crates/defguard_core/src/lib.rs +++ b/crates/defguard_core/src/lib.rs @@ -767,8 +767,9 @@ pub async fn init_dev_env(config: &DefGuardConfig) { false, LocationMfaMode::Disabled, ServiceLocationMode::Disabled, - ); - network.set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 1, 1, 1)), 24).unwrap()]); + ) + .set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 1, 1, 1)), 24).unwrap()]) + .unwrap(); network.pubkey = "zGMeVGm9HV9I4wSKF9AXmYnnAIhDySyqLMuKpcfIaQo=".to_string(); network.prvkey = "MAk3d5KuB167G88HM7nGYR6ksnPMAOguAg2s5EcPp1M=".to_string(); network @@ -845,7 +846,7 @@ pub async fn init_vpn_location( WireguardNetwork::find_by_id(&mut *transaction, location_id).await? { network.name.clone_from(&args.name); - network.set_address([args.address]); + let mut network = network.set_address([args.address])?; network.port = args.port; network.endpoint.clone_from(&args.endpoint); network.dns.clone_from(&args.dns); @@ -867,8 +868,8 @@ pub async fn init_vpn_location( false, LocationMfaMode::Disabled, ServiceLocationMode::Disabled, - ); - network.set_address([args.address]); + ) + .set_address([args.address])?; network.mtu = args.mtu as i32; network.fwmark = i64::from(args.fwmark); let network = network.save(&mut *transaction).await?; @@ -907,8 +908,8 @@ pub async fn init_vpn_location( false, LocationMfaMode::Disabled, ServiceLocationMode::Disabled, - ); - location.set_address([args.address]); + ) + .set_address([args.address])?; location.mtu = args.mtu as i32; location.fwmark = i64::from(args.fwmark); location.save(pool).await? diff --git a/crates/defguard_core/src/location_management/allowed_peers.rs b/crates/defguard_core/src/location_management/allowed_peers.rs index 49827834c4..bce9fa761e 100644 --- a/crates/defguard_core/src/location_management/allowed_peers.rs +++ b/crates/defguard_core/src/location_management/allowed_peers.rs @@ -1,6 +1,6 @@ use defguard_common::db::{Id, models::WireguardNetwork}; use defguard_proto::gateway::Peer; -use sqlx::{Error as SqlxError, PgExecutor, query}; +use sqlx::{PgExecutor, query}; use crate::grpc::should_prevent_service_location_usage; @@ -13,7 +13,7 @@ use crate::grpc::should_prevent_service_location_usage; pub async fn get_location_allowed_peers<'e, E>( location: &WireguardNetwork, executor: E, -) -> Result, SqlxError> +) -> sqlx::Result> where E: PgExecutor<'e>, { diff --git a/crates/defguard_core/src/location_management/tests.rs b/crates/defguard_core/src/location_management/tests.rs index c640900b7b..a3adb23327 100644 --- a/crates/defguard_core/src/location_management/tests.rs +++ b/crates/defguard_core/src/location_management/tests.rs @@ -22,10 +22,11 @@ fn test_network_readdress(_: PgPoolOptions, options: PgConnectOptions) { // 192.168.42.45: device // 192.168.42.46: gateway // 192.168.42.47: broadcast - let mut network = WireguardNetwork::default(); - network.set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::new(192, 168, 42, 46)), 30).unwrap()]); + let mut network = WireguardNetwork::default() + .set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::new(192, 168, 42, 46)), 30).unwrap()]) + .unwrap(); network.allow_all_groups = true; - let mut network = network.save(&pool).await.unwrap(); + let network = network.save(&pool).await.unwrap(); let mut conn = pool.begin().await.unwrap(); @@ -68,7 +69,9 @@ fn test_network_readdress(_: PgPoolOptions, options: PgConnectOptions) { // 192.168.42.77: gateway // 192.168.42.78: device // 192.168.42.79: broadcast - network.set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::new(192, 168, 42, 77)), 30).unwrap()]); + let network = network + .set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::new(192, 168, 42, 77)), 30).unwrap()]) + .unwrap(); network.save(&pool).await.unwrap(); // Re-address the network. diff --git a/crates/defguard_core/src/wg_config.rs b/crates/defguard_core/src/wg_config.rs index b1c0ead6d8..c063939ca4 100644 --- a/crates/defguard_core/src/wg_config.rs +++ b/crates/defguard_core/src/wg_config.rs @@ -116,8 +116,8 @@ pub(crate) fn parse_wireguard_config( false, LocationMfaMode::Disabled, ServiceLocationMode::Disabled, - ); - network.set_address(addresses.clone()); + ) + .set_address(addresses.clone())?; network.mtu = mtu; network.fwmark = fwmark; network.pubkey = pubkey; @@ -254,7 +254,7 @@ mod test { let config = " [Interface] PrivateKey = GAA2X3DW0WakGVx+DsGjhDpTgg50s1MlmrLf24Psrlg= - Address = 10.0.0.1/24,fc00::/112 + Address = 10.0.0.1/24,fc00::1/112 ListenPort = 55055 DNS = 10.0.0.2 @@ -279,7 +279,7 @@ mod test { network.address(), [ "10.0.0.1/24".parse().unwrap(), - "fc00::/112".parse().unwrap() + "fc00::1/112".parse().unwrap() ] ); assert_eq!(network.port, 55055); diff --git a/crates/defguard_core/tests/integration/api/wireguard.rs b/crates/defguard_core/tests/integration/api/wireguard.rs index c1e852df6a..dfe7bb537b 100644 --- a/crates/defguard_core/tests/integration/api/wireguard.rs +++ b/crates/defguard_core/tests/integration/api/wireguard.rs @@ -63,7 +63,7 @@ async fn test_network(_: PgPoolOptions, options: PgConnectOptions) { // modify network let network_data = WireguardNetworkData { name: "my network".into(), - address: "10.1.1.0/24".into(), + address: "10.1.1.1/24".into(), endpoint: "10.1.1.1".parse().unwrap(), port: 55555, allowed_ips: Some("10.1.1.0/24, 10.2.0.1/16, 10.10.10.54/32".into()), @@ -146,7 +146,7 @@ async fn test_location_mfa_mode_validation_create(_: PgPoolOptions, options: PgC let location_data = WireguardNetworkData { name: "test_location".into(), - address: "10.1.1.0/24".into(), + address: "10.1.1.1/24".into(), endpoint: "10.1.1.1".parse().unwrap(), port: 55555, allowed_ips: Some("10.1.1.0/24, 10.2.0.1/16, 10.10.10.54/32".into()), @@ -231,7 +231,7 @@ async fn test_location_mfa_mode_validation_modify(_: PgPoolOptions, options: PgC let mut location_data = WireguardNetworkData { name: "test_location".into(), - address: "10.1.1.0/24".into(), + address: "10.1.1.254/24".into(), endpoint: "10.1.1.1".parse().unwrap(), port: 55555, allowed_ips: Some("10.1.1.0/24, 10.2.0.1/16, 10.10.10.54/32".into()), @@ -334,7 +334,7 @@ async fn test_peer_disconnect_threshold_validation_create( let mut location_data = WireguardNetworkData { name: "test_location_disabled".into(), - address: "10.1.1.0/24".into(), + address: "10.1.1.1/24".into(), endpoint: "10.1.1.1".parse().unwrap(), port: 55555, allowed_ips: Some("10.1.1.0/24, 10.2.0.1/16, 10.10.10.54/32".into()), @@ -389,7 +389,7 @@ async fn test_peer_disconnect_threshold_validation_modify( let mut location_data = WireguardNetworkData { name: "test_location".into(), - address: "10.1.1.0/24".into(), + address: "10.1.1.1/24".into(), endpoint: "10.1.1.1".parse().unwrap(), port: 55555, allowed_ips: Some("10.1.1.0/24, 10.2.0.1/16, 10.10.10.54/32".into()), diff --git a/crates/defguard_core/tests/integration/api/wireguard_network_import.rs b/crates/defguard_core/tests/integration/api/wireguard_network_import.rs index 599ce237ca..eb169bd7c3 100644 --- a/crates/defguard_core/tests/integration/api/wireguard_network_import.rs +++ b/crates/defguard_core/tests/integration/api/wireguard_network_import.rs @@ -47,7 +47,7 @@ async fn test_config_import(_: PgPoolOptions, options: PgConnectOptions) { let pool = client_state.pool; // setup initial network - let mut initial_network = WireguardNetwork::new( + WireguardNetwork::new( "initial".into(), 51515, String::new(), @@ -58,9 +58,12 @@ async fn test_config_import(_: PgPoolOptions, options: PgConnectOptions) { false, LocationMfaMode::Disabled, ServiceLocationMode::Disabled, - ); - initial_network.set_address(["10.1.9.0/24".parse().unwrap()]); - initial_network.save(&pool).await.unwrap(); + ) + .set_address(["10.1.9.1/24".parse().unwrap()]) + .unwrap() + .save(&pool) + .await + .unwrap(); // add existing devices let mut transaction = pool.begin().await.unwrap(); diff --git a/crates/defguard_session_manager/tests/common/mod.rs b/crates/defguard_session_manager/tests/common/mod.rs index 84aa0b8219..d1572c1791 100644 --- a/crates/defguard_session_manager/tests/common/mod.rs +++ b/crates/defguard_session_manager/tests/common/mod.rs @@ -116,7 +116,7 @@ pub(crate) async fn create_location_with_mfa_mode( pool: &sqlx::PgPool, location_mfa_mode: LocationMfaMode, ) -> WireguardNetwork { - let mut location = WireguardNetwork::new( + WireguardNetwork::new( "TestNet".to_string(), 51820, "10.0.0.1".to_string(), @@ -127,12 +127,12 @@ pub(crate) async fn create_location_with_mfa_mode( false, location_mfa_mode, ServiceLocationMode::Disabled, - ); - location.set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 0)), 24).unwrap()]); - location - .save(pool) - .await - .expect("failed to create WireGuard location") + ) + .set_address([IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 24).unwrap()]) + .unwrap() + .save(pool) + .await + .expect("failed to create WireGuard location") } pub(crate) async fn create_user(pool: &sqlx::PgPool) -> User { @@ -175,7 +175,7 @@ pub(crate) async fn attach_device_to_location(pool: &sqlx::PgPool, location_id: let network_device = WireguardNetworkDevice::new( location_id, device_id, - vec![IpAddr::V4(Ipv4Addr::new(10, 0, 0, 10))], + [IpAddr::V4(Ipv4Addr::new(10, 0, 0, 10))], ); network_device .insert(pool) diff --git a/crates/defguard_setup/src/auto_adoption.rs b/crates/defguard_setup/src/auto_adoption.rs index a02904dd47..e004bb2a6b 100644 --- a/crates/defguard_setup/src/auto_adoption.rs +++ b/crates/defguard_setup/src/auto_adoption.rs @@ -733,7 +733,7 @@ id={} for new gateway", .context("Failed to parse default auto-adoption network address")?; let mut transaction = pool.begin().await.context("Failed to begin transaction")?; - let mut network = WireguardNetwork::new( + let network = WireguardNetwork::new( common_name.to_string(), DEFAULT_AUTO_ADOPTION_WIREGUARD_PORT, host.to_string(), @@ -744,12 +744,11 @@ id={} for new gateway", false, LocationMfaMode::Disabled, ServiceLocationMode::Disabled, - ); - network.set_address([network_address]); - let network = network - .save(&mut *transaction) - .await - .context("Failed to save auto-adopted WireguardNetwork")?; + ) + .set_address([network_address])? + .save(&mut *transaction) + .await + .context("Failed to save auto-adopted WireguardNetwork")?; network .add_all_allowed_devices(&mut transaction) diff --git a/crates/defguard_setup/src/handlers/auto_wizard.rs b/crates/defguard_setup/src/handlers/auto_wizard.rs index 0c542a27ef..84b7bcd6f0 100644 --- a/crates/defguard_setup/src/handlers/auto_wizard.rs +++ b/crates/defguard_setup/src/handlers/auto_wizard.rs @@ -137,7 +137,7 @@ pub async fn set_vpn_settings( network.endpoint = vpn_settings.public_ip; network.port = vpn_settings.wireguard_port; - network.set_address(addresses); + let mut network = network.set_address(addresses)?; network.allowed_ips = allowed_ips; network.dns = { let dns = vpn_settings.dns_server_ip.trim(); diff --git a/crates/defguard_setup/tests/auto_adoption_wizard.rs b/crates/defguard_setup/tests/auto_adoption_wizard.rs index 2bd5562a82..9dedb9c1ec 100644 --- a/crates/defguard_setup/tests/auto_adoption_wizard.rs +++ b/crates/defguard_setup/tests/auto_adoption_wizard.rs @@ -51,8 +51,9 @@ async fn seed_wireguard_network(pool: &sqlx::PgPool) -> WireguardNetwork { false, LocationMfaMode::Disabled, ServiceLocationMode::Disabled, - ); - location.set_address(["10.0.0.0/24".parse::().unwrap()]); + ) + .set_address(["10.0.0.1/24".parse::().unwrap()]) + .unwrap(); location.mtu = 1280; location .save(pool) diff --git a/crates/defguard_setup/tests/wizard_state.rs b/crates/defguard_setup/tests/wizard_state.rs index eeaa8c5546..1640fa448b 100644 --- a/crates/defguard_setup/tests/wizard_state.rs +++ b/crates/defguard_setup/tests/wizard_state.rs @@ -149,8 +149,9 @@ async fn test_wizard_state_auto_adoption(_: PgPoolOptions, options: PgConnectOpt false, LocationMfaMode::Disabled, ServiceLocationMode::Disabled, - ); - location.set_address(["10.0.0.0/24".parse().unwrap()]); + ) + .set_address(["10.0.0.1/24".parse().unwrap()]) + .unwrap(); location.mtu = 1280; location .save(&pool) From 870d7929f203a527a96a2e97ff0704b60e9dffda Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Tue, 17 Mar 2026 14:01:40 +0100 Subject: [PATCH 05/10] Move wireguard tests to a separate file --- .../src/db/models/wireguard.rs | 648 +----------------- .../src/db/models/wireguard/tests.rs | 583 ++++++++++++++++ 2 files changed, 616 insertions(+), 615 deletions(-) create mode 100644 crates/defguard_common/src/db/models/wireguard/tests.rs diff --git a/crates/defguard_common/src/db/models/wireguard.rs b/crates/defguard_common/src/db/models/wireguard.rs index 3fde5cf255..65afd3f5ca 100644 --- a/crates/defguard_common/src/db/models/wireguard.rs +++ b/crates/defguard_common/src/db/models/wireguard.rs @@ -587,7 +587,8 @@ impl WireguardNetwork { let stats = query_as!( WireguardDeviceTransferRow, "SELECT s.device_id, date_trunc($1, collected_at) \"collected_at!: NaiveDateTime\", \ - CAST(sum(download_diff) AS bigint) \"download!\", CAST(sum(upload_diff) AS bigint) \"upload!\" \ + CAST(sum(download_diff) AS bigint) \"download!\", \ + CAST(sum(upload_diff) AS bigint) \"upload!\" \ FROM vpn_session_stats \ INNER JOIN vpn_client_session s ON session_id = s.id \ WHERE s.device_id = ANY($2) AND collected_at >= $3 AND s.location_id = $4 \ @@ -1135,10 +1136,10 @@ impl WireguardNetwork { { query_as!( Device, - "SELECT \ - id, name, wireguard_pubkey, user_id, created, description, device_type \"device_type: DeviceType\", \ - configured \ - FROM device WHERE id in (SELECT device_id FROM wireguard_network_device WHERE wireguard_network_id = $1) \ + "SELECT id, name, wireguard_pubkey, user_id, created, description, \ + device_type \"device_type: DeviceType\", configured \ + FROM device WHERE id in (SELECT device_id \ + FROM wireguard_network_device WHERE wireguard_network_id = $1) \ AND device_type = $2", self.id, device_type as DeviceType @@ -1240,8 +1241,9 @@ impl WireguardNetwork { let locations = query_as!( WireguardNetwork, "SELECT id, name, address, port, pubkey, prvkey, endpoint, dns, mtu, fwmark, \ - allowed_ips, allow_all_groups, connected_at, keepalive_interval, peer_disconnect_threshold, acl_enabled, \ - acl_default_allow, location_mfa_mode \"location_mfa_mode: LocationMfaMode\", \ + allowed_ips, allow_all_groups, connected_at, keepalive_interval, \ + peer_disconnect_threshold, acl_enabled, acl_default_allow, \ + location_mfa_mode \"location_mfa_mode: LocationMfaMode\", \ service_location_mode \"service_location_mode: ServiceLocationMode\" \ FROM wireguard_network WHERE location_mfa_mode = 'external'::location_mfa_mode", ) @@ -1391,8 +1393,8 @@ impl WireguardNetwork { ) -> sqlx::Result>> { query_as!( VpnClientSession, - "SELECT id, location_id, user_id, device_id, \ - created_at, connected_at, disconnected_at, mfa_method \"mfa_method: VpnClientMfaMethod\", \ + "SELECT id, location_id, user_id, device_id, created_at, connected_at, \ + disconnected_at, mfa_method \"mfa_method: VpnClientMfaMethod\", \ state \"state: VpnClientSessionState\" \ FROM vpn_client_session \ WHERE location_id = $1 AND state = 'connected'::vpn_client_session_state", @@ -1554,12 +1556,12 @@ pub async fn networks_stats( let total_activity = query_as!( WireguardNetworkActivityStats, "SELECT \ - COALESCE(COUNT(DISTINCT CASE WHEN d.device_type = 'user' THEN s.user_id END), 0) \"active_users!\", \ - COALESCE(COUNT(DISTINCT CASE WHEN d.device_type = 'user' THEN d.id END), 0) \"active_user_devices!\", \ - COALESCE(COUNT(DISTINCT CASE WHEN d.device_type = 'network' THEN d.id END), 0) \"active_network_devices!\" \ - FROM vpn_client_session s \ - LEFT JOIN device d ON d.id = s.device_id \ - WHERE s.state = 'connected' OR (s.state = 'disconnected' AND s.disconnected_at >= $1)", + COALESCE(COUNT(DISTINCT CASE WHEN d.device_type = 'user' THEN s.user_id END), 0) \"active_users!\", \ + COALESCE(COUNT(DISTINCT CASE WHEN d.device_type = 'user' THEN d.id END), 0) \"active_user_devices!\", \ + COALESCE(COUNT(DISTINCT CASE WHEN d.device_type = 'network' THEN d.id END), 0) \"active_network_devices!\" \ + FROM vpn_client_session s \ + LEFT JOIN device d ON d.id = s.device_id \ + WHERE s.state = 'connected' OR (s.state = 'disconnected' AND s.disconnected_at >= $1)", from ) .fetch_one(pool) @@ -1569,12 +1571,12 @@ pub async fn networks_stats( let current_activity = query_as!( WireguardNetworkActivityStats, "SELECT \ - COALESCE(COUNT(DISTINCT CASE WHEN d.device_type = 'user' THEN s.user_id END), 0) \"active_users!\", \ - COALESCE(COUNT(DISTINCT CASE WHEN d.device_type = 'user' THEN d.id END), 0) \"active_user_devices!\", \ - COALESCE(COUNT(DISTINCT CASE WHEN d.device_type = 'network' THEN d.id END), 0) \"active_network_devices!\" \ - FROM vpn_client_session s \ - LEFT JOIN device d ON d.id = s.device_id \ - WHERE s.state = 'connected'", + COALESCE(COUNT(DISTINCT CASE WHEN d.device_type = 'user' THEN s.user_id END), 0) \"active_users!\", \ + COALESCE(COUNT(DISTINCT CASE WHEN d.device_type = 'user' THEN d.id END), 0) \"active_user_devices!\", \ + COALESCE(COUNT(DISTINCT CASE WHEN d.device_type = 'network' THEN d.id END), 0) \"active_network_devices!\" \ + FROM vpn_client_session s \ + LEFT JOIN device d ON d.id = s.device_id \ + WHERE s.state = 'connected'", ) .fetch_one(pool) .await?; @@ -1582,15 +1584,15 @@ pub async fn networks_stats( // get transfer series for specified time window let transfer_series = query_as!( WireguardStatsRow, - "SELECT \ - date_trunc($1, collected_at) \"collected_at: NaiveDateTime\", \ - cast(sum(upload_diff) AS bigint) upload, cast(sum(download_diff) AS bigint) download \ - FROM vpn_session_stats \ - JOIN vpn_client_session s ON session_id = s.id \ - WHERE collected_at >= $2 \ - GROUP BY 1 \ - ORDER BY 1 \ - LIMIT $3", + "SELECT \ + date_trunc($1, collected_at) \"collected_at: NaiveDateTime\", \ + cast(sum(upload_diff) AS bigint) upload, cast(sum(download_diff) AS bigint) download \ + FROM vpn_session_stats \ + JOIN vpn_client_session s ON session_id = s.id \ + WHERE collected_at >= $2 \ + GROUP BY 1 \ + ORDER BY 1 \ + LIMIT $3", aggregation.fstring(), from, PEER_STATS_LIMIT, @@ -1611,588 +1613,4 @@ pub async fn networks_stats( } #[cfg(test)] -mod test { - use std::str::FromStr; - - use matches::assert_matches; - use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; - - use super::*; - use crate::db::setup_pool; - - // FIXME(mwojcik): rewrite for new stats implementation - // #[sqlx::test] - // async fn test_connected_at_reconnection(_: PgPoolOptions, options: PgConnectOptions) { - // let pool = setup_pool(options).await; - // let mut location = WireguardNetwork::default(); - // location.try_set_address("10.1.1.1/29").unwrap(); - // let location = location.save(&pool).await.unwrap(); - - // let user = User::new( - // "testuser", - // Some("hunter2"), - // "Tester", - // "Test", - // "test@test.com", - // None, - // ) - // .save(&pool) - // .await - // .unwrap(); - // let device = Device::new( - // String::new(), - // String::new(), - // user.id, - // DeviceType::User, - // None, - // true, - // ) - // .save(&pool) - // .await - // .unwrap(); - - // // insert stats - // let samples = 60; // 1 hour of samples - // let now = Utc::now().naive_utc(); - // for i in 0..=samples { - // // simulate connection 30 minutes ago - // let handshake_minutes = i * if i < 31 { 1 } else { 10 }; - // WireguardPeerStats { - // id: NoId, - // device_id: device.id, - // collected_at: now - TimeDelta::minutes(i), - // network: location.id, - // endpoint: Some("11.22.33.44".into()), - // upload: (samples - i) * 10, - // download: (samples - i) * 20, - // latest_handshake: now - TimeDelta::minutes(handshake_minutes), - // allowed_ips: Some("10.1.1.0/24".into()), - // } - // .save(&pool) - // .await - // .unwrap(); - // } - - // let connected_at = device - // .last_connected_at(&pool, location.id) - // .await - // .unwrap() - // .unwrap(); - // assert_eq!( - // connected_at, - // // PostgreSQL stores 6 sub-second digits while chrono stores 9. - // (now - TimeDelta::minutes(30)).trunc_subsecs(6), - // ); - // } - - // FIXME(mwojcik): rewrite for new stats implementation - // #[sqlx::test] - // async fn test_connected_at_always_connected(_: PgPoolOptions, options: PgConnectOptions) { - // let pool = setup_pool(options).await; - // let mut location = WireguardNetwork::default(); - // location.try_set_address("10.1.1.1/29").unwrap(); - // let location = location.save(&pool).await.unwrap(); - - // let user = User::new( - // "testuser", - // Some("hunter2"), - // "Tester", - // "Test", - // "test@test.com", - // None, - // ) - // .save(&pool) - // .await - // .unwrap(); - // let device = Device::new( - // String::new(), - // String::new(), - // user.id, - // DeviceType::User, - // None, - // true, - // ) - // .save(&pool) - // .await - // .unwrap(); - - // // insert stats - // let samples = 60; // 1 hour of samples - // let now = Utc::now().naive_utc(); - // for i in 0..=samples { - // WireguardPeerStats { - // id: NoId, - // device_id: device.id, - // collected_at: now - TimeDelta::minutes(i), - // network: location.id, - // endpoint: Some("11.22.33.44".into()), - // upload: (samples - i) * 10, - // download: (samples - i) * 20, - // latest_handshake: now - TimeDelta::minutes(i), // handshake every minute - // allowed_ips: Some("10.1.1.0/24".into()), - // } - // .save(&pool) - // .await - // .unwrap(); - // } - - // let connected_at = device - // .last_connected_at(&pool, location.id) - // .await - // .unwrap() - // .unwrap(); - // assert_eq!( - // connected_at, - // // PostgreSQL stores 6 sub-second digits while chrono stores 9. - // (now - TimeDelta::minutes(samples)).trunc_subsecs(6), - // ); - // } - - #[sqlx::test] - async fn test_get_allowed_devices_for_user(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - let mut network = WireguardNetwork::default() - .try_set_address("10.1.1.1/29") - .unwrap(); - network.allow_all_groups = true; - let network = network.save(&pool).await.unwrap(); - - let user1 = User::new( - "user1", - Some("pass1"), - "Test", - "User1", - "user1@test.com", - None, - ) - .save(&pool) - .await - .unwrap(); - - let user2 = User::new( - "user2", - Some("pass2"), - "Test", - "User2", - "user2@test.com", - None, - ) - .save(&pool) - .await - .unwrap(); - - let device1 = Device::new( - "device1".into(), - "key1".into(), - user1.id, - DeviceType::User, - None, - true, - ) - .save(&pool) - .await - .unwrap(); - - let device2 = Device::new( - "device2".into(), - "key2".into(), - user1.id, - DeviceType::User, - None, - true, - ) - .save(&pool) - .await - .unwrap(); - - let device3 = Device::new( - "device3".into(), - "key3".into(), - user2.id, - DeviceType::User, - None, - true, - ) - .save(&pool) - .await - .unwrap(); - - let devices = network - .get_allowed_devices_for_user(&mut pool.acquire().await.unwrap(), user1.id) - .await - .unwrap(); - assert_eq!(devices.len(), 2); - assert!(devices.iter().any(|d| d.id == device1.id)); - assert!(devices.iter().any(|d| d.id == device2.id)); - - let devices = network - .get_allowed_devices_for_user(&mut pool.acquire().await.unwrap(), user2.id) - .await - .unwrap(); - assert_eq!(devices.len(), 1); - assert!(devices.iter().any(|d| d.id == device3.id)); - - let devices = network - .get_allowed_devices_for_user(&mut pool.acquire().await.unwrap(), Id::from(999)) - .await - .unwrap(); - assert!(devices.is_empty()); - } - - #[sqlx::test] - async fn test_get_allowed_devices_for_user_with_groups( - _: PgPoolOptions, - options: PgConnectOptions, - ) { - let pool = setup_pool(options).await; - let network = WireguardNetwork::default() - .try_set_address("10.1.1.1/29") - .unwrap() - .save(&pool) - .await - .unwrap(); - - let user1 = User::new( - "user1", - Some("pass1"), - "Test", - "User1", - "user1@test.com", - None, - ) - .save(&pool) - .await - .unwrap(); - - let user2 = User::new( - "user2", - Some("pass2"), - "Test", - "User2", - "user2@test.com", - None, - ) - .save(&pool) - .await - .unwrap(); - - let group1 = Group::new("group1").save(&pool).await.unwrap(); - let group2 = Group::new("group2").save(&pool).await.unwrap(); - - user1.add_to_group(&pool, &group1).await.unwrap(); - user2.add_to_group(&pool, &group2).await.unwrap(); - - let device1 = Device::new( - "device1".into(), - "key1".into(), - user1.id, - DeviceType::User, - None, - true, - ) - .save(&pool) - .await - .unwrap(); - - Device::new( - "device2".into(), - "key2".into(), - user2.id, - DeviceType::User, - None, - true, - ) - .save(&pool) - .await - .unwrap(); - - let mut transaction = pool.begin().await.unwrap(); - - network - .set_allowed_groups(&mut transaction, &[group1.name]) - .await - .unwrap(); - - let devices = network - .get_allowed_devices_for_user(&mut transaction, user1.id) - .await - .unwrap(); - assert_eq!(devices.len(), 1); - assert_eq!(devices[0].id, device1.id); - - let devices = network - .get_allowed_devices_for_user(&mut transaction, user2.id) - .await - .unwrap(); - assert!(devices.is_empty()); - } - - #[sqlx::test] - async fn test_can_assign_ips(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - - let network = WireguardNetwork::new( - "network".to_string(), - 50051, - String::new(), - None, - [IpNetwork::from_str("10.1.1.0/24").unwrap()], - false, - false, - false, - LocationMfaMode::Disabled, - ServiceLocationMode::Disabled, - ) - .set_address([IpNetwork::from_str("10.1.1.1/24").unwrap()]) - .unwrap() - .save(&pool) - .await - .unwrap(); - - // assign free address - let addrs = vec![IpAddr::from_str("10.1.1.2").unwrap()]; - assert_matches!( - network - .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, None) - .await, - Ok(()) - ); - - // assign multiple free addresses - let addrs = vec![ - IpAddr::from_str("10.1.1.2").unwrap(), - IpAddr::from_str("10.1.1.3").unwrap(), - ]; - assert_matches!( - network - .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, None) - .await, - Ok(()) - ); - - // try to assign address from another network - let addrs = vec![IpAddr::from_str("10.2.1.2").unwrap()]; - assert_matches!( - network - .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, None) - .await, - Err(NetworkAddressError::NoContainingNetwork(..)) - ); - - // try to assign already assigned address - let user = User::new( - "hpotter", - Some("pass123"), - "Potter", - "Harry", - "h.potter@hogwart.edu.uk", - None, - ) - .save(&pool) - .await - .unwrap(); - - let device = Device::new( - "device".to_string(), - String::new(), - user.id, - DeviceType::User, - None, - true, - ) - .save(&pool) - .await - .unwrap(); - WireguardNetworkDevice::new( - network.id, - device.id, - vec![IpAddr::from_str("10.1.1.2").unwrap()], - ) - .insert(&pool) - .await - .unwrap(); - let addrs = vec![IpAddr::from_str("10.1.1.2").unwrap()]; - assert_matches!( - network - .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, None) - .await, - Err(NetworkAddressError::AddressAlreadyAssigned(..)) - ); - - // assign with exception for the device - let addrs = vec![IpAddr::from_str("10.1.1.2").unwrap()]; - assert_matches!( - network - .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, Some(device.id)) - .await, - Ok(()) - ); - - // try to assign gateway address - let addrs = vec![IpAddr::from_str("10.1.1.1").unwrap()]; - assert_matches!( - network - .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, None) - .await, - Err(NetworkAddressError::ReservedForGateway(..)) - ); - - // try to assign network address - let addrs = vec![IpAddr::from_str("10.1.1.0").unwrap()]; - assert_matches!( - network - .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, None) - .await, - Err(NetworkAddressError::IsNetworkAddress(..)) - ); - - // try to assign broadcast address - let addrs = vec![IpAddr::from_str("10.1.1.255").unwrap()]; - assert_matches!( - network - .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, None) - .await, - Err(NetworkAddressError::IsBroadcastAddress(..)) - ); - } - - #[sqlx::test] - async fn test_can_assign_ips_multiple_addresses(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - - let network = WireguardNetwork::new( - "network".to_string(), - 50051, - String::new(), - None, - [IpNetwork::from_str("10.1.1.0/24").unwrap()], - false, - false, - false, - LocationMfaMode::Disabled, - ServiceLocationMode::Disabled, - ) - .set_address([ - IpNetwork::from_str("10.1.1.1/24").unwrap(), - IpNetwork::from_str("fc00::1/112").unwrap(), - ]) - .unwrap() - .save(&pool) - .await - .unwrap(); - - // assign free addresses - let addrs = vec![ - IpAddr::from_str("10.1.1.2").unwrap(), - IpAddr::from_str("fc00::2").unwrap(), - ]; - assert_matches!( - network - .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, None) - .await, - Ok(()) - ); - - // assign multiple free addresses - let addrs = vec![ - IpAddr::from_str("10.1.1.2").unwrap(), - IpAddr::from_str("10.1.1.3").unwrap(), - IpAddr::from_str("fc00::2").unwrap(), - IpAddr::from_str("fc00::3").unwrap(), - ]; - assert_matches!( - network - .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, None) - .await, - Ok(()) - ); - - // try to assign address from another network - let addrs = vec![IpAddr::from_str("fa::2").unwrap()]; - assert_matches!( - network - .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, None) - .await, - Err(NetworkAddressError::NoContainingNetwork(..)) - ); - - // try to assign already assigned address - let user = User::new( - "hpotter", - Some("pass123"), - "Potter", - "Harry", - "h.potter@hogwart.edu.uk", - None, - ) - .save(&pool) - .await - .unwrap(); - - let device = Device::new( - "device".to_string(), - String::new(), - user.id, - DeviceType::User, - None, - true, - ) - .save(&pool) - .await - .unwrap(); - WireguardNetworkDevice::new( - network.id, - device.id, - vec![ - IpAddr::from_str("10.1.1.2").unwrap(), - IpAddr::from_str("fc00::2").unwrap(), - ], - ) - .insert(&pool) - .await - .unwrap(); - let addrs = vec![IpAddr::from_str("fc00::2").unwrap()]; - assert_matches!( - network - .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, None) - .await, - Err(NetworkAddressError::AddressAlreadyAssigned(..)) - ); - - // assign with exception for the device - let addrs = vec![IpAddr::from_str("fc00::2").unwrap()]; - assert_matches!( - network - .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, Some(device.id)) - .await, - Ok(()) - ); - - // try to assign gateway address - let addrs = vec![IpAddr::from_str("fc00::1").unwrap()]; - assert_matches!( - network - .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, None) - .await, - Err(NetworkAddressError::ReservedForGateway(..)) - ); - - // try to assign network address - let addrs = vec![IpAddr::from_str("fc00::0").unwrap()]; - assert_matches!( - network - .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, None) - .await, - Err(NetworkAddressError::IsNetworkAddress(..)) - ); - - // try to assign broadcast address - let addrs = vec![IpAddr::from_str("fc00::ffff").unwrap()]; - assert_matches!( - network - .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, None) - .await, - Err(NetworkAddressError::IsBroadcastAddress(..)) - ); - } -} +mod tests; diff --git a/crates/defguard_common/src/db/models/wireguard/tests.rs b/crates/defguard_common/src/db/models/wireguard/tests.rs new file mode 100644 index 0000000000..c4ec50584f --- /dev/null +++ b/crates/defguard_common/src/db/models/wireguard/tests.rs @@ -0,0 +1,583 @@ +use std::str::FromStr; + +use matches::assert_matches; +use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; + +use super::*; +use crate::db::setup_pool; + +// FIXME(mwojcik): rewrite for new stats implementation +// #[sqlx::test] +// async fn test_connected_at_reconnection(_: PgPoolOptions, options: PgConnectOptions) { +// let pool = setup_pool(options).await; +// let mut location = WireguardNetwork::default(); +// location.try_set_address("10.1.1.1/29").unwrap(); +// let location = location.save(&pool).await.unwrap(); + +// let user = User::new( +// "testuser", +// Some("hunter2"), +// "Tester", +// "Test", +// "test@test.com", +// None, +// ) +// .save(&pool) +// .await +// .unwrap(); +// let device = Device::new( +// String::new(), +// String::new(), +// user.id, +// DeviceType::User, +// None, +// true, +// ) +// .save(&pool) +// .await +// .unwrap(); + +// // insert stats +// let samples = 60; // 1 hour of samples +// let now = Utc::now().naive_utc(); +// for i in 0..=samples { +// // simulate connection 30 minutes ago +// let handshake_minutes = i * if i < 31 { 1 } else { 10 }; +// WireguardPeerStats { +// id: NoId, +// device_id: device.id, +// collected_at: now - TimeDelta::minutes(i), +// network: location.id, +// endpoint: Some("11.22.33.44".into()), +// upload: (samples - i) * 10, +// download: (samples - i) * 20, +// latest_handshake: now - TimeDelta::minutes(handshake_minutes), +// allowed_ips: Some("10.1.1.0/24".into()), +// } +// .save(&pool) +// .await +// .unwrap(); +// } + +// let connected_at = device +// .last_connected_at(&pool, location.id) +// .await +// .unwrap() +// .unwrap(); +// assert_eq!( +// connected_at, +// // PostgreSQL stores 6 sub-second digits while chrono stores 9. +// (now - TimeDelta::minutes(30)).trunc_subsecs(6), +// ); +// } + +// FIXME(mwojcik): rewrite for new stats implementation +// #[sqlx::test] +// async fn test_connected_at_always_connected(_: PgPoolOptions, options: PgConnectOptions) { +// let pool = setup_pool(options).await; +// let mut location = WireguardNetwork::default(); +// location.try_set_address("10.1.1.1/29").unwrap(); +// let location = location.save(&pool).await.unwrap(); + +// let user = User::new( +// "testuser", +// Some("hunter2"), +// "Tester", +// "Test", +// "test@test.com", +// None, +// ) +// .save(&pool) +// .await +// .unwrap(); +// let device = Device::new( +// String::new(), +// String::new(), +// user.id, +// DeviceType::User, +// None, +// true, +// ) +// .save(&pool) +// .await +// .unwrap(); + +// // insert stats +// let samples = 60; // 1 hour of samples +// let now = Utc::now().naive_utc(); +// for i in 0..=samples { +// WireguardPeerStats { +// id: NoId, +// device_id: device.id, +// collected_at: now - TimeDelta::minutes(i), +// network: location.id, +// endpoint: Some("11.22.33.44".into()), +// upload: (samples - i) * 10, +// download: (samples - i) * 20, +// latest_handshake: now - TimeDelta::minutes(i), // handshake every minute +// allowed_ips: Some("10.1.1.0/24".into()), +// } +// .save(&pool) +// .await +// .unwrap(); +// } + +// let connected_at = device +// .last_connected_at(&pool, location.id) +// .await +// .unwrap() +// .unwrap(); +// assert_eq!( +// connected_at, +// // PostgreSQL stores 6 sub-second digits while chrono stores 9. +// (now - TimeDelta::minutes(samples)).trunc_subsecs(6), +// ); +// } + +#[sqlx::test] +async fn test_get_allowed_devices_for_user(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + let mut network = WireguardNetwork::default() + .try_set_address("10.1.1.1/29") + .unwrap(); + network.allow_all_groups = true; + let network = network.save(&pool).await.unwrap(); + + let user1 = User::new( + "user1", + Some("pass1"), + "Test", + "User1", + "user1@test.com", + None, + ) + .save(&pool) + .await + .unwrap(); + + let user2 = User::new( + "user2", + Some("pass2"), + "Test", + "User2", + "user2@test.com", + None, + ) + .save(&pool) + .await + .unwrap(); + + let device1 = Device::new( + "device1".into(), + "key1".into(), + user1.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + let device2 = Device::new( + "device2".into(), + "key2".into(), + user1.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + let device3 = Device::new( + "device3".into(), + "key3".into(), + user2.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + let devices = network + .get_allowed_devices_for_user(&mut pool.acquire().await.unwrap(), user1.id) + .await + .unwrap(); + assert_eq!(devices.len(), 2); + assert!(devices.iter().any(|d| d.id == device1.id)); + assert!(devices.iter().any(|d| d.id == device2.id)); + + let devices = network + .get_allowed_devices_for_user(&mut pool.acquire().await.unwrap(), user2.id) + .await + .unwrap(); + assert_eq!(devices.len(), 1); + assert!(devices.iter().any(|d| d.id == device3.id)); + + let devices = network + .get_allowed_devices_for_user(&mut pool.acquire().await.unwrap(), Id::from(999)) + .await + .unwrap(); + assert!(devices.is_empty()); +} + +#[sqlx::test] +async fn test_get_allowed_devices_for_user_with_groups( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + let network = WireguardNetwork::default() + .try_set_address("10.1.1.1/29") + .unwrap() + .save(&pool) + .await + .unwrap(); + + let user1 = User::new( + "user1", + Some("pass1"), + "Test", + "User1", + "user1@test.com", + None, + ) + .save(&pool) + .await + .unwrap(); + + let user2 = User::new( + "user2", + Some("pass2"), + "Test", + "User2", + "user2@test.com", + None, + ) + .save(&pool) + .await + .unwrap(); + + let group1 = Group::new("group1").save(&pool).await.unwrap(); + let group2 = Group::new("group2").save(&pool).await.unwrap(); + + user1.add_to_group(&pool, &group1).await.unwrap(); + user2.add_to_group(&pool, &group2).await.unwrap(); + + let device1 = Device::new( + "device1".into(), + "key1".into(), + user1.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + Device::new( + "device2".into(), + "key2".into(), + user2.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + let mut transaction = pool.begin().await.unwrap(); + + network + .set_allowed_groups(&mut transaction, &[group1.name]) + .await + .unwrap(); + + let devices = network + .get_allowed_devices_for_user(&mut transaction, user1.id) + .await + .unwrap(); + assert_eq!(devices.len(), 1); + assert_eq!(devices[0].id, device1.id); + + let devices = network + .get_allowed_devices_for_user(&mut transaction, user2.id) + .await + .unwrap(); + assert!(devices.is_empty()); +} + +#[sqlx::test] +async fn test_can_assign_ips(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + + let network = WireguardNetwork::new( + "network".to_string(), + 50051, + String::new(), + None, + [IpNetwork::from_str("10.1.1.0/24").unwrap()], + false, + false, + false, + LocationMfaMode::Disabled, + ServiceLocationMode::Disabled, + ) + .set_address([IpNetwork::from_str("10.1.1.1/24").unwrap()]) + .unwrap() + .save(&pool) + .await + .unwrap(); + + // assign free address + let addrs = vec![IpAddr::from_str("10.1.1.2").unwrap()]; + assert_matches!( + network + .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, None) + .await, + Ok(()) + ); + + // assign multiple free addresses + let addrs = vec![ + IpAddr::from_str("10.1.1.2").unwrap(), + IpAddr::from_str("10.1.1.3").unwrap(), + ]; + assert_matches!( + network + .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, None) + .await, + Ok(()) + ); + + // try to assign address from another network + let addrs = vec![IpAddr::from_str("10.2.1.2").unwrap()]; + assert_matches!( + network + .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, None) + .await, + Err(NetworkAddressError::NoContainingNetwork(..)) + ); + + // try to assign already assigned address + let user = User::new( + "hpotter", + Some("pass123"), + "Potter", + "Harry", + "h.potter@hogwart.edu.uk", + None, + ) + .save(&pool) + .await + .unwrap(); + + let device = Device::new( + "device".to_string(), + String::new(), + user.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + WireguardNetworkDevice::new( + network.id, + device.id, + vec![IpAddr::from_str("10.1.1.2").unwrap()], + ) + .insert(&pool) + .await + .unwrap(); + let addrs = vec![IpAddr::from_str("10.1.1.2").unwrap()]; + assert_matches!( + network + .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, None) + .await, + Err(NetworkAddressError::AddressAlreadyAssigned(..)) + ); + + // assign with exception for the device + let addrs = vec![IpAddr::from_str("10.1.1.2").unwrap()]; + assert_matches!( + network + .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, Some(device.id)) + .await, + Ok(()) + ); + + // try to assign gateway address + let addrs = vec![IpAddr::from_str("10.1.1.1").unwrap()]; + assert_matches!( + network + .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, None) + .await, + Err(NetworkAddressError::ReservedForGateway(..)) + ); + + // try to assign network address + let addrs = vec![IpAddr::from_str("10.1.1.0").unwrap()]; + assert_matches!( + network + .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, None) + .await, + Err(NetworkAddressError::IsNetworkAddress(..)) + ); + + // try to assign broadcast address + let addrs = vec![IpAddr::from_str("10.1.1.255").unwrap()]; + assert_matches!( + network + .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, None) + .await, + Err(NetworkAddressError::IsBroadcastAddress(..)) + ); +} + +#[sqlx::test] +async fn test_can_assign_ips_multiple_addresses(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + + let network = WireguardNetwork::new( + "network".to_string(), + 50051, + String::new(), + None, + [IpNetwork::from_str("10.1.1.0/24").unwrap()], + false, + false, + false, + LocationMfaMode::Disabled, + ServiceLocationMode::Disabled, + ) + .set_address([ + IpNetwork::from_str("10.1.1.1/24").unwrap(), + IpNetwork::from_str("fc00::1/112").unwrap(), + ]) + .unwrap() + .save(&pool) + .await + .unwrap(); + + // assign free addresses + let addrs = vec![ + IpAddr::from_str("10.1.1.2").unwrap(), + IpAddr::from_str("fc00::2").unwrap(), + ]; + assert_matches!( + network + .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, None) + .await, + Ok(()) + ); + + // assign multiple free addresses + let addrs = vec![ + IpAddr::from_str("10.1.1.2").unwrap(), + IpAddr::from_str("10.1.1.3").unwrap(), + IpAddr::from_str("fc00::2").unwrap(), + IpAddr::from_str("fc00::3").unwrap(), + ]; + assert_matches!( + network + .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, None) + .await, + Ok(()) + ); + + // try to assign address from another network + let addrs = vec![IpAddr::from_str("fa::2").unwrap()]; + assert_matches!( + network + .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, None) + .await, + Err(NetworkAddressError::NoContainingNetwork(..)) + ); + + // try to assign already assigned address + let user = User::new( + "hpotter", + Some("pass123"), + "Potter", + "Harry", + "h.potter@hogwart.edu.uk", + None, + ) + .save(&pool) + .await + .unwrap(); + + let device = Device::new( + "device".to_string(), + String::new(), + user.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + WireguardNetworkDevice::new( + network.id, + device.id, + vec![ + IpAddr::from_str("10.1.1.2").unwrap(), + IpAddr::from_str("fc00::2").unwrap(), + ], + ) + .insert(&pool) + .await + .unwrap(); + let addrs = vec![IpAddr::from_str("fc00::2").unwrap()]; + assert_matches!( + network + .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, None) + .await, + Err(NetworkAddressError::AddressAlreadyAssigned(..)) + ); + + // assign with exception for the device + let addrs = vec![IpAddr::from_str("fc00::2").unwrap()]; + assert_matches!( + network + .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, Some(device.id)) + .await, + Ok(()) + ); + + // try to assign gateway address + let addrs = vec![IpAddr::from_str("fc00::1").unwrap()]; + assert_matches!( + network + .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, None) + .await, + Err(NetworkAddressError::ReservedForGateway(..)) + ); + + // try to assign network address + let addrs = vec![IpAddr::from_str("fc00::0").unwrap()]; + assert_matches!( + network + .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, None) + .await, + Err(NetworkAddressError::IsNetworkAddress(..)) + ); + + // try to assign broadcast address + let addrs = vec![IpAddr::from_str("fc00::ffff").unwrap()]; + assert_matches!( + network + .can_assign_ips(&mut pool.acquire().await.unwrap(), &addrs, None) + .await, + Err(NetworkAddressError::IsBroadcastAddress(..)) + ); +} From 3eca8ec30fa5b21ee028bbe9bfd8748fcf204fe2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Tue, 17 Mar 2026 14:14:14 +0100 Subject: [PATCH 06/10] Add test for set_address --- .../src/db/models/wireguard.rs | 1 - .../src/db/models/wireguard/tests.rs | 31 ++++++++++++++++++- crates/defguard_setup/src/auto_adoption.rs | 12 +++---- 3 files changed, 35 insertions(+), 9 deletions(-) diff --git a/crates/defguard_common/src/db/models/wireguard.rs b/crates/defguard_common/src/db/models/wireguard.rs index 65afd3f5ca..ebde087953 100644 --- a/crates/defguard_common/src/db/models/wireguard.rs +++ b/crates/defguard_common/src/db/models/wireguard.rs @@ -253,7 +253,6 @@ impl WireguardNetwork { } /// Try to set `address` from comma-separated string of addresses. - /// If there is an error parsing the address list, `address` will be partially set. pub fn try_set_address(mut self, address: &str) -> Result { self.address = Vec::new(); for addr in address.split(',') { diff --git a/crates/defguard_common/src/db/models/wireguard/tests.rs b/crates/defguard_common/src/db/models/wireguard/tests.rs index c4ec50584f..48421c1ef8 100644 --- a/crates/defguard_common/src/db/models/wireguard/tests.rs +++ b/crates/defguard_common/src/db/models/wireguard/tests.rs @@ -1,4 +1,4 @@ -use std::str::FromStr; +use std::{net::Ipv6Addr, str::FromStr}; use matches::assert_matches; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; @@ -6,6 +6,35 @@ use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use super::*; use crate::db::setup_pool; +#[test] +fn test_set_address() { + // This is fine. + let result = WireguardNetwork::default().set_address([ + IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 10, 10, 10)), 10).unwrap(), + IpNetwork::new( + IpAddr::V6(Ipv6Addr::new(0x1010, 0, 0, 0, 0, 0, 0, 0x1010)), + 10, + ) + .unwrap(), + ]); + assert!(result.is_ok()); + + // This should return error. + let result = WireguardNetwork::default().set_address([IpNetwork::new( + IpAddr::V4(Ipv4Addr::new(10, 10, 10, 0)), + 24, + ) + .unwrap()]); + assert!(result.is_err()); + + let result = WireguardNetwork::default().set_address([IpNetwork::new( + IpAddr::V6(Ipv6Addr::new(0x1010, 0, 0, 0, 0, 0, 0, 0)), + 112, + ) + .unwrap()]); + assert!(result.is_err()); +} + // FIXME(mwojcik): rewrite for new stats implementation // #[sqlx::test] // async fn test_connected_at_reconnection(_: PgPoolOptions, options: PgConnectOptions) { diff --git a/crates/defguard_setup/src/auto_adoption.rs b/crates/defguard_setup/src/auto_adoption.rs index e004bb2a6b..08264a41e4 100644 --- a/crates/defguard_setup/src/auto_adoption.rs +++ b/crates/defguard_setup/src/auto_adoption.rs @@ -846,13 +846,11 @@ pub async fn attempt_auto_adoption( pool: &PgPool, config: &DefGuardConfig, ) -> Result<(), anyhow::Error> { - let (edge_endpoint, gateway_endpoint) = match (&config.adopt_edge, &config.adopt_gateway) { - (Some(e), Some(g)) => (e, g), - _ => { - anyhow::bail!( - "Both --adopt-edge and --adopt-gateway must be set to run the auto-adoption wizard" - ); - } + let (Some(edge_endpoint), Some(gateway_endpoint)) = (&config.adopt_edge, &config.adopt_gateway) + else { + anyhow::bail!( + "Both --adopt-edge and --adopt-gateway must be set to run the auto-adoption wizard" + ); }; let mut auto_state = AutoAdoptionWizardState::get(pool) From c754ebffe46077ed11e60a001c664f9dba4e35d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Wed, 18 Mar 2026 07:13:45 +0100 Subject: [PATCH 07/10] Fortify try_set_address --- Cargo.lock | 16 ++++++++-------- .../defguard_common/src/db/models/wireguard.rs | 14 +++++++++++--- .../src/db/models/wireguard/tests.rs | 13 +++++++++++++ .../tests/integration/grpc/gateway.rs | 6 ------ 4 files changed, 32 insertions(+), 17 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 13844a5a7e..70e1be1d34 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6447,18 +6447,18 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "1.0.0+spec-1.1.0" +version = "1.0.1+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32c2555c699578a4f59f0cc68e5116c8d7cabbd45e1409b989d4be085b53f13e" +checksum = "9b320e741db58cac564e26c607d3cc1fdc4a88fd36c879568c07856ed83ff3e9" dependencies = [ "serde_core", ] [[package]] name = "toml_edit" -version = "0.25.4+spec-1.1.0" +version = "0.25.5+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7193cbd0ce53dc966037f54351dbbcf0d5a642c7f0038c382ef9e677ce8c13f2" +checksum = "8ca1a40644a28bce036923f6a431df0b34236949d111cc07cb6dca830c9ef2e1" dependencies = [ "indexmap 2.13.0", "toml_datetime", @@ -6468,9 +6468,9 @@ dependencies = [ [[package]] name = "toml_parser" -version = "1.0.9+spec-1.1.0" +version = "1.0.10+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "702d4415e08923e7e1ef96cd5727c0dfed80b4d2fa25db9647fe5eb6f7c5a4c4" +checksum = "7df25b4befd31c4816df190124375d5a20c6b6921e2cad937316de3fccd63420" dependencies = [ "winnow", ] @@ -7603,9 +7603,9 @@ checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" [[package]] name = "winnow" -version = "0.7.15" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df79d97927682d2fd8adb29682d1140b343be4ac0f08fd68b7765d9c059d3945" +checksum = "a90e88e4667264a994d34e6d1ab2d26d398dcdca8b7f52bec8668957517fc7d8" dependencies = [ "memchr", ] diff --git a/crates/defguard_common/src/db/models/wireguard.rs b/crates/defguard_common/src/db/models/wireguard.rs index 23db7e232b..8fce989d48 100644 --- a/crates/defguard_common/src/db/models/wireguard.rs +++ b/crates/defguard_common/src/db/models/wireguard.rs @@ -267,10 +267,18 @@ impl WireguardNetwork { /// Try to set `address` from comma-separated string of addresses. pub fn try_set_address(mut self, address: &str) -> Result { self.address = Vec::new(); - for addr in address.split(',') { - self.address.push(addr.trim().parse()?); + for addr_str in address.split(',') { + let addr = addr_str.trim().parse::()?; + let ip = addr.ip(); + if ip == addr.network() { + return Err(IpNetworkError::InvalidAddr("address is network".into())); + } + if ip == addr.broadcast() { + return Err(IpNetworkError::InvalidAddr("address is broadcast".into())); + } + self.address.push(addr); } - if address.is_empty() { + if self.address.is_empty() { Err(IpNetworkError::InvalidAddr("empty address".into())) } else { Ok(self) diff --git a/crates/defguard_common/src/db/models/wireguard/tests.rs b/crates/defguard_common/src/db/models/wireguard/tests.rs index 48421c1ef8..390b5076cc 100644 --- a/crates/defguard_common/src/db/models/wireguard/tests.rs +++ b/crates/defguard_common/src/db/models/wireguard/tests.rs @@ -35,6 +35,19 @@ fn test_set_address() { assert!(result.is_err()); } +#[test] +fn test_try_set_address() { + // Valid host address should be accepted. + let result = WireguardNetwork::default().try_set_address("10.10.10.10/24"); + assert!(result.is_ok()); + // Network address should be rejected. + let result = WireguardNetwork::default().try_set_address("10.10.10.0/24"); + assert!(result.is_err()); + // Broadcast address should be rejected. + let result = WireguardNetwork::default().try_set_address("10.10.10.255/24"); + assert!(result.is_err()); +} + // FIXME(mwojcik): rewrite for new stats implementation // #[sqlx::test] // async fn test_connected_at_reconnection(_: PgPoolOptions, options: PgConnectOptions) { diff --git a/crates/defguard_core/tests/integration/grpc/gateway.rs b/crates/defguard_core/tests/integration/grpc/gateway.rs index e5de5b8268..0d0fe35502 100644 --- a/crates/defguard_core/tests/integration/grpc/gateway.rs +++ b/crates/defguard_core/tests/integration/grpc/gateway.rs @@ -55,8 +55,6 @@ async fn setup_test_server( 1000, "endpoint1".to_string(), None, - 1420, - 0, Vec::new(), false, 100, @@ -406,8 +404,6 @@ async fn test_gateway_update_routing(_: PgPoolOptions, options: PgConnectOptions 1000, "endpoint2".to_string(), None, - 1420, - 0, Vec::new(), false, 100, @@ -526,8 +522,6 @@ async fn test_gateway_config(_: PgPoolOptions, options: PgConnectOptions) { 1000, "endpoint2".to_string(), None, - 1420, - 0, Vec::new(), false, 100, From 0bc14ca3e32ce34d6434bebd07f4c1b6d7ba19b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Wed, 18 Mar 2026 07:37:25 +0100 Subject: [PATCH 08/10] Apply suggestions --- .../defguard_common/src/db/models/wireguard.rs | 17 +++++------------ .../enterprise/firewall/tests/all_locations.rs | 2 +- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/crates/defguard_common/src/db/models/wireguard.rs b/crates/defguard_common/src/db/models/wireguard.rs index 8fce989d48..081cb61e75 100644 --- a/crates/defguard_common/src/db/models/wireguard.rs +++ b/crates/defguard_common/src/db/models/wireguard.rs @@ -265,23 +265,16 @@ impl WireguardNetwork { } /// Try to set `address` from comma-separated string of addresses. - pub fn try_set_address(mut self, address: &str) -> Result { - self.address = Vec::new(); + pub fn try_set_address(self, address: &str) -> Result { + let mut parsed_addresses = Vec::new(); for addr_str in address.split(',') { let addr = addr_str.trim().parse::()?; - let ip = addr.ip(); - if ip == addr.network() { - return Err(IpNetworkError::InvalidAddr("address is network".into())); - } - if ip == addr.broadcast() { - return Err(IpNetworkError::InvalidAddr("address is broadcast".into())); - } - self.address.push(addr); + parsed_addresses.push(addr); } - if self.address.is_empty() { + if parsed_addresses.is_empty() { Err(IpNetworkError::InvalidAddr("empty address".into())) } else { - Ok(self) + self.set_address(parsed_addresses) } } } diff --git a/crates/defguard_core/src/enterprise/firewall/tests/all_locations.rs b/crates/defguard_core/src/enterprise/firewall/tests/all_locations.rs index a5f2d1e3f9..3ee0ea139e 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/all_locations.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/all_locations.rs @@ -26,7 +26,7 @@ async fn test_acl_rules_all_locations_ipv4(_: PgPoolOptions, options: PgConnectO // Create another test location let mut location_2 = WireguardNetwork::default() - .set_address(["fb00::1/112".parse().unwrap()]) + .set_address(["192.168.0.1/24".parse().unwrap()]) .unwrap(); location_2.acl_enabled = true; let location_2 = location_2.save(&pool).await.unwrap(); From 1f12c426a5bf2e67f7c3cf917b56c5de81140f7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Wed, 18 Mar 2026 08:11:49 +0100 Subject: [PATCH 09/10] Remove parse_addresses() --- .../src/db/models/wireguard.rs | 31 ++++++++++--------- .../defguard_core/src/handlers/wireguard.rs | 28 ++--------------- 2 files changed, 19 insertions(+), 40 deletions(-) diff --git a/crates/defguard_common/src/db/models/wireguard.rs b/crates/defguard_common/src/db/models/wireguard.rs index 081cb61e75..d61d012d5d 100644 --- a/crates/defguard_common/src/db/models/wireguard.rs +++ b/crates/defguard_common/src/db/models/wireguard.rs @@ -263,20 +263,6 @@ impl WireguardNetwork { service_location_mode, } } - - /// Try to set `address` from comma-separated string of addresses. - pub fn try_set_address(self, address: &str) -> Result { - let mut parsed_addresses = Vec::new(); - for addr_str in address.split(',') { - let addr = addr_str.trim().parse::()?; - parsed_addresses.push(addr); - } - if parsed_addresses.is_empty() { - Err(IpNetworkError::InvalidAddr("empty address".into())) - } else { - self.set_address(parsed_addresses) - } - } } impl WireguardNetwork { @@ -292,6 +278,9 @@ impl WireguardNetwork { { let address = address.into(); for addr in &address { + if addr.prefix() == 0 { + return Err(IpNetworkError::InvalidAddr("prefix is zero".into())); + } let ip = addr.ip(); if ip == addr.network() { return Err(IpNetworkError::InvalidAddr("address is network".into())); @@ -304,6 +293,20 @@ impl WireguardNetwork { Ok(self) } + + /// Try to set `address` from comma-separated string of addresses. + pub fn try_set_address(self, address: &str) -> Result { + let mut parsed_addresses = Vec::new(); + for addr_str in address.split(',') { + let addr = addr_str.trim().parse::()?; + parsed_addresses.push(addr); + } + if parsed_addresses.is_empty() { + Err(IpNetworkError::InvalidAddr("empty address".into())) + } else { + self.set_address(parsed_addresses) + } + } } impl WireguardNetwork { diff --git a/crates/defguard_core/src/handlers/wireguard.rs b/crates/defguard_core/src/handlers/wireguard.rs index e5c3e26ed7..214c209307 100644 --- a/crates/defguard_core/src/handlers/wireguard.rs +++ b/crates/defguard_core/src/handlers/wireguard.rs @@ -14,7 +14,7 @@ use defguard_common::{ wireguard::{LocationMfaMode, MappedDevice, ServiceLocationMode}, }, }, - utils::{parse_address_list, parse_network_address_list}, + utils::parse_network_address_list, }; use defguard_mail::templates::{TemplateLocation, new_device_added_mail}; use ipnetwork::IpNetwork; @@ -87,29 +87,6 @@ impl WireguardNetworkData { .map_or(Vec::new(), |ips| parse_network_address_list(ips)) } - pub(crate) fn parse_addresses(&self) -> Result, WebError> { - // first parse the addresses - let subnets = parse_address_list(self.address.as_ref()); - - // check if address list is not empty - if subnets.is_empty() { - return Err(WebError::BadRequest( - "Must provide at least one valid network address".to_owned(), - )); - } - - // check if any subnet has an invalid /0 netmask - for subnet in &subnets { - if subnet.prefix() == 0 { - return Err(WebError::BadRequest(format!( - "{subnet} is not a valid address" - ))); - } - } - - Ok(subnets) - } - pub(crate) fn validate_peer_disconnect_threshold(&self) -> Result<(), WebError> { if self.location_mfa_mode == LocationMfaMode::Disabled { return Ok(()); @@ -352,8 +329,7 @@ pub(crate) async fn modify_network( let network = find_network(network_id, &appstate.pool).await?; // store network before mods let before = network.clone(); - let new_addresses = data.parse_addresses()?; - let mut network = network.set_address(new_addresses)?; + let mut network = network.try_set_address(&data.address)?; network.allowed_ips = data.parse_allowed_ips(); network.name = data.name; From 12f2832ccf03a3534961ceb4b25323df69001371 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Wed, 18 Mar 2026 08:47:44 +0100 Subject: [PATCH 10/10] Removed commented-out tests --- .../src/db/models/wireguard/tests.rs | 128 ------------------ 1 file changed, 128 deletions(-) diff --git a/crates/defguard_common/src/db/models/wireguard/tests.rs b/crates/defguard_common/src/db/models/wireguard/tests.rs index 390b5076cc..4db9443cef 100644 --- a/crates/defguard_common/src/db/models/wireguard/tests.rs +++ b/crates/defguard_common/src/db/models/wireguard/tests.rs @@ -48,134 +48,6 @@ fn test_try_set_address() { assert!(result.is_err()); } -// FIXME(mwojcik): rewrite for new stats implementation -// #[sqlx::test] -// async fn test_connected_at_reconnection(_: PgPoolOptions, options: PgConnectOptions) { -// let pool = setup_pool(options).await; -// let mut location = WireguardNetwork::default(); -// location.try_set_address("10.1.1.1/29").unwrap(); -// let location = location.save(&pool).await.unwrap(); - -// let user = User::new( -// "testuser", -// Some("hunter2"), -// "Tester", -// "Test", -// "test@test.com", -// None, -// ) -// .save(&pool) -// .await -// .unwrap(); -// let device = Device::new( -// String::new(), -// String::new(), -// user.id, -// DeviceType::User, -// None, -// true, -// ) -// .save(&pool) -// .await -// .unwrap(); - -// // insert stats -// let samples = 60; // 1 hour of samples -// let now = Utc::now().naive_utc(); -// for i in 0..=samples { -// // simulate connection 30 minutes ago -// let handshake_minutes = i * if i < 31 { 1 } else { 10 }; -// WireguardPeerStats { -// id: NoId, -// device_id: device.id, -// collected_at: now - TimeDelta::minutes(i), -// network: location.id, -// endpoint: Some("11.22.33.44".into()), -// upload: (samples - i) * 10, -// download: (samples - i) * 20, -// latest_handshake: now - TimeDelta::minutes(handshake_minutes), -// allowed_ips: Some("10.1.1.0/24".into()), -// } -// .save(&pool) -// .await -// .unwrap(); -// } - -// let connected_at = device -// .last_connected_at(&pool, location.id) -// .await -// .unwrap() -// .unwrap(); -// assert_eq!( -// connected_at, -// // PostgreSQL stores 6 sub-second digits while chrono stores 9. -// (now - TimeDelta::minutes(30)).trunc_subsecs(6), -// ); -// } - -// FIXME(mwojcik): rewrite for new stats implementation -// #[sqlx::test] -// async fn test_connected_at_always_connected(_: PgPoolOptions, options: PgConnectOptions) { -// let pool = setup_pool(options).await; -// let mut location = WireguardNetwork::default(); -// location.try_set_address("10.1.1.1/29").unwrap(); -// let location = location.save(&pool).await.unwrap(); - -// let user = User::new( -// "testuser", -// Some("hunter2"), -// "Tester", -// "Test", -// "test@test.com", -// None, -// ) -// .save(&pool) -// .await -// .unwrap(); -// let device = Device::new( -// String::new(), -// String::new(), -// user.id, -// DeviceType::User, -// None, -// true, -// ) -// .save(&pool) -// .await -// .unwrap(); - -// // insert stats -// let samples = 60; // 1 hour of samples -// let now = Utc::now().naive_utc(); -// for i in 0..=samples { -// WireguardPeerStats { -// id: NoId, -// device_id: device.id, -// collected_at: now - TimeDelta::minutes(i), -// network: location.id, -// endpoint: Some("11.22.33.44".into()), -// upload: (samples - i) * 10, -// download: (samples - i) * 20, -// latest_handshake: now - TimeDelta::minutes(i), // handshake every minute -// allowed_ips: Some("10.1.1.0/24".into()), -// } -// .save(&pool) -// .await -// .unwrap(); -// } - -// let connected_at = device -// .last_connected_at(&pool, location.id) -// .await -// .unwrap() -// .unwrap(); -// assert_eq!( -// connected_at, -// // PostgreSQL stores 6 sub-second digits while chrono stores 9. -// (now - TimeDelta::minutes(samples)).trunc_subsecs(6), -// ); -// } - #[sqlx::test] async fn test_get_allowed_devices_for_user(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await;