From 2363c85778defa50441c2cc5a6bc483ba9730f6d Mon Sep 17 00:00:00 2001 From: Aleksander <170264518+t-aleksander@users.noreply.github.com> Date: Wed, 21 May 2025 15:52:51 +0200 Subject: [PATCH] change vpn init command --- src/config.rs | 2 + src/db/models/wireguard.rs | 2 + src/lib.rs | 95 ++++++++++++++++++++++++++++---------- 3 files changed, 75 insertions(+), 24 deletions(-) diff --git a/src/config.rs b/src/config.rs index 19fd9ecb52..f5bc6ef94b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -188,6 +188,8 @@ pub struct InitVpnLocationArgs { pub dns: Option, #[arg(long)] pub allowed_ips: Vec, + #[arg(long)] + pub id: Option, } impl DefGuardConfig { diff --git a/src/db/models/wireguard.rs b/src/db/models/wireguard.rs index df1f81f35c..6b54ba8ec5 100644 --- a/src/db/models/wireguard.rs +++ b/src/db/models/wireguard.rs @@ -597,7 +597,9 @@ impl WireguardNetwork { } /// Refresh network IPs for all relevant devices + /// /// If the list of allowed devices has changed add/remove devices accordingly + /// /// If the network address has changed readdress existing devices pub(crate) async fn sync_allowed_devices( &self, diff --git a/src/lib.rs b/src/lib.rs index 10bafcb4f5..9263c39c18 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -739,36 +739,83 @@ pub async fn init_dev_env(config: &DefGuardConfig) { /// Create a new VPN location. /// Meant to be used to automate setting up a new defguard instance. -/// Does not handle assigning device IPs, since no device should exist at this point. +/// If the network ID has been specified, it will be assumed that the user wants to update the existing network or create a new one with a predefined ID. +/// This is mainly used for deployment purposes where the network ID must be known beforehand. +/// +/// If there is no ID specified, the function will only create the network if no other network exists. +/// In other words, multiple networks can be created, but only if the ID is predefined for each network. pub async fn init_vpn_location( pool: &PgPool, args: &InitVpnLocationArgs, ) -> Result { - // check if a VPN location exists already - let networks = WireguardNetwork::all(pool).await?; - if !networks.is_empty() { - return Err(anyhow!( - "Failed to initialize first VPN location. A location already exists." - )); + // The ID is predefined + let network = if let Some(location_id) = args.id { + let mut transaction = pool.begin().await?; + // If the network already exists, update it, assuming that's the user's intent. + let network = if let Some(mut network) = + WireguardNetwork::find_by_id(&mut *transaction, location_id).await? + { + network.name = args.name.clone(); + network.address = vec![args.address]; + network.port = args.port; + network.endpoint = args.endpoint.clone(); + network.dns = args.dns.clone(); + network.allowed_ips = args.allowed_ips.clone(); + network.save(&mut *transaction).await?; + network.sync_allowed_devices(&mut *transaction, None); + network + } + // Otherwise create it with the predefined ID + else { + let mut network = WireguardNetwork::new( + args.name.clone(), + vec![args.address], + args.port, + args.endpoint.clone(), + args.dns.clone(), + args.allowed_ips.clone(), + false, + DEFAULT_KEEPALIVE_INTERVAL, + DEFAULT_DISCONNECT_THRESHOLD, + false, + false, + )? + .with_id(location_id); + network.save(&mut *transaction).await?; + network.add_all_allowed_devices(&mut transaction).await?; + network + }; + transaction.commit().await?; + network + } + // No predefined ID, add the network if no other networks are present + else { + // check if a VPN location exists already + let networks = WireguardNetwork::all(pool).await?; + if !networks.is_empty() { + return Err(anyhow!( + "Failed to initialize first VPN location. A location already exists." + )); + }; + + // create a new network + WireguardNetwork::new( + args.name.clone(), + vec![args.address], + args.port, + args.endpoint.clone(), + args.dns.clone(), + args.allowed_ips.clone(), + false, + DEFAULT_KEEPALIVE_INTERVAL, + DEFAULT_DISCONNECT_THRESHOLD, + false, + false, + )? + .save(pool) + .await? }; - // create a new network - let network = WireguardNetwork::new( - args.name.clone(), - vec![args.address], - args.port, - args.endpoint.clone(), - args.dns.clone(), - args.allowed_ips.clone(), - false, - DEFAULT_KEEPALIVE_INTERVAL, - DEFAULT_DISCONNECT_THRESHOLD, - false, - false, - )? - .save(pool) - .await?; - // generate gateway token let token = Claims::new( ClaimsType::Gateway,