Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions crates/defguard_core/src/db/models/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,7 @@ impl Device<Id> {
/// - `transaction`: Active PostgreSQL connection to check and insert assignments.
/// - `network`: The `WireguardNetwork<Id>` whose subnets will be assigned.
/// - `reserved_ips`: Optional slice of IPs that must not be assigned, even if otherwise free.
/// - `current_ips`: Optional slice of IPs already assigned to the device - won't be reassigned if they are still valid.
///
/// # Returns
///
Expand All @@ -821,6 +822,7 @@ impl Device<Id> {
transaction: &mut PgConnection,
network: &WireguardNetwork<Id>,
reserved_ips: Option<&[IpAddr]>,
current_ips: Option<&[IpAddr]>,
) -> Result<WireguardNetworkDevice, ModelError> {
debug!(
"Assiging IP addresses for device: {} in network {}",
Expand All @@ -835,6 +837,17 @@ impl Device<Id> {
"Assigning address to device {} in network {} {address}",
self.name, network.name,
);
// Don't reassign addresses for networks that didn't change
if let Some(ip) =
current_ips.and_then(|ips| ips.iter().find(|ip| address.contains(**ip)))
{
debug!(
"Skipping reassignment of already assigned valid IP {ip} for device {} in network {} with addresses {:?}",
self.name, network.name, network.address
);
ips.push(*ip);
continue;
}
let mut picked = None;
for ip in address {
if network
Expand Down
24 changes: 15 additions & 9 deletions crates/defguard_core/src/db/models/wireguard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ impl WireguardNetwork<Id> {
let devices = self.get_allowed_devices(&mut *transaction).await?;
for device in devices {
device
.assign_next_network_ip(&mut *transaction, self, None)
.assign_next_network_ip(&mut *transaction, self, None, None)
.await?;
}
Ok(())
Expand All @@ -454,7 +454,7 @@ impl WireguardNetwork<Id> {
let allowed_device_ids: Vec<i64> = allowed_devices.iter().map(|dev| dev.id).collect();
if allowed_device_ids.contains(&device.id) {
let wireguard_network_device = device
.assign_next_network_ip(&mut *transaction, self, reserved_ips)
.assign_next_network_ip(&mut *transaction, self, reserved_ips, None)
.await?;
Ok(wireguard_network_device)
} else {
Expand All @@ -475,27 +475,33 @@ impl WireguardNetwork<Id> {
self.address.iter().find(|net| net.contains(addr)).copied()
}

/// Works out which devices need to be added, removed, or readdressed
/// based on the list of currently configured devices and the list of devices which should be allowed
/// Works out which devices need to be added, removed, or readdressed based on the list
/// of currently configured devices and the list of devices which should be allowed.
async fn process_device_access_changes(
&self,
transaction: &mut PgConnection,
mut allowed_devices: HashMap<Id, Device<Id>>,
currently_configured_devices: Vec<WireguardNetworkDevice>,
reserved_ips: Option<&[IpAddr]>,
) -> Result<Vec<GatewayEvent>, WireguardNetworkError> {
// Loop through current device configurations; remove no longer allowed, readdress when necessary; remove processed entry from all devices list
// initial list should now contain only devices to be added
// Loop through current device configurations; remove no longer allowed, readdress
// when necessary; remove processed entry from all devices list initial list should
// now contain only devices to be added.
let mut events: Vec<GatewayEvent> = Vec::new();
for device_network_config in currently_configured_devices {
// Device is allowed and an IP was already assigned
if let Some(device) = allowed_devices.remove(&device_network_config.device_id) {
// Network address changed and IP addresses need to be updated
// Network address has changed and IP addresses need to be updated
if !self.contains_all(&device_network_config.wireguard_ips)
|| self.address.len() != device_network_config.wireguard_ips.len()
{
let wireguard_network_device = device
.assign_next_network_ip(&mut *transaction, self, reserved_ips)
.assign_next_network_ip(
&mut *transaction,
self,
reserved_ips,
Some(&device_network_config.wireguard_ips),
)
.await?;
events.push(GatewayEvent::DeviceModified(DeviceInfo {
device,
Expand Down Expand Up @@ -537,7 +543,7 @@ impl WireguardNetwork<Id> {
// Add configs for new allowed devices
for device in allowed_devices.into_values() {
let wireguard_network_device = device
.assign_next_network_ip(&mut *transaction, self, reserved_ips)
.assign_next_network_ip(&mut *transaction, self, reserved_ips, None)
.await?;
events.push(GatewayEvent::DeviceCreated(DeviceInfo {
device,
Expand Down
2 changes: 1 addition & 1 deletion crates/defguard_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,7 @@ pub async fn init_dev_env(config: &DefGuardConfig) {
.await
.expect("Could not save device");
device
.assign_next_network_ip(&mut transaction, &network, None)
.assign_next_network_ip(&mut transaction, &network, None, None)
.await
.expect("Could not assign IP to device");
}
Expand Down
116 changes: 116 additions & 0 deletions crates/defguard_core/tests/integration/wireguard.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};

use defguard_core::{
db::{
models::{
Expand Down Expand Up @@ -275,6 +277,120 @@ async fn test_device(_: PgPoolOptions, options: PgConnectOptions) {
assert!(devices.is_empty());
}

#[sqlx::test]
async fn test_network_address_reassignment(_: PgPoolOptions, options: PgConnectOptions) {
let pool = setup_pool(options).await;

let (client, client_state) = make_test_client(pool).await;

let auth = Auth::new("admin", "pass123");
let response = &client.post("/api/v1/auth").json(&auth).send().await;
assert_eq!(response.status(), StatusCode::OK);

// create network
let network = json!({
"name": "network",
"address": "10.1.1.1/24",
"port": 55555,
"endpoint": "192.168.4.14",
"allowed_ips": "10.1.1.0/24",
"dns": "1.1.1.1",
"allowed_groups": [],
"mfa_enabled": false,
"keepalive_interval": 25,
"peer_disconnect_threshold": 180,
"acl_enabled": false,
"acl_default_allow": false
});
let response = client.post("/api/v1/network").json(&network).send().await;
assert_eq!(response.status(), StatusCode::CREATED);

// network details
let response = client.get("/api/v1/network/1").send().await;
assert_eq!(response.status(), StatusCode::OK);
let network_from_details: WireguardNetwork<Id> = response.json().await;

// create devices
let device = json!({
"name": "device1",
"wireguard_pubkey": "LQKsT6/3HWKuJmMulH63R8iK+5sI8FyYEL6WDIi6lQU=",
});
let response = client
.post("/api/v1/device/admin")
.json(&device)
.send()
.await;
assert_eq!(response.status(), StatusCode::CREATED);
let device = json!({
"name": "device2",
"wireguard_pubkey": "ZqDlG4LQZRO9v57Sd27AHdtTLxegbMp5oVThjYrg21I=",
});
let response = client
.post("/api/v1/device/admin")
.json(&device)
.send()
.await;
assert_eq!(response.status(), StatusCode::CREATED);

// ensure IPs were assigned for new devices
let network_devices = WireguardNetworkDevice::find_by_device(&client_state.pool, 1)
.await
.unwrap()
.unwrap();
assert_eq!(
network_devices[0].wireguard_ips,
vec![IpAddr::V4(Ipv4Addr::new(10, 1, 1, 2))],
);
let network_devices = WireguardNetworkDevice::find_by_device(&client_state.pool, 2)
.await
.unwrap()
.unwrap();
assert_eq!(
network_devices[0].wireguard_ips,
vec![IpAddr::V4(Ipv4Addr::new(10, 1, 1, 3))],
);

// delete the first device
let response = client.delete("/api/v1/device/1").json(&device).send().await;
assert_eq!(response.status(), StatusCode::OK);

// modify network addresses
let network = json!({
"id": network_from_details.id,
"name": "network",
"address": "10.1.1.1/24,fc00::1/112",
"port": 55555,
"endpoint": "192.168.4.14",
"allowed_ips": "10.1.1.0/24",
"dns": "1.1.1.1",
"allowed_groups": [],
"mfa_enabled": false,
"keepalive_interval": 25,
"peer_disconnect_threshold": 180,
"acl_enabled": false,
"acl_default_allow": false
});
let response = client
.put(format!("/api/v1/network/{}", network_from_details.id))
.json(&network)
.send()
.await;
assert_eq!(response.status(), StatusCode::OK);

// ensure IPv4 address wasn't reassigned
let network_devices = WireguardNetworkDevice::find_by_device(&client_state.pool, 2)
.await
.unwrap()
.unwrap();
assert_eq!(
network_devices[0].wireguard_ips,
vec![
IpAddr::V4(Ipv4Addr::new(10, 1, 1, 3)),
IpAddr::V6(Ipv6Addr::new(0xfc00, 0, 0, 0, 0, 0, 0, 2)),
],
);
}

#[sqlx::test]
async fn test_device_permissions(_: PgPoolOptions, options: PgConnectOptions) {
let pool = setup_pool(options).await;
Expand Down