diff --git a/crates/defguard/src/main.rs b/crates/defguard/src/main.rs index fbc0e5f7c0..49e1b67790 100644 --- a/crates/defguard/src/main.rs +++ b/crates/defguard/src/main.rs @@ -108,7 +108,12 @@ async fn main() -> Result<(), anyhow::Error> { // initialize global settings struct initialize_current_settings(&pool).await?; - let has_auto_adopt_flags = config.adopt_edge.is_some() || config.adopt_gateway.is_some(); + // Both flags must be provided together + if let Err(msg) = config.validate_adopt_flags() { + anyhow::bail!("{msg}"); + } + + let has_auto_adopt_flags = config.adopt_edge.is_some() && config.adopt_gateway.is_some(); let wizard = Wizard::init(&pool, has_auto_adopt_flags).await?; let mut ini_server_config = true; diff --git a/crates/defguard_common/src/config.rs b/crates/defguard_common/src/config.rs index e9f39ab14c..3ef3100145 100644 --- a/crates/defguard_common/src/config.rs +++ b/crates/defguard_common/src/config.rs @@ -238,6 +238,23 @@ impl DefGuardConfig { Self::parse_from::<[_; 0], String>([]) } + /// Validate that the auto-adoption flags are consistent. + /// + /// Both `--adopt-edge` and `--adopt-gateway` must be supplied together. + pub fn validate_adopt_flags(&self) -> Result<(), String> { + match (&self.adopt_edge, &self.adopt_gateway) { + (Some(_), None) => Err("--adopt-edge (DEFGUARD_ADOPT_EDGE) was provided but \ + --adopt-gateway (DEFGUARD_ADOPT_GATEWAY) is missing. \ + Both flags must be provided together to launch the auto-adoption wizard." + .to_string()), + (None, Some(_)) => Err("--adopt-gateway (DEFGUARD_ADOPT_GATEWAY) was provided but \ + --adopt-edge (DEFGUARD_ADOPT_EDGE) is missing. \ + Both flags must be provided together to launch the auto-adoption wizard." + .to_string()), + _ => Ok(()), + } + } + /// Initialize values that depend on Settings. pub fn initialize_post_settings(&mut self) { let url = Settings::url().expect("Unable to parse Defguard URL."); @@ -316,4 +333,36 @@ mod tests { assert_eq!(config.cookie_domain, Some("example.com".to_string())); } + + fn make_config(adopt_edge: Option<&str>, adopt_gateway: Option<&str>) -> DefGuardConfig { + let mut config = DefGuardConfig::new_test_config(); + config.adopt_edge = adopt_edge.map(str::to_string); + config.adopt_gateway = adopt_gateway.map(str::to_string); + config + } + + #[test] + fn test_validate_adopt_flags() { + // neither flag: valid, no auto-adoption requested + assert!(make_config(None, None).validate_adopt_flags().is_ok()); + + // both flags: valid + assert!( + make_config(Some("edge.example.com:8080"), Some("gw.example.com:8080")) + .validate_adopt_flags() + .is_ok() + ); + + // only one flag at a time: must be an error + assert!( + make_config(Some("edge.example.com:8080"), None) + .validate_adopt_flags() + .is_err() + ); + assert!( + make_config(None, Some("gw.example.com:8080")) + .validate_adopt_flags() + .is_err() + ); + } } diff --git a/crates/defguard_setup/src/auto_adoption.rs b/crates/defguard_setup/src/auto_adoption.rs index bdb19cfa47..658198198e 100644 --- a/crates/defguard_setup/src/auto_adoption.rs +++ b/crates/defguard_setup/src/auto_adoption.rs @@ -844,11 +844,23 @@ async fn create_proxy( Ok(()) } -/// Stores and updates startup auto-adoption states for components requested via CLI flags. +/// Stores and updates startup auto-adoption states for both components. +/// +/// Both `config.adopt_edge` and `config.adopt_gateway` must be set before calling this +/// function. Callers (i.e. `main`) are responsible for enforcing that invariant. pub async fn attempt_auto_adoption( pool: &PgPool, config: &DefGuardConfig, ) -> Result<(), anyhow::Error> { + let (edge_endpoint, gateway_endpoint) = match (&config.adopt_edge, &config.adopt_gateway) { + (Some(e), Some(g)) => (e, g), + _ => { + anyhow::bail!( + "Both --adopt-edge and --adopt-gateway must be set to run the auto-adoption wizard" + ); + } + }; + let mut auto_state = AutoAdoptionWizardState::get(pool) .await .context("Failed to load auto-adoption wizard state")? @@ -863,62 +875,60 @@ pub async fn attempt_auto_adoption( .get(&SetupAutoAdoptionComponent::Gateway) .is_some_and(|result| result.success); - let should_run_edge = config.adopt_edge.is_some() && !edge_already_succeeded; - let should_run_gateway = config.adopt_gateway.is_some() && !gateway_already_succeeded; - let auto_mode_requested = should_run_edge || should_run_gateway; - if auto_mode_requested { + if !edge_already_succeeded || !gateway_already_succeeded { ensure_ca_for_auto_adoption(pool).await?; } - if let Some(endpoint) = &config.adopt_edge { - if edge_already_succeeded { - info!( - "Skipping startup auto-adoption for Edge component endpoint={endpoint} as it was already completed" + if edge_already_succeeded { + info!( + "Skipping startup auto-adoption for Edge component endpoint={edge_endpoint} as it was already completed" + ); + } else { + info!("Starting startup auto-adoption for Edge component endpoint={edge_endpoint}"); + if let Err(err) = + process_startup_auto_adoption(pool, SetupAutoAdoptionComponent::Edge, edge_endpoint) + .await + { + auto_state.adoption_result.insert( + SetupAutoAdoptionComponent::Edge, + AutoAdoptionComponentResult { + success: false, + logs: vec![format!("Startup auto-adoption failed: {err}")], + updated_at: chrono::Utc::now().naive_utc(), + }, ); + auto_state.save(pool).await?; } else { - info!("Starting startup auto-adoption for Edge component endpoint={endpoint}"); - if let Err(err) = - process_startup_auto_adoption(pool, SetupAutoAdoptionComponent::Edge, endpoint) - .await - { - auto_state.adoption_result.insert( - SetupAutoAdoptionComponent::Edge, - AutoAdoptionComponentResult { - success: false, - logs: vec![format!("Startup auto-adoption failed: {err}")], - updated_at: chrono::Utc::now().naive_utc(), - }, - ); - auto_state.save(pool).await?; - } else { - info!("Startup auto-adoption for Edge component completed endpoint={endpoint}"); - } + info!("Startup auto-adoption for Edge component completed endpoint={edge_endpoint}"); } } - if let Some(endpoint) = &config.adopt_gateway { - if gateway_already_succeeded { - info!( - "Skipping startup auto-adoption for Gateway component endpoint={endpoint} as it was already completed" + if gateway_already_succeeded { + info!( + "Skipping startup auto-adoption for Gateway component endpoint={gateway_endpoint} as it was already completed" + ); + } else { + info!("Starting startup auto-adoption for Gateway component endpoint={gateway_endpoint}"); + if let Err(err) = process_startup_auto_adoption( + pool, + SetupAutoAdoptionComponent::Gateway, + gateway_endpoint, + ) + .await + { + auto_state.adoption_result.insert( + SetupAutoAdoptionComponent::Gateway, + AutoAdoptionComponentResult { + success: false, + logs: vec![format!("Startup auto-adoption failed: {err}")], + updated_at: chrono::Utc::now().naive_utc(), + }, ); + auto_state.save(pool).await?; } else { - info!("Starting startup auto-adoption for Gateway component endpoint={endpoint}"); - if let Err(err) = - process_startup_auto_adoption(pool, SetupAutoAdoptionComponent::Gateway, endpoint) - .await - { - auto_state.adoption_result.insert( - SetupAutoAdoptionComponent::Gateway, - AutoAdoptionComponentResult { - success: false, - logs: vec![format!("Startup auto-adoption failed: {err}")], - updated_at: chrono::Utc::now().naive_utc(), - }, - ); - auto_state.save(pool).await?; - } else { - info!("Startup auto-adoption for Gateway component completed endpoint={endpoint}"); - } + info!( + "Startup auto-adoption for Gateway component completed endpoint={gateway_endpoint}" + ); } } diff --git a/crates/defguard_setup/tests/auto_adoption_wizard.rs b/crates/defguard_setup/tests/auto_adoption_wizard.rs index aa09acb6b1..c60e98a361 100644 --- a/crates/defguard_setup/tests/auto_adoption_wizard.rs +++ b/crates/defguard_setup/tests/auto_adoption_wizard.rs @@ -1,13 +1,17 @@ -use defguard_common::db::{ - models::{ - Settings, WireguardNetwork, - settings::initialize_current_settings, - setup_auto_adoption::{AutoAdoptionWizardState, AutoAdoptionWizardStep}, - wireguard::{LocationMfaMode, ServiceLocationMode}, - wizard::{ActiveWizard, Wizard}, +use defguard_common::{ + config::DefGuardConfig, + db::{ + models::{ + Settings, WireguardNetwork, + settings::initialize_current_settings, + setup_auto_adoption::{AutoAdoptionWizardState, AutoAdoptionWizardStep}, + wireguard::{LocationMfaMode, ServiceLocationMode}, + wizard::{ActiveWizard, Wizard}, + }, + setup_pool, }, - setup_pool, }; +use defguard_setup::auto_adoption::attempt_auto_adoption; use ipnetwork::IpNetwork; use reqwest::{ Client, StatusCode, @@ -380,3 +384,46 @@ async fn test_auto_adoption_vpn_settings_missing_network( // Step must NOT have advanced past VpnSettings assert_auto_adoption_step(&pool, AutoAdoptionWizardStep::VpnSettings).await; } + +fn config_with_flags(adopt_edge: Option<&str>, adopt_gateway: Option<&str>) -> DefGuardConfig { + let mut config = DefGuardConfig::new_test_config(); + config.adopt_edge = adopt_edge.map(str::to_string); + config.adopt_gateway = adopt_gateway.map(str::to_string); + config +} + +/// attempt_auto_adoption must fail immediately when fewer than both flags are set. +#[sqlx::test] +async fn test_attempt_auto_adoption_requires_both_flags( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = defguard_common::db::setup_pool(options).await; + initialize_current_settings(&pool) + .await + .expect("Failed to initialize settings"); + + // only adopt_edge + assert!( + attempt_auto_adoption( + &pool, + &config_with_flags(Some("edge.example.com:8080"), None) + ) + .await + .is_err() + ); + + // only adopt_gateway + assert!( + attempt_auto_adoption(&pool, &config_with_flags(None, Some("gw.example.com:8080"))) + .await + .is_err() + ); + + // neither flag + assert!( + attempt_auto_adoption(&pool, &config_with_flags(None, None)) + .await + .is_err() + ); +} diff --git a/crates/defguard_setup/tests/session_info.rs b/crates/defguard_setup/tests/session_info.rs index b90ad54d80..f27d8d034a 100644 --- a/crates/defguard_setup/tests/session_info.rs +++ b/crates/defguard_setup/tests/session_info.rs @@ -84,7 +84,7 @@ async fn test_session_info_auto_adoption_wizard(_: PgPoolOptions, options: PgCon initialize_current_settings(&pool) .await .expect("Failed to initialize settings"); - // has_auto_adopt_flags = true: AutoAdoption wizard + // has_auto_adopt_flags = true (both flags provided): AutoAdoption wizard Wizard::init(&pool, true) .await .expect("Failed to initialize wizard"); diff --git a/crates/defguard_setup/tests/wizard_init.rs b/crates/defguard_setup/tests/wizard_init.rs index 8881acd769..7084c604ae 100644 --- a/crates/defguard_setup/tests/wizard_init.rs +++ b/crates/defguard_setup/tests/wizard_init.rs @@ -46,7 +46,7 @@ async fn test_wizard_init_auto_adopt_flags(_: PgPoolOptions, options: PgConnectO .await .expect("Failed to initialize settings"); - // Fresh DB + auto-adopt flags: AutoAdoption wizard + // Fresh DB + both auto-adopt flags provided: AutoAdoption wizard let wizard = Wizard::init(&pool, true) .await .expect("Failed to init wizard"); @@ -136,7 +136,7 @@ async fn test_wizard_init_idempotent(_: PgPoolOptions, options: PgConnectOptions assert_eq!( third.active_wizard, ActiveWizard::Initial, - "Already-active wizard should not be switched by flags" + "Already-active wizard should not be switched even when both adopt flags are set" ); // Simulate completion: mark wizard as done