Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion crates/defguard/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ async fn main() -> Result<(), anyhow::Error> {
}

let (proxy_control_tx, proxy_control_rx) = channel::<ProxyControlMessage>(100);
let proxy_secret_key = settings.secret_key_required()?.to_string();
let proxy_secret_key = settings.secret_key_required()?;
let proxy_manager = ProxyManager::new(
pool.clone(),
ProxyTxSet::new(gateway_tx.clone(), bidi_event_tx.clone()),
Expand Down
155 changes: 112 additions & 43 deletions crates/defguard_common/src/db/models/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use sqlx::{
query_scalar,
};
use thiserror::Error;
use tracing::{debug, error, info};
use tracing::{debug, error, info, warn};
use utoipa::ToSchema;

use crate::{
Expand Down Expand Up @@ -534,8 +534,8 @@ impl WireguardNetworkDevice {

#[derive(Debug, Error)]
pub enum DeviceError {
#[error("Device {0} pubkey is the same as gateway pubkey for network {1}")]
PubkeyConflict(Device<Id>, String),
#[error("Device pubkey {0} is the same as gateway pubkey")]
PubkeyConflict(String),
#[error("Database error")]
DatabaseError(#[from] sqlx::Error),
#[error(transparent)]
Expand Down Expand Up @@ -753,13 +753,13 @@ impl Device<Id> {
Ok((device_network_info, device_config))
}

// Add device to all existing networks
/// Add device to all existing networks.
pub async fn add_to_all_networks(
&self,
transaction: &mut PgConnection,
conn: &mut PgConnection,
) -> Result<(Vec<DeviceNetworkInfo>, Vec<DeviceConfig>), DeviceError> {
info!("Adding device {} to all existing networks", self.name);
let networks = WireguardNetwork::all(&mut *transaction).await?;
let networks = WireguardNetwork::all(&mut *conn).await?;

let mut configs = Vec::new();
let mut network_info = Vec::new();
Expand All @@ -770,49 +770,52 @@ impl Device<Id> {
);
// check for pubkey conflicts with networks
if network.pubkey == self.wireguard_pubkey {
return Err(DeviceError::PubkeyConflict(self.clone(), network.name));
return Err(DeviceError::PubkeyConflict(self.wireguard_pubkey.clone()));
}
if WireguardNetworkDevice::find(&mut *transaction, self.id, network.id)
if WireguardNetworkDevice::find(&mut *conn, self.id, network.id)
.await?
.is_some()
{
debug!("Device {self} already has an IP within network {network}. Skipping...",);
debug!("Device {self} already has an IP within network {network}. Skipping...");
continue;
}

if let Ok(wireguard_network_device) = network
.add_device_to_network(&mut *transaction, self, None)
.await
{
debug!(
"Assigned IPs {} for device {} (user {}) in network {network}",
wireguard_network_device.wireguard_ips.as_csv(),
self.name,
self.user_id
);
let device_network_info = DeviceNetworkInfo {
network_id: network.id,
device_wireguard_ips: wireguard_network_device.wireguard_ips.clone(),
preshared_key: wireguard_network_device.preshared_key.clone(),
is_authorized: wireguard_network_device.is_authorized,
};
network_info.push(device_network_info);

let config = Self::create_config(&network, &wireguard_network_device);
configs.push(DeviceConfig {
network_id: network.id,
network_name: network.name,
config,
endpoint: format!("{}:{}", network.endpoint, network.port),
address: wireguard_network_device.wireguard_ips,
allowed_ips: network.allowed_ips,
pubkey: network.pubkey,
dns: network.dns,
keepalive_interval: network.keepalive_interval,
location_mfa_mode: network.location_mfa_mode.clone(),
service_location_mode: network.service_location_mode.clone(),
});
}
// FIXME: don't ignore the error.
let Ok(wireguard_network_device) =
network.add_device_to_network(&mut *conn, self, None).await
else {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Log failure

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

warn!("Failed to add device {self} to network {network}");
continue;
};

debug!(
"Assigned IPs {} for device {} (user {}) in network {network}",
wireguard_network_device.wireguard_ips.as_csv(),
self.name,
self.user_id
);
let device_network_info = DeviceNetworkInfo {
network_id: network.id,
device_wireguard_ips: wireguard_network_device.wireguard_ips.clone(),
preshared_key: wireguard_network_device.preshared_key.clone(),
is_authorized: wireguard_network_device.is_authorized,
};
network_info.push(device_network_info);

let config = Self::create_config(&network, &wireguard_network_device);
configs.push(DeviceConfig {
network_id: network.id,
network_name: network.name,
config,
endpoint: format!("{}:{}", network.endpoint, network.port),
address: wireguard_network_device.wireguard_ips,
allowed_ips: network.allowed_ips,
pubkey: network.pubkey,
dns: network.dns,
keepalive_interval: network.keepalive_interval,
location_mfa_mode: network.location_mfa_mode.clone(),
service_location_mode: network.service_location_mode.clone(),
});
}
Ok((network_info, configs))
}
Expand Down Expand Up @@ -1052,7 +1055,7 @@ impl Device<Id> {

#[cfg(test)]
mod test {
use std::str::FromStr;
use std::{net::Ipv4Addr, str::FromStr};

use claims::{assert_err, assert_ok};
use sqlx::postgres::{PgConnectOptions, PgPoolOptions};
Expand Down Expand Up @@ -1526,4 +1529,70 @@ mod test {
assert_eq!(devices.len(), 1);
assert_eq!(devices[0].device_id, device.id);
}

// Mimic what add_device handler does.
#[sqlx::test]
fn test_saturated_network(_: PgPoolOptions, options: PgConnectOptions) {
let pool = setup_pool(options).await;

let user = User::new("tester", None, "Tester", "Test", "test@test.pl", None)
.save(&pool)
.await
.unwrap();

let network = WireguardNetwork::<NoId> {
address: vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(192, 168, 42, 4)), 29).unwrap()],
..Default::default()
}
.save(&pool)
.await
.unwrap();

let mut conn = pool.begin().await.unwrap();

for (name, pubkey) in [
("device1", "LQKsT6/3HWKuJmMulH63R8iK+5sI8FyYEL6WDIi6lQU="),
("device2", "AJwxGkzvVVn5Q1xjpCDFo5RJSU9KOPHeoEixYaj+20M="),
("device3", "OLQNaEH3FxW0hiodaChEHoETzd+7UzcqIbsLs+X8rD0="),
("device4", "mgVXE8WcfStoD8mRatHcX5aaQ0DlcpjvPXibHEOr9y8="),
("device5", "hNuapt7lOxF93KUqZGUY00oKJxH8LYwwsUVB1uUa0y4="),
] {
let device = Device::new(
name.to_string(),
pubkey.to_string(),
user.id,
DeviceType::User,
None,
true,
)
.save(&mut *conn)
.await
.unwrap();
let (_, _) = device.add_to_all_networks(&mut conn).await.unwrap();
}

// This device won't fit in the address space.
let _device = Device::new(
"device6".to_string(),
"fF9K0tgatZTEJRvzpNUswr0h8HqCIi+v39B45+QZZzE=".to_string(),
user.id,
DeviceType::User,
None,
true,
)
.save(&mut *conn)
.await
.unwrap();
// FIXME: uncomment when `add_to_all_networks` is fixed.
// assert!(device.add_to_all_networks(&mut conn).await.is_err());

conn.commit().await.unwrap();

let devices = Device::all(&pool).await.unwrap();
assert_eq!(6, devices.len(), "{devices:#?}");
let network_devices = WireguardNetworkDevice::all_for_network(&pool, network.id)
.await
.unwrap();
assert_eq!(5, network_devices.len(), "{network_devices:#?}");
}
}
55 changes: 15 additions & 40 deletions crates/defguard_common/src/db/models/wireguard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ use super::{
user::User,
};
use crate::{
auth::claims::{Claims, ClaimsType},
db::{
Id, NoId,
models::{
Expand Down Expand Up @@ -191,8 +190,6 @@ pub enum WireguardNetworkError {
DeviceNotAllowed(String),
#[error("Device error")]
DeviceError(#[from] DeviceError),
#[error(transparent)]
TokenError(#[from] jsonwebtoken::errors::Error),
}

#[derive(Debug, Error)]
Expand Down Expand Up @@ -269,10 +266,7 @@ impl WireguardNetwork {
}

impl WireguardNetwork<Id> {
pub async fn find_by_name<'e, E>(
executor: E,
name: &str,
) -> Result<Option<Vec<Self>>, WireguardNetworkError>
pub async fn find_by_name<'e, E>(executor: E, name: &str) -> sqlx::Result<Option<Vec<Self>>>
where
E: PgExecutor<'e>,
{
Expand All @@ -295,15 +289,14 @@ impl WireguardNetwork<Id> {
Ok(Some(networks))
}

#[allow(clippy::result_large_err)]
/// Check if given number of devices can fit in networks used by this location.
/// Note: `device_count` should include network and broadcast addresses.
pub fn validate_network_size(&self, device_count: usize) -> Result<(), WireguardNetworkError> {
debug!("Checking if {device_count} devices can fit in networks used by location {self}");
// if given location uses multiple subnets validate devices can fit them all
// If a given location uses multiple subnets, validate devices can fit them all.
for subnet in &self.address {
debug!("Checking if {device_count} devices can fit in network {subnet}");
let network_size = subnet.size();
// include address, network, and broadcast in the calculation
match network_size {
match subnet.size() {
NetworkSize::V4(size) => {
if device_count as u32 > size {
return Err(WireguardNetworkError::NetworkTooSmall);
Expand Down Expand Up @@ -452,23 +445,23 @@ impl WireguardNetwork<Id> {
/// Generate network IPs for a device if it's allowed in network
pub(crate) async fn add_device_to_network(
&self,
transaction: &mut PgConnection,
conn: &mut PgConnection,
device: &Device<Id>,
reserved_ips: Option<&[IpAddr]>,
) -> Result<WireguardNetworkDevice, WireguardNetworkError> {
info!("Assigning IP in network {self} for {device}");
let allowed_devices = self.get_allowed_devices(&mut *transaction).await?;
let allowed_device_ids: Vec<i64> = allowed_devices.iter().map(|dev| dev.id).collect();
let used_ips = self.all_used_ips_for_network(&mut *transaction).await?;
let allowed_devices = self.get_allowed_devices(&mut *conn).await?;
let allowed_device_ids = allowed_devices.iter().map(|dev| dev.id).collect::<Vec<_>>();
let used_ips = self.all_used_ips_for_network(&mut *conn).await?;

if allowed_device_ids.contains(&device.id) {
let wireguard_network_device = device
.assign_next_network_ip(&mut *transaction, self, &used_ips, reserved_ips, None)
.assign_next_network_ip(&mut *conn, self, &used_ips, reserved_ips, None)
.await?;
Ok(wireguard_network_device)
} else {
info!("Device {device} not allowed in network {self}");
Err(WireguardNetworkError::DeviceNotAllowed(format!("{device}")))
Err(WireguardNetworkError::DeviceNotAllowed(device.to_string()))
}
}

Expand All @@ -487,7 +480,7 @@ impl WireguardNetwork<Id> {
}

/// Update `connected_at` to the current time and save it to the database.
pub async fn touch_connected<'e, E>(&mut self, executor: E) -> Result<(), sqlx::Error>
pub async fn touch_connected<'e, E>(&mut self, executor: E) -> sqlx::Result<()>
where
E: PgExecutor<'e>,
{
Expand All @@ -510,7 +503,7 @@ impl WireguardNetwork<Id> {
devices: &[Device<Id>],
from: &NaiveDateTime,
aggregation: &DateTimeAggregation,
) -> Result<Vec<WireguardDeviceStatsRow>, sqlx::Error> {
) -> sqlx::Result<Vec<WireguardDeviceStatsRow>> {
if devices.is_empty() {
return Ok(Vec::new());
}
Expand Down Expand Up @@ -1169,9 +1162,7 @@ impl WireguardNetwork<Id> {
}

// fetch all locations using external MFA
pub async fn all_using_external_mfa<'e, E>(
executor: E,
) -> Result<Vec<Self>, WireguardNetworkError>
pub async fn all_using_external_mfa<'e, E>(executor: E) -> sqlx::Result<Vec<Self>>
where
E: PgExecutor<'e>,
{
Expand All @@ -1189,24 +1180,8 @@ impl WireguardNetwork<Id> {
Ok(locations)
}

/// Generates auth token for a VPN gateway
#[allow(clippy::result_large_err)]
pub fn generate_gateway_token(&self) -> Result<String, WireguardNetworkError> {
let location_id = self.id;

let token = Claims::new(
ClaimsType::Gateway,
format!("DEFGUARD-NETWORK-{location_id}"),
location_id.to_string(),
u32::MAX.into(),
)
.to_jwt()?;

Ok(token)
}

/// Fetch a list of all allowed groups for a given network from DB
pub async fn fetch_allowed_groups<'e, E>(&self, executor: E) -> Result<Vec<String>, ModelError>
pub async fn fetch_allowed_groups<'e, E>(&self, executor: E) -> sqlx::Result<Vec<String>>
where
E: PgExecutor<'e>,
{
Expand Down
Loading
Loading