diff --git a/.github/workflows/build-docker.yml b/.github/workflows/build-docker.yml index 7fc2e420a2..989c13471b 100644 --- a/.github/workflows/build-docker.yml +++ b/.github/workflows/build-docker.yml @@ -72,7 +72,7 @@ jobs: cache-to: type=registry,mode=max,ref=${{ env.GHCR_REPO }}:cache-${{ matrix.tag }}-${{ env.SAFE_REF }} - name: Scan image with Trivy - uses: aquasecurity/trivy-action@0.33.1 + uses: aquasecurity/trivy-action@0.34.2 env: TRIVY_SHOW_SUPPRESSED: 1 TRIVY_IGNOREFILE: "./.trivyignore.yaml" @@ -96,7 +96,7 @@ jobs: steps: - name: Install Cosign - uses: sigstore/cosign-installer@v3.9.2 + uses: sigstore/cosign-installer@v3.10.1 - name: Docker meta id: meta diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a4932a04e4..e76ad954c9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -59,7 +59,7 @@ jobs: submodules: recursive - name: Scan code with Trivy - uses: aquasecurity/trivy-action@0.33.1 + uses: aquasecurity/trivy-action@0.34.2 env: TRIVY_SHOW_SUPPRESSED: 1 TRIVY_IGNOREFILE: "./.trivyignore.yaml" diff --git a/.github/workflows/sbom.yml b/.github/workflows/sbom.yml index 05951846ca..306f8677e3 100644 --- a/.github/workflows/sbom.yml +++ b/.github/workflows/sbom.yml @@ -35,7 +35,7 @@ jobs: submodules: recursive - name: Create SBOM with Trivy - uses: aquasecurity/trivy-action@0.33.1 + uses: aquasecurity/trivy-action@0.34.2 with: scan-type: 'fs' format: 'spdx-json' @@ -46,7 +46,7 @@ jobs: skip-dirs: "e2e" - name: Create docker image SBOM with Trivy - uses: aquasecurity/trivy-action@0.33.1 + uses: aquasecurity/trivy-action@0.34.2 with: image-ref: "ghcr.io/defguard/defguard:${{ steps.vars.outputs.VERSION }}" scan-type: 'image' @@ -56,7 +56,7 @@ jobs: scanners: "vuln" - name: Create security advisory file with Trivy - uses: aquasecurity/trivy-action@0.33.1 + uses: aquasecurity/trivy-action@0.34.2 with: scan-type: 'fs' format: 'json' @@ -67,7 +67,7 @@ jobs: skip-dirs: "e2e" - name: Create docker image security advisory file with Trivy - uses: aquasecurity/trivy-action@0.33.1 + uses: aquasecurity/trivy-action@0.34.2 with: image-ref: "ghcr.io/defguard/defguard:${{ steps.vars.outputs.VERSION }}" scan-type: 'image' diff --git a/.sqlx/query-27e7e18a7014af541fe5f8f051f78d61eebe6a79945324e98ca452b50d6abc90.json b/.sqlx/query-27e7e18a7014af541fe5f8f051f78d61eebe6a79945324e98ca452b50d6abc90.json new file mode 100644 index 0000000000..7f15265f00 --- /dev/null +++ b/.sqlx/query-27e7e18a7014af541fe5f8f051f78d61eebe6a79945324e98ca452b50d6abc90.json @@ -0,0 +1,86 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT * FROM proxy WHERE enabled AND id NOT IN (SELECT id FROM proxy WHERE enabled LIMIT 1\n )", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "name", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "address", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "port", + "type_info": "Int4" + }, + { + "ordinal": 4, + "name": "connected_at", + "type_info": "Timestamp" + }, + { + "ordinal": 5, + "name": "disconnected_at", + "type_info": "Timestamp" + }, + { + "ordinal": 6, + "name": "certificate_expiry", + "type_info": "Timestamp" + }, + { + "ordinal": 7, + "name": "version", + "type_info": "Text" + }, + { + "ordinal": 8, + "name": "modified_at", + "type_info": "Timestamp" + }, + { + "ordinal": 9, + "name": "certificate", + "type_info": "Text" + }, + { + "ordinal": 10, + "name": "enabled", + "type_info": "Bool" + }, + { + "ordinal": 11, + "name": "modified_by", + "type_info": "Text" + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + false, + false, + false, + false, + true, + true, + true, + true, + false, + true, + false, + false + ] + }, + "hash": "27e7e18a7014af541fe5f8f051f78d61eebe6a79945324e98ca452b50d6abc90" +} diff --git a/.sqlx/query-2ce93887379d80ff03753caaf94ec1ab4c6f0ead212fc74bb881e1d5c0d96080.json b/.sqlx/query-2ce93887379d80ff03753caaf94ec1ab4c6f0ead212fc74bb881e1d5c0d96080.json index aa509dfc3e..c27dd4058c 100644 --- a/.sqlx/query-2ce93887379d80ff03753caaf94ec1ab4c6f0ead212fc74bb881e1d5c0d96080.json +++ b/.sqlx/query-2ce93887379d80ff03753caaf94ec1ab4c6f0ead212fc74bb881e1d5c0d96080.json @@ -55,13 +55,13 @@ }, { "ordinal": 10, - "name": "modified_by", - "type_info": "Text" + "name": "enabled", + "type_info": "Bool" }, { "ordinal": 11, - "name": "enabled", - "type_info": "Bool" + "name": "modified_by", + "type_info": "Text" } ], "parameters": { diff --git a/.sqlx/query-472e3903cf3df3c5938527c5584a5a53edc5b492a6bb12eac3d97f3ebc5f8506.json b/.sqlx/query-472e3903cf3df3c5938527c5584a5a53edc5b492a6bb12eac3d97f3ebc5f8506.json index 68e6a6c079..107e426de6 100644 --- a/.sqlx/query-472e3903cf3df3c5938527c5584a5a53edc5b492a6bb12eac3d97f3ebc5f8506.json +++ b/.sqlx/query-472e3903cf3df3c5938527c5584a5a53edc5b492a6bb12eac3d97f3ebc5f8506.json @@ -55,13 +55,13 @@ }, { "ordinal": 10, - "name": "modified_by", - "type_info": "Text" + "name": "enabled", + "type_info": "Bool" }, { "ordinal": 11, - "name": "enabled", - "type_info": "Bool" + "name": "modified_by", + "type_info": "Text" } ], "parameters": { diff --git a/.sqlx/query-4b1b06bb9769c0237e467ee7cc5265d9b0c45a5b27368cffd634d858e974d8aa.json b/.sqlx/query-4b1b06bb9769c0237e467ee7cc5265d9b0c45a5b27368cffd634d858e974d8aa.json new file mode 100644 index 0000000000..fa52645561 --- /dev/null +++ b/.sqlx/query-4b1b06bb9769c0237e467ee7cc5265d9b0c45a5b27368cffd634d858e974d8aa.json @@ -0,0 +1,12 @@ +{ + "db_name": "PostgreSQL", + "query": "UPDATE gateway SET enabled = false WHERE enabled AND id NOT IN (SELECT id FROM gateway WHERE enabled LIMIT 1\n )", + "describe": { + "columns": [], + "parameters": { + "Left": [] + }, + "nullable": [] + }, + "hash": "4b1b06bb9769c0237e467ee7cc5265d9b0c45a5b27368cffd634d858e974d8aa" +} diff --git a/.sqlx/query-a41787c8c8307414165ab23ef96d82a34d3bfa4364cbe9b8368e71445bc20877.json b/.sqlx/query-a41787c8c8307414165ab23ef96d82a34d3bfa4364cbe9b8368e71445bc20877.json index 87ce4e720e..3b34801d88 100644 --- a/.sqlx/query-a41787c8c8307414165ab23ef96d82a34d3bfa4364cbe9b8368e71445bc20877.json +++ b/.sqlx/query-a41787c8c8307414165ab23ef96d82a34d3bfa4364cbe9b8368e71445bc20877.json @@ -55,13 +55,13 @@ }, { "ordinal": 10, - "name": "modified_by", - "type_info": "Text" + "name": "enabled", + "type_info": "Bool" }, { "ordinal": 11, - "name": "enabled", - "type_info": "Bool" + "name": "modified_by", + "type_info": "Text" } ], "parameters": { diff --git a/Cargo.lock b/Cargo.lock index 4a6f696a65..e9530c1277 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4562,9 +4562,9 @@ dependencies = [ [[package]] name = "proc-macro-crate" -version = "3.4.0" +version = "3.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "219cb19e96be00ab2e37d6e299658a0cfa83e52429179969b0f0121b4ac46983" +checksum = "e67ba7e9b2b56446f1d419b1d807906278ffa1a658a8a5d8a39dcb1f5a78614f" dependencies = [ "toml_edit", ] @@ -6448,18 +6448,18 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.7.5+spec-1.1.0" +version = "1.0.0+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92e1cfed4a3038bc5a127e35a2d360f145e1f4b971b551a2ba5fd7aedf7e1347" +checksum = "32c2555c699578a4f59f0cc68e5116c8d7cabbd45e1409b989d4be085b53f13e" dependencies = [ "serde_core", ] [[package]] name = "toml_edit" -version = "0.23.10+spec-1.0.0" +version = "0.25.4+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84c8b9f757e028cee9fa244aea147aab2a9ec09d5325a9b01e0a49730c2b5269" +checksum = "7193cbd0ce53dc966037f54351dbbcf0d5a642c7f0038c382ef9e677ce8c13f2" dependencies = [ "indexmap 2.13.0", "toml_datetime", diff --git a/crates/defguard/src/main.rs b/crates/defguard/src/main.rs index 5dd2dec097..afd006fa93 100644 --- a/crates/defguard/src/main.rs +++ b/crates/defguard/src/main.rs @@ -212,7 +212,7 @@ async fn main() -> Result<(), anyhow::Error> { failed_logins, api_event_tx, incompatible_components, - proxy_control_tx + proxy_control_tx.clone() ) => error!("Web server returned early: {res:?}"), res = run_periodic_stats_purge( pool.clone(), @@ -220,7 +220,7 @@ async fn main() -> Result<(), anyhow::Error> { config.stats_purge_threshold.into() ), if !config.disable_stats_purge => error!("Periodic stats purge task returned early: {res:?}"), - res = run_periodic_license_check(&pool) => + res = run_periodic_license_check(&pool, proxy_control_tx) => error!("Periodic license check task returned early: {res:?}"), res = run_utility_thread(&pool, gateway_tx.clone()) => error!("Utility thread returned early: {res:?}"), diff --git a/crates/defguard_certs/src/lib.rs b/crates/defguard_certs/src/lib.rs index f92c10f96e..460645642f 100644 --- a/crates/defguard_certs/src/lib.rs +++ b/crates/defguard_certs/src/lib.rs @@ -391,7 +391,7 @@ mod tests { let days = (not_after - not_before).whole_days(); assert!( - (valid_days as i64 - 1..=valid_days as i64 + 1).contains(&days), + (i64::from(valid_days) - 1..=i64::from(valid_days) + 1).contains(&days), "expected validity of {valid_days} days (±1), got {days} days" ); assert!( @@ -448,8 +448,7 @@ mod tests { assert!( email_found, - "Email '{}' should be present in Subject Alternative Names", - expected_email + "Email '{expected_email}' should be present in Subject Alternative Names" ); } diff --git a/crates/defguard_common/src/db/models/gateway.rs b/crates/defguard_common/src/db/models/gateway.rs index 98293331c9..c57c006cb7 100644 --- a/crates/defguard_common/src/db/models/gateway.rs +++ b/crates/defguard_common/src/db/models/gateway.rs @@ -3,7 +3,7 @@ use std::fmt; use chrono::{NaiveDateTime, Timelike, Utc}; use model_derive::Model; use serde::{Deserialize, Serialize}; -use sqlx::{PgExecutor, query, query_as}; +use sqlx::{PgExecutor, query, query_as, query_scalar}; use crate::db::{Id, NoId}; @@ -89,7 +89,7 @@ impl Gateway { } /// Update `connected_at` to the current time and save it to the database. - pub async fn touch_connected<'e, E>(&mut self, executor: E) -> Result<(), sqlx::Error> + pub async fn touch_connected<'e, E>(&mut self, executor: E) -> sqlx::Result<()> where E: PgExecutor<'e>, { @@ -106,7 +106,7 @@ impl Gateway { } /// Set `disconnected_at` to the current time and save it to the database. - pub async fn touch_disconnected<'e, E>(&mut self, executor: E) -> Result<(), sqlx::Error> + pub async fn touch_disconnected<'e, E>(&mut self, executor: E) -> sqlx::Result<()> where E: PgExecutor<'e>, { @@ -122,11 +122,11 @@ impl Gateway { Ok(()) } - pub async fn delete_by_id<'e, E>(executor: E, id: Id) -> Result<(), sqlx::Error> + pub async fn delete_by_id<'e, E>(executor: E, id: Id) -> sqlx::Result<()> where E: PgExecutor<'e>, { - sqlx::query!("DELETE FROM \"gateway\" WHERE id = $1", id,) + query!("DELETE FROM \"gateway\" WHERE id = $1", id,) .execute(executor) .await?; @@ -138,7 +138,7 @@ impl Gateway { executor: E, address: &str, port: u16, - ) -> Result, sqlx::Error> + ) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -154,10 +154,29 @@ impl Gateway { Ok(record) } + /// Return address and port as URL with HTTP scheme. #[must_use] pub fn url(&self) -> String { format!("http://{}:{}", self.address, self.port) } + + /// Disable all Gateways except one. Used for expired licence. + pub async fn leave_one_enabled<'e, E>(executor: E) -> sqlx::Result<()> + where + E: PgExecutor<'e>, + { + let result = query_scalar!( + "UPDATE gateway SET enabled = false WHERE enabled AND id NOT IN (\ + SELECT id FROM gateway WHERE enabled LIMIT 1 + )" + ) + .execute(executor) + .await?; + + tracing::debug!("Disabled {} Gateways", result.rows_affected()); + + Ok(()) + } } impl fmt::Display for Gateway { diff --git a/crates/defguard_common/src/db/models/proxy.rs b/crates/defguard_common/src/db/models/proxy.rs index c2b359f8ea..011c63efc8 100644 --- a/crates/defguard_common/src/db/models/proxy.rs +++ b/crates/defguard_common/src/db/models/proxy.rs @@ -112,4 +112,19 @@ impl Proxy { Ok(()) } + + /// Fetch all enabled, but one. Used for expired licence. + pub async fn leave_one_enabled<'e, E>(executor: E) -> sqlx::Result> + where + E: sqlx::PgExecutor<'e>, + { + sqlx::query_as!( + Self, + "SELECT * FROM proxy WHERE enabled AND id NOT IN (\ + SELECT id FROM proxy WHERE enabled LIMIT 1 + )" + ) + .fetch_all(executor) + .await + } } diff --git a/crates/defguard_common/src/db/models/user.rs b/crates/defguard_common/src/db/models/user.rs index c128655198..5d2fe78023 100644 --- a/crates/defguard_common/src/db/models/user.rs +++ b/crates/defguard_common/src/db/models/user.rs @@ -1160,6 +1160,7 @@ impl User { .await } + #[must_use] pub fn fullname(&self) -> String { format!("{} {}", self.first_name, self.last_name) } diff --git a/crates/defguard_common/src/db/models/wizard.rs b/crates/defguard_common/src/db/models/wizard.rs index 3e77fd37ab..60a54c2612 100644 --- a/crates/defguard_common/src/db/models/wizard.rs +++ b/crates/defguard_common/src/db/models/wizard.rs @@ -197,16 +197,14 @@ impl Wizard { let step = self .initial_setup_state .as_ref() - .map(|s| s.step) - .unwrap_or(InitialSetupStep::Welcome); + .map_or(InitialSetupStep::Welcome, |s| s.step); step > InitialSetupStep::AdminUser } ActiveWizard::AutoAdoption => { let step = self .auto_adoption_state .as_ref() - .map(|s| s.step) - .unwrap_or(AutoAdoptionWizardStep::Welcome); + .map_or(AutoAdoptionWizardStep::Welcome, |s| s.step); step > AutoAdoptionWizardStep::AdminUser } _ => true, diff --git a/crates/defguard_core/src/enterprise/firewall/tests/mod.rs b/crates/defguard_core/src/enterprise/firewall/tests/mod.rs index 78c8ca896d..a6613cc71b 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/mod.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/mod.rs @@ -102,7 +102,7 @@ async fn create_test_users_and_devices( let device = device.save(pool).await.unwrap(); // Add device to locations' VPN network - for location in test_locations.iter() { + for location in &test_locations { let wireguard_ips = location .address .iter() diff --git a/crates/defguard_core/src/enterprise/ldap/tests.rs b/crates/defguard_core/src/enterprise/ldap/tests.rs index dbb46c0083..4ef6c09c70 100644 --- a/crates/defguard_core/src/enterprise/ldap/tests.rs +++ b/crates/defguard_core/src/enterprise/ldap/tests.rs @@ -2488,9 +2488,7 @@ async fn test_sync_ldap_to_defguard_does_not_exceed_user_license_limit( assert!( user_count_after_sync <= user_limit, - "LDAP sync exceeded user license limit: users={}, limit={}", - user_count_after_sync, - user_limit + "LDAP sync exceeded user license limit: users={user_count_after_sync}, limit={user_limit}" ); let skipped_user = User::find_by_username(&pool, "ldap_only_user_limit") diff --git a/crates/defguard_core/src/enterprise/license.rs b/crates/defguard_core/src/enterprise/license.rs index f0895c3a0d..471e64574d 100644 --- a/crates/defguard_core/src/enterprise/license.rs +++ b/crates/defguard_core/src/enterprise/license.rs @@ -6,8 +6,9 @@ use chrono::{DateTime, TimeDelta, Utc}; use defguard_common::{ VERSION, config::server_config, - db::models::{Settings, settings::update_current_settings}, + db::models::{Settings, gateway::Gateway, proxy::Proxy, settings::update_current_settings}, global_value, + types::proxy::ProxyControlMessage, }; use humantime::format_duration; use pgp::{ @@ -222,18 +223,23 @@ impl License { if license.requires_renewal() { if license.is_max_overdue() { warn!( - "The provided license has expired and reached its maximum overdue time, please contact salesdefguard.net" + "The provided license has expired and reached its maximum overdue time, \ + please contact salesdefguard.net" ); } else { warn!( - "The provided license is about to expire and requires a renewal. An automatic renewal process will attempt to renew the license soon. Alternatively, automatic renewal attempt will be also performed at the next defguard start." + "The provided license is about to expire and requires a renewal. An \ + automatic renewal process will attempt to renew the license soon. \ + Alternatively, automatic renewal attempt will be also performed at the \ + next Defguard start." ); } } if !license.subscription && license.is_expired() { warn!( - "The provided license is not a subscription and has expired, please contact salesdefguard.net" + "The provided license is not a subscription and has expired, please \ + contact salesdefguard.net" ); } @@ -260,8 +266,9 @@ impl License { } } - /// Try to load the license from the database, if the license requires a renewal, try to renew it. - /// If the renewal fails, it will return the old license for the renewal service to renew it later. + /// Try to load the license from the database, if the license requires a renewal, try to renew + /// it. If the renewal fails, it will return the old license for the renewal service to renew it + /// later. pub async fn load_or_renew(pool: &PgPool) -> Result, LicenseError> { match Self::load()? { Some(license) => { @@ -275,7 +282,8 @@ impl License { let new_license = License::from_base64(&new_key)?; save_license_key(pool, &new_key).await?; info!( - "Successfully renewed and loaded the license, new license key saved to the database" + "Successfully renewed and loaded the license, new license key \ + saved to the database" ); Ok(Some(new_license)) } @@ -349,7 +357,8 @@ impl License { if self.subscription { self.time_overdue() > MAX_OVERDUE_TIME } else { - // Non-subscription licenses are considered expired immediately, no grace period is required + // Non-subscription licenses are considered expired immediately, no grace period is + // required. self.is_expired() } } @@ -484,8 +493,39 @@ pub fn update_cached_license(key: Option<&str>) -> Result<(), LicenseError> { const RENEWAL_TIME: TimeDelta = TimeDelta::hours(24); const MAX_OVERDUE_TIME: TimeDelta = TimeDelta::days(14); +/// Scale down enabled Gateways and Edges to one (per component). +async fn trim_gateways_and_edges( + pool: &PgPool, + proxy_control_tx: &tokio::sync::mpsc::Sender, +) -> Result<(), LicenseError> { + Gateway::leave_one_enabled(pool).await?; + + let edges = Proxy::leave_one_enabled(pool).await?; + let count = edges.len(); + for mut edge in edges { + edge.enabled = false; + edge.save(pool).await?; + if let Err(err) = proxy_control_tx + .send(ProxyControlMessage::ShutdownConnection(edge.id)) + .await + { + error!( + "Failed to shutdown Proxy {}, it may be disconnected: {err:?}", + edge.id + ); + } + } + + debug!("Disabled {count} Edges"); + + Ok(()) +} + #[instrument(skip_all)] -pub async fn run_periodic_license_check(pool: &PgPool) -> Result<(), LicenseError> { +pub async fn run_periodic_license_check( + pool: &PgPool, + proxy_control_tx: tokio::sync::mpsc::Sender, +) -> Result<(), LicenseError> { let config = server_config(); let mut check_period: Duration = *config.check_period; info!( @@ -497,6 +537,9 @@ pub async fn run_periodic_license_check(pool: &PgPool) -> Result<(), LicenseErro // Check if the license is present in the mutex, if not skip the check if get_cached_license().is_none() { debug!("No license found, skipping license check"); + + trim_gateways_and_edges(pool, &proxy_control_tx).await?; + sleep(*config.check_period_no_license).await; continue; } @@ -504,28 +547,31 @@ pub async fn run_periodic_license_check(pool: &PgPool) -> Result<(), LicenseErro // Check if the license requires renewal, uses the cached value to be more efficient // The block here is to avoid holding the lock through awaits // - // Multiple locks here may cause a race condition if the user decides to update the license key - // while the renewal is in progress. However this seems like a rare case and shouldn't be very problematic. - let requires_renewal = { - let license = get_cached_license(); - debug!("Checking if the license {license:?} requires a renewal..."); - - if let Some(license) = license.as_ref() { + // Multiple locks here may cause a race condition if the user decides to update the license + // key while the renewal is in progress. However this seems like a rare case and shouldn't + // be very problematic. + let (requires_renewal, trim_components) = { + let cached_license = get_cached_license(); + debug!("Checking if the license {cached_license:?} requires a renewal"); + + if let Some(license) = cached_license.as_ref() { if license.requires_renewal() { // check if we are pass the maximum expiration date, after which we don't // want to try to renew the license anymore if license.is_max_overdue() { check_period = *config.check_period; warn!( - "Your license has expired and reached its maximum overdue date, please contact sales at salesdefguard.net" + "Your license has expired and reached its maximum overdue date, please \ + contact sales at salesdefguard.net" ); debug!("Changing check period to {}", format_duration(check_period)); - false + (false, true) } else { debug!( - "License requires renewal, as it is about to expire and is not past the maximum overdue time" + "License requires renewal, as it is about to expire and is not past \ + the maximum overdue time" ); - true + (true, false) } } else { // This if is only for logging purposes, to provide more detailed information @@ -534,14 +580,18 @@ pub async fn run_periodic_license_check(pool: &PgPool) -> Result<(), LicenseErro } else { debug!("License is not a subscription, skipping renewal check"); } - false + (false, false) } } else { debug!("No license found, skipping license check"); - false + (false, true) } }; + if trim_components { + trim_gateways_and_edges(pool, &proxy_control_tx).await?; + } + if requires_renewal { info!("License requires renewal, renewing license..."); check_period = *config.check_period_renewal_window; @@ -556,8 +606,8 @@ pub async fn run_periodic_license_check(pool: &PgPool) -> Result<(), LicenseErro } Err(err) => { error!( - "Couldn't save the newly fetched license key to the database, error: {}", - err + "Couldn't save the newly fetched license key to the database, error: \ + {err}" ); } }, @@ -581,6 +631,11 @@ pub(crate) const PUBLIC_KEY: &[u8] = include_bytes!("test_key.asc"); #[cfg(test)] mod test { use chrono::TimeZone; + use defguard_common::db::{ + models::{User, WireguardNetwork}, + setup_pool, + }; + use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use super::*; @@ -756,4 +811,60 @@ mod test { let enterprise_license = License::from_base64(enterprise_license).unwrap(); assert_eq!(enterprise_license.tier, LicenseTier::Enterprise); } + + #[sqlx::test] + async fn test_trim_gateways_and_edges(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + + let location = WireguardNetwork::default().save(&pool).await.unwrap(); + let user = User::new( + "tester", + Some("hunter2"), + "Tes", + "Ter", + "email@email.com", + None, + ) + .save(&pool) + .await + .unwrap(); + let fullname = user.fullname(); + + Gateway::new(location.id, "Gateway 1", "localhost", 8000, &fullname) + .save(&pool) + .await + .unwrap(); + Gateway::new(location.id, "Gateway 2", "localhost", 8001, &fullname) + .save(&pool) + .await + .unwrap(); + + Proxy::new("Proxy 1", "localhost", 9000, &fullname) + .save(&pool) + .await + .unwrap(); + Proxy::new("Proxy 2", "localhost", 9001, &fullname) + .save(&pool) + .await + .unwrap(); + + let (proxy_control_tx, mut proxy_control_rx) = + tokio::sync::mpsc::channel::(8); + + trim_gateways_and_edges(&pool, &proxy_control_tx) + .await + .unwrap(); + + let all_gateways = Gateway::all(&pool).await.unwrap(); + assert_eq!(1, all_gateways.iter().filter(|gw| gw.enabled).count()); + assert_eq!(1, all_gateways.iter().filter(|gw| !gw.enabled).count()); + + let all_proxies = Proxy::all(&pool).await.unwrap(); + assert_eq!(1, all_proxies.iter().filter(|gw| gw.enabled).count()); + assert_eq!(1, all_proxies.iter().filter(|gw| !gw.enabled).count()); + + // Only one Proxy has to be shut down. + assert!(proxy_control_rx.try_recv().is_ok()); + assert!(proxy_control_rx.try_recv().is_err()); + } } diff --git a/crates/defguard_core/src/handlers/session_info.rs b/crates/defguard_core/src/handlers/session_info.rs index 65361742e6..be2f32c1c7 100644 --- a/crates/defguard_core/src/handlers/session_info.rs +++ b/crates/defguard_core/src/handlers/session_info.rs @@ -27,15 +27,14 @@ pub(crate) async fn get_session_info( }, StatusCode::OK, )); - } else { - return Ok(ApiResponse::json( - SessionInfoResponse { - authorized: false, - wizard_flags: None, - }, - StatusCode::OK, - )); } + return Ok(ApiResponse::json( + SessionInfoResponse { + authorized: false, + wizard_flags: None, + }, + StatusCode::OK, + )); }; let Some(user) = User::find_by_id(pool, session.user_id).await? else { diff --git a/crates/defguard_session_manager/tests/common/mod.rs b/crates/defguard_session_manager/tests/common/mod.rs index f727b8d11c..005530c367 100644 --- a/crates/defguard_session_manager/tests/common/mod.rs +++ b/crates/defguard_session_manager/tests/common/mod.rs @@ -54,7 +54,7 @@ pub(crate) async fn create_network(pool: &sqlx::PgPool) -> WireguardNetwork None, 1420, 0, - vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0).unwrap()], + vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0).unwrap()], 25, 300, false, diff --git a/crates/defguard_setup/src/handlers/initial_wizard.rs b/crates/defguard_setup/src/handlers/initial_wizard.rs index 5312f920cf..b61cb61de7 100644 --- a/crates/defguard_setup/src/handlers/initial_wizard.rs +++ b/crates/defguard_setup/src/handlers/initial_wizard.rs @@ -51,8 +51,7 @@ async fn advance_initial_wizard_to_step( let current_step = wizard .initial_setup_state .as_ref() - .map(|s| s.step) - .unwrap_or(InitialSetupStep::Welcome); + .map_or(InitialSetupStep::Welcome, |s| s.step); if current_step < step { wizard.initial_setup_state = Some(InitialSetupState { step }); wizard.save(pool).await?; diff --git a/tools/defguard_generator/src/vpn_session_stats.rs b/tools/defguard_generator/src/vpn_session_stats.rs index a83af4e425..de6b4e093e 100644 --- a/tools/defguard_generator/src/vpn_session_stats.rs +++ b/tools/defguard_generator/src/vpn_session_stats.rs @@ -164,7 +164,7 @@ async fn prepare_gateway(pool: &PgPool, location_id: Id) -> Result> match existing_gateways.into_iter().next() { Some(gateway) => Ok(gateway), None => { - let gateway = Gateway::new(location_id, "test", "localhost", 50055, 1) + let gateway = Gateway::new(location_id, "test", "localhost", 50055, "Generator") .save(pool) .await?; Ok(gateway)