diff --git a/Cargo.lock b/Cargo.lock index ecb8387668..3f840a5c20 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1098,6 +1098,7 @@ dependencies = [ "defguard_event_logger", "defguard_event_router", "defguard_mail", + "defguard_session_manager", "defguard_version", "dotenvy", "secrecy", @@ -1110,8 +1111,11 @@ name = "defguard_common" version = "1.6.0" dependencies = [ "anyhow", + "argon2", + "base32", "base64 0.22.1", "chrono", + "claims", "clap", "ed25519-dalek", "humantime", @@ -1125,14 +1129,18 @@ dependencies = [ "rsa", "secrecy", "serde", + "serde_cbor_2 0.12.0-dev", "sqlx", "struct-patch", "thiserror 2.0.17", "tonic", + "totp-lite", "tracing", "utoipa", "uuid", "vergen-git2", + "webauthn-rs", + "x25519-dalek", ] [[package]] @@ -1264,6 +1272,7 @@ dependencies = [ name = "defguard_proto" version = "0.0.0" dependencies = [ + "defguard_common", "prost", "serde", "tonic", @@ -1271,6 +1280,15 @@ dependencies = [ "tonic-prost-build", ] +[[package]] +name = "defguard_session_manager" +version = "0.0.0" +dependencies = [ + "defguard_common", + "sqlx", + "tokio", +] + [[package]] name = "defguard_version" version = "0.0.0" diff --git a/Cargo.toml b/Cargo.toml index 673c673ccc..29c703a168 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ defguard_event_logger = { path = "./crates/defguard_event_logger", version = "0. defguard_event_router = { path = "./crates/defguard_event_router", version = "0.0.0" } defguard_mail = { path = "./crates/defguard_mail", version = "0.0.0" } defguard_proto = { path = "./crates/defguard_proto", version = "0.0.0" } +defguard_session_manager = { path = "./crates/defguard_session_manager", version = "0.0.0" } defguard_version = { path = "./crates/defguard_version", version = "0.0.0" } defguard_web_ui = { path = "./crates/defguard_web_ui", version = "0.0.0" } model_derive = { path = "./crates/model_derive", version = "0.0.0" } diff --git a/crates/defguard/Cargo.toml b/crates/defguard/Cargo.toml index 8399bb0ae5..0dc170ff56 100644 --- a/crates/defguard/Cargo.toml +++ b/crates/defguard/Cargo.toml @@ -14,6 +14,7 @@ defguard_core = { workspace = true } defguard_event_router = { workspace = true } defguard_event_logger = { workspace = true } defguard_mail = { workspace = true } +defguard_session_manager = { workspace = true } defguard_version = { workspace = true } # external dependencies diff --git a/crates/defguard/src/main.rs b/crates/defguard/src/main.rs index 3c7576a2ee..087ac62bdf 100644 --- a/crates/defguard/src/main.rs +++ b/crates/defguard/src/main.rs @@ -9,12 +9,17 @@ use defguard_common::{ config::{Command, DefGuardConfig, SERVER_CONFIG}, db::{ init_db, - models::{Settings, settings::initialize_current_settings}, + models::{ + Settings, + User, + settings::initialize_current_settings, + // wireguard_peer_stats::WireguardPeerStats, + }, }, }; use defguard_core::{ auth::failed_login::FailedLoginMap, - db::{AppEvent, GatewayEvent, User}, + db::AppEvent, enterprise::{ activity_log_stream::activity_log_stream_manager::run_activity_log_stream_manager, license::{License, run_periodic_license_check, set_cached_license}, @@ -23,7 +28,7 @@ use defguard_core::{ events::{ApiEvent, BidiStreamEvent, GrpcEvent, InternalEvent}, grpc::{ WorkerState, - gateway::{client_state::ClientMap, map::GatewayMap}, + gateway::{client_state::ClientMap, events::GatewayEvent, map::GatewayMap}, run_grpc_bidi_stream, run_grpc_server, }, init_dev_env, init_vpn_location, run_web_server, @@ -35,6 +40,7 @@ use defguard_core::{ use defguard_event_logger::{message::EventLoggerMessage, run_event_logger}; use defguard_event_router::{RouterReceiverSet, run_event_router}; use defguard_mail::{Mail, run_mail_handler}; +// use defguard_session_manager::run_session_manager; use secrecy::ExposeSecret; use tokio::sync::{broadcast, mpsc::unbounded_channel}; @@ -106,6 +112,7 @@ async fn main() -> Result<(), anyhow::Error> { let (wireguard_tx, _wireguard_rx) = broadcast::channel::(256); let (mail_tx, mail_rx) = unbounded_channel::(); let (event_logger_tx, event_logger_rx) = unbounded_channel::(); + // let (peer_stats_tx, peer_stats_rx) = unbounded_channel::(); let worker_state = Arc::new(Mutex::new(WorkerState::new(webhook_tx.clone()))); let gateway_state = Arc::new(Mutex::new(GatewayMap::new())); @@ -220,6 +227,10 @@ async fn main() -> Result<(), anyhow::Error> { activity_log_stream_reload_notify.clone(), activity_log_messages_rx ) => error!("Activity log stream manager returned early: {res:?}"), + // res = run_session_manager( + // pool.clone(), + // peer_stats_rx + // ) => error!("VPN client session manager returned early: {res:?}"), } Ok(()) diff --git a/crates/defguard_common/Cargo.toml b/crates/defguard_common/Cargo.toml index 1855788cb0..6ddfceb7f7 100644 --- a/crates/defguard_common/Cargo.toml +++ b/crates/defguard_common/Cargo.toml @@ -11,8 +11,11 @@ rust-version.workspace = true model_derive.workspace = true anyhow.workspace = true +argon2.workspace = true +base32.workspace = true base64.workspace = true chrono.workspace = true +claims.workspace = true clap.workspace = true ed25519-dalek = { version = "2.2", features = ["rand_core"] } humantime.workspace = true @@ -24,13 +27,17 @@ reqwest.workspace = true rsa.workspace = true secrecy.workspace = true serde.workspace = true +serde_cbor.workspace = true sqlx.workspace = true struct-patch.workspace = true thiserror.workspace = true tonic.workspace = true +totp-lite.workspace = true tracing.workspace = true utoipa.workspace = true uuid.workspace = true +webauthn-rs.workspace = true +x25519-dalek.workspace = true [dev-dependencies] matches.workspace = true diff --git a/crates/defguard_core/src/db/models/device.rs b/crates/defguard_common/src/db/models/device.rs similarity index 93% rename from crates/defguard_core/src/db/models/device.rs rename to crates/defguard_common/src/db/models/device.rs index f359eb6383..abd3b17d8b 100644 --- a/crates/defguard_core/src/db/models/device.rs +++ b/crates/defguard_common/src/db/models/device.rs @@ -1,51 +1,52 @@ use std::{fmt, net::IpAddr}; -use base64::{Engine, prelude::BASE64_STANDARD}; -#[cfg(test)] -use chrono::NaiveDate; -use chrono::{NaiveDateTime, Timelike, Utc}; -use defguard_common::{ +use crate::{ + KEY_LENGTH, csv::AsCsv, - db::{Id, NoId, models::ModelError}, + db::{ + Id, NoId, + models::{ + ModelError, WireguardNetwork, + user::User, + wireguard::{ + LocationMfaMode, NetworkAddressError, ServiceLocationMode, WIREGUARD_MAX_HANDSHAKE, + }, + }, + }, }; +use base64::{Engine, prelude::BASE64_STANDARD}; +use chrono::{NaiveDate, NaiveDateTime, Timelike, Utc}; use ipnetwork::IpNetwork; use model_derive::Model; -#[cfg(test)] use rand::{ Rng, distributions::{Alphanumeric, DistString, Standard}, prelude::Distribution, }; +use serde::{Deserialize, Serialize}; use sqlx::{ Error as SqlxError, FromRow, PgConnection, PgExecutor, PgPool, Type, postgres::types::PgInterval, query, query_as, }; use thiserror::Error; +use tracing::{debug, error, info}; use utoipa::ToSchema; -use super::wireguard::{ - LocationMfaMode, NetworkAddressError, WIREGUARD_MAX_HANDSHAKE, WireguardNetwork, -}; -use crate::{ - KEY_LENGTH, - db::{User, models::wireguard::ServiceLocationMode}, -}; - #[derive(Serialize, ToSchema)] pub struct DeviceConfig { - pub(crate) network_id: Id, - pub(crate) network_name: String, - pub(crate) config: String, + pub network_id: Id, + pub network_name: String, + pub config: String, #[schema(value_type = String)] - pub(crate) address: Vec, - pub(crate) endpoint: String, + pub address: Vec, + pub endpoint: String, #[schema(value_type = String)] - pub(crate) allowed_ips: Vec, - pub(crate) pubkey: String, - pub(crate) dns: Option, - pub(crate) keepalive_interval: i32, - pub(crate) location_mfa_mode: LocationMfaMode, - pub(crate) service_location_mode: ServiceLocationMode, + pub allowed_ips: Vec, + pub pubkey: String, + pub dns: Option, + pub keepalive_interval: i32, + pub location_mfa_mode: LocationMfaMode, + pub service_location_mode: ServiceLocationMode, } // The type of a device: @@ -104,7 +105,6 @@ impl fmt::Display for Device { } } -#[cfg(test)] impl Distribution> for Standard { fn sample(&self, rng: &mut R) -> Device { Device { @@ -154,10 +154,7 @@ pub struct DeviceNetworkInfo { } impl DeviceInfo { - pub(crate) async fn from_device<'e, E>( - executor: E, - device: Device, - ) -> Result + pub async fn from_device<'e, E>(executor: E, device: Device) -> Result where E: PgExecutor<'e>, { @@ -288,7 +285,7 @@ pub struct ModifyDevice { impl WireguardNetworkDevice { #[must_use] - pub(crate) fn new(network_id: Id, device_id: Id, wireguard_ips: I) -> Self + pub fn new(network_id: Id, device_id: Id, wireguard_ips: I) -> Self where I: Into>, { @@ -310,7 +307,7 @@ impl WireguardNetworkDevice { .collect() } - pub(crate) async fn insert<'e, E>(&self, executor: E) -> Result<(), SqlxError> + pub async fn insert<'e, E>(&self, executor: E) -> Result<(), SqlxError> where E: PgExecutor<'e>, { @@ -334,7 +331,7 @@ impl WireguardNetworkDevice { Ok(()) } - pub(crate) async fn update<'e, E>(&self, executor: E) -> Result<(), SqlxError> + pub async fn update<'e, E>(&self, executor: E) -> Result<(), SqlxError> where E: PgExecutor<'e>, { @@ -355,7 +352,7 @@ impl WireguardNetworkDevice { Ok(()) } - pub(crate) async fn delete<'e, E>(&self, executor: E) -> Result<(), SqlxError> + pub async fn delete<'e, E>(&self, executor: E) -> Result<(), SqlxError> where E: PgExecutor<'e>, { @@ -371,7 +368,7 @@ impl WireguardNetworkDevice { Ok(()) } - pub(crate) async fn find<'e, E>( + pub async fn find<'e, E>( executor: E, device_id: Id, network_id: Id, @@ -397,10 +394,7 @@ impl WireguardNetworkDevice { /// Get a first network the device was added to. Useful for network devices to /// make sure they always pull only one network's config. - pub(crate) async fn find_first<'e, E>( - executor: E, - device_id: Id, - ) -> Result, SqlxError> + pub async fn find_first<'e, E>(executor: E, device_id: Id) -> Result, SqlxError> where E: PgExecutor<'e>, { @@ -444,10 +438,7 @@ impl WireguardNetworkDevice { }) } - pub(crate) async fn all_for_network<'e, E>( - executor: E, - network_id: Id, - ) -> Result, SqlxError> + pub async fn all_for_network<'e, E>(executor: E, network_id: Id) -> Result, SqlxError> where E: PgExecutor<'e>, { @@ -469,7 +460,7 @@ impl WireguardNetworkDevice { /// Get all devices for a given network and user /// Note: doesn't return network devices added by the user /// as they are not considered to be bound to the user - pub(crate) async fn all_for_network_and_user<'e, E>( + pub async fn all_for_network_and_user<'e, E>( executor: E, network_id: Id, user_id: Id, @@ -494,10 +485,7 @@ impl WireguardNetworkDevice { Ok(res) } - pub(crate) async fn network<'e, E>( - &self, - executor: E, - ) -> Result, SqlxError> + pub async fn network<'e, E>(&self, executor: E) -> Result, SqlxError> where E: PgExecutor<'e>, { @@ -559,7 +547,7 @@ impl Device { } impl Device { - pub(crate) fn update_from(&mut self, other: ModifyDevice) { + pub fn update_from(&mut self, other: ModifyDevice) { self.name = other.name; self.wireguard_pubkey = other.wireguard_pubkey; self.description = other.description; @@ -567,7 +555,7 @@ impl Device { /// Create WireGuard config for device. #[must_use] - pub(crate) fn create_config( + pub fn create_config( network: &WireguardNetwork, wireguard_network_device: &WireguardNetworkDevice, ) -> String { @@ -606,7 +594,7 @@ impl Device { ) } - pub(crate) async fn find_by_ip<'e, E>( + pub async fn find_by_ip<'e, E>( executor: E, ip: IpAddr, network_id: Id, @@ -628,10 +616,7 @@ impl Device { .await } - pub(crate) async fn find_by_pubkey<'e, E>( - executor: E, - pubkey: &str, - ) -> Result, SqlxError> + pub async fn find_by_pubkey<'e, E>(executor: E, pubkey: &str) -> Result, SqlxError> where E: PgExecutor<'e>, { @@ -646,7 +631,7 @@ impl Device { .await } - pub(crate) async fn find_by_id_and_username<'e, E: sqlx::PgExecutor<'e>>( + pub async fn find_by_id_and_username<'e, E: sqlx::PgExecutor<'e>>( executor: E, id: Id, username: &str, @@ -664,10 +649,7 @@ impl Device { .await } - pub(crate) async fn all_for_username( - pool: &PgPool, - username: &str, - ) -> Result, SqlxError> { + pub async fn all_for_username(pool: &PgPool, username: &str) -> Result, SqlxError> { query_as!( Self, "SELECT device.id, name, wireguard_pubkey, user_id, created, description, \ @@ -680,7 +662,7 @@ impl Device { .await } - pub(crate) async fn get_network_configs( + pub async fn get_network_configs( &self, network: &WireguardNetwork, transaction: &mut PgConnection, @@ -714,7 +696,7 @@ impl Device { Ok((device_network_info, device_config)) } - pub(crate) async fn add_to_network( + pub async fn add_to_network( &self, network: &WireguardNetwork, ip: &[IpAddr], @@ -834,7 +816,7 @@ impl Device { /// /// - `Ok(WireguardNetworkDevice)`: A new relation linking this device to its assigned IPs across all subnets. /// - `Err(ModelError::CannotCreate)`: If any subnet lacks an available IP. - pub(crate) async fn assign_next_network_ip( + pub async fn assign_next_network_ip( &self, transaction: &mut PgConnection, network: &WireguardNetwork, @@ -972,7 +954,7 @@ impl Device { Err(format!("{pubkey} is not a valid pubkey")) } - pub(crate) async fn find_by_type<'e, E>( + pub async fn find_by_type<'e, E>( executor: E, device_type: DeviceType, ) -> Result, SqlxError> @@ -987,7 +969,7 @@ impl Device { ).fetch_all(executor).await } - pub(crate) async fn find_by_type_and_network<'e, E>( + pub async fn find_by_type_and_network<'e, E>( executor: E, device_type: DeviceType, network_id: Id, @@ -1006,7 +988,7 @@ impl Device { ).fetch_all(executor).await } - pub(crate) async fn get_owner<'e, E>(&self, executor: E) -> Result, SqlxError> + pub async fn get_owner<'e, E>(&self, executor: E) -> Result, SqlxError> where E: PgExecutor<'e>, { @@ -1027,11 +1009,11 @@ mod test { use std::str::FromStr; use claims::{assert_err, assert_ok}; - use defguard_common::db::setup_pool; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; + use crate::db::setup_pool; + use super::*; - use crate::db::User; impl Device { /// Create new device and assign IP in a given network diff --git a/crates/defguard_core/src/db/models/group.rs b/crates/defguard_common/src/db/models/group.rs similarity index 58% rename from crates/defguard_core/src/db/models/group.rs rename to crates/defguard_common/src/db/models/group.rs index 017075934c..7415e2f897 100644 --- a/crates/defguard_core/src/db/models/group.rs +++ b/crates/defguard_common/src/db/models/group.rs @@ -1,12 +1,11 @@ use std::fmt; -use defguard_common::db::{Id, NoId, models::ModelError}; +use crate::db::{Id, NoId, models::user::User}; use model_derive::Model; -use sqlx::{Error as SqlxError, FromRow, PgConnection, PgExecutor, query, query_as, query_scalar}; +use serde::Serialize; +use sqlx::{Error as SqlxError, FromRow, PgExecutor, query, query_as, query_scalar}; use utoipa::ToSchema; -use crate::db::{User, WireguardNetwork}; - #[derive(Debug)] pub enum Permission { IsAdmin, @@ -22,12 +21,11 @@ impl fmt::Display for Permission { #[derive(Clone, Debug, Model, ToSchema, FromRow, PartialEq, Serialize)] pub struct Group { - pub(crate) id: I, + pub id: I, pub name: String, pub is_admin: bool, } -#[cfg(test)] impl Default for Group { fn default() -> Self { Self { @@ -124,7 +122,7 @@ impl Group { query_as(&query).fetch_all(executor).await } - pub(crate) async fn has_permission<'e, E>( + pub async fn has_permission<'e, E>( &self, executor: E, permission: Permission, @@ -140,7 +138,7 @@ impl Group { Ok(result.unwrap_or(false)) } - pub(crate) async fn set_permission<'e, E>( + pub async fn set_permission<'e, E>( &self, executor: E, permission: Permission, @@ -159,149 +157,12 @@ impl Group { } } -impl WireguardNetwork { - /// Fetch a list of all allowed groups for a given network from DB - pub async fn fetch_allowed_groups<'e, E>(&self, executor: E) -> Result, ModelError> - where - E: PgExecutor<'e>, - { - debug!("Fetching all allowed groups for network {self}"); - let groups = query_scalar!( - "SELECT name FROM wireguard_network_allowed_group wag \ - JOIN \"group\" g ON wag.group_id = g.id WHERE wag.network_id = $1", - self.id - ) - .fetch_all(executor) - .await?; - - Ok(groups) - } - - /// Return a list of allowed groups for a given network. - /// Admin group should always be included. - /// If no `allowed_groups` are specified for a network then all devices are allowed. - /// In this case `None` is returned to signify that there's no filtering. - /// This helper method is meant for use in all business logic gating - /// access to networks based on allowed groups. - pub async fn get_allowed_groups( - &self, - conn: &mut PgConnection, - ) -> Result>, ModelError> { - debug!("Returning a list of allowed groups for network {self}"); - let admin_groups = Group::find_by_permission(&mut *conn, Permission::IsAdmin).await?; - - // get allowed groups from DB - let mut groups = self.fetch_allowed_groups(&mut *conn).await?; - - // if no allowed groups are set then all groups are allowed - if groups.is_empty() { - return Ok(None); - } - - for group in admin_groups { - if !groups.iter().any(|name| name == &group.name) { - groups.push(group.name); - } - } - - Ok(Some(groups)) - } - - /// Set allowed groups, removing or adding groups as necessary. - pub async fn set_allowed_groups( - &self, - transaction: &mut PgConnection, - allowed_groups: Vec, - ) -> Result<(), ModelError> { - info!("Setting allowed groups for network {self} to: {allowed_groups:?}"); - if allowed_groups.is_empty() { - return self.clear_allowed_groups(transaction).await; - } - - // get list of current allowed groups - let mut current_groups = self.fetch_allowed_groups(&mut *transaction).await?; - - // add to group if not already a member - for group in &allowed_groups { - if !current_groups.contains(group) { - self.add_to_group(transaction, group).await?; - } - } - - // remove groups which are no longer present - current_groups.retain(|group| !allowed_groups.contains(group)); - if !current_groups.is_empty() { - self.remove_from_groups(transaction, current_groups).await?; - } - - Ok(()) - } - - pub async fn add_to_group( - &self, - transaction: &mut PgConnection, - group: &str, - ) -> Result<(), ModelError> { - info!("Adding allowed group {group} for network {self}"); - query!( - "INSERT INTO wireguard_network_allowed_group (network_id, group_id) \ - SELECT $1, g.id FROM \"group\" g WHERE g.name = $2", - self.id, - group - ) - .execute(transaction) - .await?; - Ok(()) - } - - pub async fn remove_from_groups( - &self, - transaction: &mut PgConnection, - groups: Vec, - ) -> Result<(), ModelError> { - info!("Removing allowed groups {groups:?} for network {self}"); - let result = query!( - "DELETE FROM wireguard_network_allowed_group \ - WHERE network_id = $1 AND group_id IN ( \ - SELECT id FROM \"group\" \ - WHERE name IN (SELECT * FROM UNNEST($2::text[])) \ - )", - self.id, - &groups - ) - .execute(transaction) - .await?; - info!( - "Removed {} allowed groups for network {self}", - result.rows_affected(), - ); - Ok(()) - } - - /// Remove all allowed groups for a given network - async fn clear_allowed_groups(&self, transaction: &mut PgConnection) -> Result<(), ModelError> { - info!("Removing all allowed groups for network {self}"); - let result = query!( - "DELETE FROM wireguard_network_allowed_group WHERE network_id=$1", - self.id - ) - .execute(transaction) - .await?; - info!( - "Removed {} allowed groups for network {self}", - result.rows_affected(), - ); - Ok(()) - } -} - #[cfg(test)] mod test { - use defguard_common::db::setup_pool; + use crate::db::setup_pool; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use super::*; - use crate::db::User; #[sqlx::test] async fn test_group(_: PgPoolOptions, options: PgConnectOptions) { diff --git a/crates/defguard_common/src/db/models/mfa_info.rs b/crates/defguard_common/src/db/models/mfa_info.rs new file mode 100644 index 0000000000..b7925ce97a --- /dev/null +++ b/crates/defguard_common/src/db/models/mfa_info.rs @@ -0,0 +1,58 @@ +use crate::db::{ + Id, + models::{MFAMethod, user::User}, +}; +use serde::{Deserialize, Serialize}; +use sqlx::{Error as SqlxError, PgPool, query_as}; + +#[derive(Deserialize, Serialize)] +pub struct MFAInfo { + pub mfa_method: MFAMethod, + totp_available: bool, + webauthn_available: bool, + email_available: bool, +} + +impl MFAInfo { + pub async fn for_user(pool: &PgPool, user: &User) -> Result, SqlxError> { + query_as!( + Self, + "SELECT mfa_method \"mfa_method: _\", totp_enabled totp_available, \ + email_mfa_enabled email_available, \ + (SELECT count(*) > 0 FROM webauthn WHERE user_id = $1) \"webauthn_available!\" \ + FROM \"user\" WHERE \"user\".id = $1", + user.id + ) + .fetch_optional(pool) + .await + } + + #[must_use] + pub fn mfa_available(&self) -> bool { + self.webauthn_available || self.totp_available || self.email_available + } + + #[must_use] + pub fn current_mfa_method(&self) -> &MFAMethod { + &self.mfa_method + } + + #[must_use] + pub fn list_available_methods(&self) -> Option> { + if !self.mfa_available() { + return None; + } + + let mut methods = Vec::new(); + if self.webauthn_available { + methods.push(MFAMethod::Webauthn); + } + if self.totp_available { + methods.push(MFAMethod::OneTimePassword); + } + if self.email_available { + methods.push(MFAMethod::Email); + } + Some(methods) + } +} diff --git a/crates/defguard_common/src/db/models/mod.rs b/crates/defguard_common/src/db/models/mod.rs index 0aa3a601bb..653e62d2ab 100644 --- a/crates/defguard_common/src/db/models/mod.rs +++ b/crates/defguard_common/src/db/models/mod.rs @@ -1,15 +1,35 @@ pub mod auth_code; pub mod authentication_key; pub mod biometric_auth; +pub mod device; pub mod device_login; pub mod error; +pub mod group; +pub mod mfa_info; +pub mod oauth2authorizedapp; +pub mod oauth2client; +pub mod oauth2token; +pub mod polling_token; +pub mod session; pub mod settings; pub mod user; +pub mod webauthn; +pub mod wireguard; +pub mod wireguard_peer_stats; +pub mod yubikey; pub use auth_code::AuthCode; pub use authentication_key::{AuthenticationKey, AuthenticationKeyType}; pub use biometric_auth::{BiometricAuth, BiometricChallenge}; +pub use device::{Device, DeviceConfig, DeviceError, DeviceNetworkInfo, DeviceType}; pub use device_login::DeviceLoginEvent; pub use error::ModelError; +pub use mfa_info::MFAInfo; +pub use oauth2authorizedapp::OAuth2AuthorizedApp; +pub use oauth2token::OAuth2Token; +pub use session::{Session, SessionState}; pub use settings::{Settings, SettingsEssentials}; -pub use user::MFAMethod; +pub use user::{MFAMethod, User}; +pub use webauthn::WebAuthn; +pub use wireguard::{WireguardNetwork, WireguardNetworkError}; +pub use yubikey::YubiKey; diff --git a/crates/defguard_core/src/db/models/oauth2authorizedapp.rs b/crates/defguard_common/src/db/models/oauth2authorizedapp.rs similarity index 96% rename from crates/defguard_core/src/db/models/oauth2authorizedapp.rs rename to crates/defguard_common/src/db/models/oauth2authorizedapp.rs index 0b7bb1af90..e6f5119abd 100644 --- a/crates/defguard_core/src/db/models/oauth2authorizedapp.rs +++ b/crates/defguard_common/src/db/models/oauth2authorizedapp.rs @@ -1,4 +1,4 @@ -use defguard_common::db::{Id, NoId}; +use crate::db::{Id, NoId}; use model_derive::Model; use sqlx::{Error as SqlxError, PgPool, query_as}; diff --git a/crates/defguard_core/src/db/models/oauth2client.rs b/crates/defguard_common/src/db/models/oauth2client.rs similarity index 83% rename from crates/defguard_core/src/db/models/oauth2client.rs rename to crates/defguard_common/src/db/models/oauth2client.rs index d7e2cbdb53..37256afc74 100644 --- a/crates/defguard_core/src/db/models/oauth2client.rs +++ b/crates/defguard_common/src/db/models/oauth2client.rs @@ -1,13 +1,11 @@ -use defguard_common::{ - db::{Id, NoId}, +use crate::{ + db::{Id, NoId, models::OAuth2Token}, random::gen_alphanumeric, }; use model_derive::Model; +use serde::{Deserialize, Serialize}; use sqlx::{Error as SqlxError, PgExecutor, PgPool, query_as}; -use super::NewOpenIDClient; -use crate::db::OAuth2Token; - #[derive(Clone, Debug, Deserialize, Model, Serialize, PartialEq)] pub struct OAuth2Client { pub id: I, @@ -22,7 +20,7 @@ pub struct OAuth2Client { pub enabled: bool, } -impl OAuth2Client { +impl OAuth2Client { #[must_use] pub fn new(redirect_uri: Vec, scope: Vec, name: String) -> Self { let client_id = gen_alphanumeric(16); @@ -37,26 +35,11 @@ impl OAuth2Client { enabled: true, } } - - #[must_use] - pub fn from_new(new: NewOpenIDClient) -> Self { - let client_id = gen_alphanumeric(16); - let client_secret = gen_alphanumeric(32); - Self { - id: NoId, - client_id, - client_secret, - redirect_uri: new.redirect_uri, - scope: new.scope, - name: new.name, - enabled: new.enabled, - } - } } impl OAuth2Client { /// Find client by 'client_id`. - pub(crate) async fn find_by_client_id<'e, E>( + pub async fn find_by_client_id<'e, E>( executor: E, client_id: &str, ) -> Result, SqlxError> @@ -73,7 +56,7 @@ impl OAuth2Client { .await } - pub(crate) async fn clear_authorizations<'e, E>(&self, executor: E) -> Result<(), SqlxError> + pub async fn clear_authorizations<'e, E>(&self, executor: E) -> Result<(), SqlxError> where E: PgExecutor<'e>, { @@ -87,7 +70,7 @@ impl OAuth2Client { } /// Find using `client_id` and `client_secret`; must be `enabled`. - pub(crate) async fn find_by_auth( + pub async fn find_by_auth( pool: &PgPool, client_id: &str, client_secret: &str, @@ -103,7 +86,7 @@ impl OAuth2Client { .await } - pub(crate) async fn find_by_token( + pub async fn find_by_token( pool: &PgPool, token: &OAuth2Token, ) -> Result, SqlxError> { @@ -122,7 +105,7 @@ impl OAuth2Client { } /// Checks if `url` matches client config (ignoring trailing slashes). - pub(crate) fn contains_redirect_url(&self, url: &str) -> bool { + pub fn contains_redirect_url(&self, url: &str) -> bool { let url_trimmed = url.trim_end_matches('/'); for redirect in &self.redirect_uri { diff --git a/crates/defguard_core/src/db/models/oauth2token.rs b/crates/defguard_common/src/db/models/oauth2token.rs similarity index 98% rename from crates/defguard_core/src/db/models/oauth2token.rs rename to crates/defguard_common/src/db/models/oauth2token.rs index abc400d711..468e83f64e 100644 --- a/crates/defguard_core/src/db/models/oauth2token.rs +++ b/crates/defguard_common/src/db/models/oauth2token.rs @@ -1,5 +1,5 @@ +use crate::{config::server_config, db::Id, random::gen_alphanumeric}; use chrono::{TimeDelta, Utc}; -use defguard_common::{config::server_config, db::Id, random::gen_alphanumeric}; use sqlx::{Error as SqlxError, PgPool, query, query_as}; pub struct OAuth2Token { diff --git a/crates/defguard_core/src/db/models/polling_token.rs b/crates/defguard_common/src/db/models/polling_token.rs similarity index 98% rename from crates/defguard_core/src/db/models/polling_token.rs rename to crates/defguard_common/src/db/models/polling_token.rs index b4d911936d..f60928c546 100644 --- a/crates/defguard_core/src/db/models/polling_token.rs +++ b/crates/defguard_common/src/db/models/polling_token.rs @@ -1,8 +1,8 @@ -use chrono::{NaiveDateTime, Utc}; -use defguard_common::{ +use crate::{ db::{Id, NoId}, random::gen_alphanumeric, }; +use chrono::{NaiveDateTime, Utc}; use model_derive::Model; use sqlx::{Error as SqlxError, PgExecutor, PgPool, query_as}; diff --git a/crates/defguard_core/src/db/models/session.rs b/crates/defguard_common/src/db/models/session.rs similarity index 94% rename from crates/defguard_core/src/db/models/session.rs rename to crates/defguard_common/src/db/models/session.rs index ee1cda00b0..e1859844e8 100644 --- a/crates/defguard_core/src/db/models/session.rs +++ b/crates/defguard_common/src/db/models/session.rs @@ -1,6 +1,5 @@ +use crate::{config::server_config, db::Id, random::gen_alphanumeric}; use chrono::{NaiveDateTime, TimeDelta, Utc}; -use defguard_common::{config::server_config, db::Id, random::gen_alphanumeric}; -use defguard_mail::templates::SessionContext; use sqlx::{Error as SqlxError, PgExecutor, PgPool, Type, query, query_as}; use webauthn_rs::prelude::{PasskeyAuthentication, PasskeyRegistration}; @@ -27,15 +26,6 @@ pub struct Session { pub device_info: Option, } -impl From for SessionContext { - fn from(value: Session) -> Self { - Self { - ip_address: value.ip_address, - device_info: value.device_info, - } - } -} - impl Session { #[must_use] pub fn new( diff --git a/crates/defguard_common/src/db/models/user.rs b/crates/defguard_common/src/db/models/user.rs index d632bb9483..96c266dfef 100644 --- a/crates/defguard_common/src/db/models/user.rs +++ b/crates/defguard_common/src/db/models/user.rs @@ -1,9 +1,57 @@ -use std::fmt; +use std::{fmt, time::SystemTime}; +use crate::{ + config::server_config, + db::{ + Id, NoId, + models::{MFAInfo, Session, WebAuthn}, + }, + random::{gen_alphanumeric, gen_totp_secret}, + types::user_info::OAuth2AuthorizedAppInfo, +}; +use argon2::{ + Argon2, + password_hash::{ + PasswordHash, PasswordHasher, PasswordVerifier, SaltString, errors::Error as HashError, + rand_core::OsRng, + }, +}; +use model_derive::Model; +use rand::{ + Rng, + distributions::{Alphanumeric, DistString, Standard}, + prelude::Distribution, +}; use serde::{Deserialize, Serialize}; -use sqlx::Type; +use sqlx::{ + Error as SqlxError, FromRow, PgConnection, PgExecutor, PgPool, Type, query, query_as, + query_scalar, +}; +use thiserror::Error; +use totp_lite::{Sha1, totp_custom}; +use tracing::{debug, error, info, warn}; use utoipa::ToSchema; +use super::{ + device::{Device, DeviceType, UserDevice}, + group::{Group, Permission}, +}; + +const RECOVERY_CODES_COUNT: usize = 8; +pub const TOTP_CODE_VALIDITY_PERIOD: u64 = 30; +pub const EMAIL_CODE_DIGITS: u32 = 6; +pub const TOTP_CODE_DIGITS: u32 = 6; + +#[derive(Debug, Error)] +pub enum UserError { + #[error("Invalid MFA state for user {username}")] + InvalidMfaState { username: String }, + #[error(transparent)] + DbError(#[from] SqlxError), + #[error("{0}")] + EmailMfaError(String), +} + #[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Hash, ToSchema, Type)] #[sqlx(type_name = "mfa_method", rename_all = "snake_case")] pub enum MFAMethod { @@ -28,3 +76,1557 @@ impl fmt::Display for MFAMethod { ) } } + +/// Only `id` and `name` from [`WebAuthn`]. +#[derive(Debug, Deserialize, Serialize, ToSchema)] +pub struct SecurityKey { + pub id: Id, + pub name: String, +} + +// User information ready to be sent as part of diagnostic data. +#[derive(Serialize)] +pub struct UserDiagnostic { + pub id: Id, + pub mfa_enabled: bool, + pub totp_enabled: bool, + pub email_mfa_enabled: bool, + pub mfa_method: MFAMethod, + pub is_active: bool, + pub enrolled: bool, +} + +#[derive(Clone, Model, PartialEq, Eq, Hash, Serialize, FromRow)] +pub struct User { + pub id: I, + pub username: String, + pub password_hash: Option, + pub last_name: String, + pub first_name: String, + pub email: String, + pub phone: Option, + pub mfa_enabled: bool, + pub is_active: bool, + /// Indicates whether the user has been created via the LDAP integration. + pub from_ldap: bool, + /// Indicates whether a user has a random password set in LDAP, if so, the user + /// will be prompted to change it on their profile page. + /// + /// The random password is set if we are creating a new user in LDAP from a Defguard user + /// and we don't have access to the plain text password, e.g. during Defguard -> LDAP user import. + pub ldap_pass_randomized: bool, + /// The user's LDAP RDN value. This is the first part of the DN. + /// For example, if the DN is `cn=John Doe,ou=users,dc=example,dc=com`, + /// the RDN is `cn=John Doe`. + /// This is used to identify the user in LDAP as we sometimes can't use the Defguard's username + /// since the RDN may contain spaces or other special characters and the username may not. + pub ldap_rdn: Option, + /// Rest of the user's DN + pub ldap_user_path: Option, + /// The user's sub claim returned by the OpenID provider. Also indicates whether the user has + /// used OpenID to log in. + // FIXME: must be unique + pub openid_sub: Option, + // secret has been verified and TOTP can be used + pub totp_enabled: bool, + pub email_mfa_enabled: bool, + pub totp_secret: Option>, + pub email_mfa_secret: Option>, + #[model(enum)] + pub mfa_method: MFAMethod, + #[model(ref)] + pub recovery_codes: Vec, + /// Indicates that an administrator has requested an enrollment token for this user. + /// Uninitialized clients should then guide the user through enrollment process. + /// Related issue: https://github.com/DefGuard/client/issues/647. + pub enrollment_pending: bool, +} + +// TODO: Refactor the user struct to use SecretStringWrapper instead of this +impl fmt::Debug for User { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let Self { + id, + username, + password_hash: _, + last_name, + first_name, + email, + phone, + mfa_enabled, + is_active, + from_ldap, + ldap_pass_randomized, + ldap_rdn, + ldap_user_path, + openid_sub, + totp_enabled, + email_mfa_enabled, + totp_secret: _, + email_mfa_secret: _, + mfa_method, + recovery_codes, + enrollment_pending, + } = self; + + f.debug_struct("User") + .field("id", id) + .field("username", username) + .field("last_name", last_name) + .field("first_name", first_name) + .field("email", email) + .field("phone", phone) + .field("mfa_enabled", mfa_enabled) + .field("is_active", is_active) + .field("from_ldap", from_ldap) + .field("ldap_pass_randomized", ldap_pass_randomized) + .field("ldap_rdn", ldap_rdn) + .field("ldap_user_path", ldap_user_path) // sensitive data + .field("openid_sub", openid_sub) + .field("totp_enabled", totp_enabled) + .field("email_mfa_enabled", email_mfa_enabled) + .field("mfa_method", mfa_method) + .field( + "recovery_codes", + &format_args!("{} items", recovery_codes.len()), + ) + .field("password_hash", &"***") + .field("totp_secret", &"***") + .field("email_mfa_secret", &"***") + .field("enrollment_pending", enrollment_pending) + .finish() + } +} + +fn hash_password(password: &str) -> Result { + let salt = SaltString::generate(&mut OsRng); + Ok(Argon2::default() + .hash_password(password.as_bytes(), &salt)? + .to_string()) +} + +impl User { + #[must_use] + pub fn new>( + username: S, + password: Option<&str>, + last_name: S, + first_name: S, + email: S, + phone: Option, + ) -> Self { + let password_hash = password.and_then(|password_hash| hash_password(password_hash).ok()); + let username: String = username.into(); + Self { + id: NoId, + username: username.clone(), + password_hash, + last_name: last_name.into(), + first_name: first_name.into(), + email: email.into(), + phone, + mfa_enabled: false, + totp_enabled: false, + email_mfa_enabled: false, + totp_secret: None, + email_mfa_secret: None, + mfa_method: MFAMethod::None, + recovery_codes: Vec::new(), + is_active: true, + openid_sub: None, + from_ldap: false, + ldap_pass_randomized: false, + ldap_rdn: Some(username.clone()), + ldap_user_path: None, + enrollment_pending: false, + } + } +} + +impl fmt::Display for User { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.username) + } +} + +impl User { + pub fn set_password(&mut self, password: &str) { + self.password_hash = hash_password(password).ok(); + } + + pub fn verify_password(&self, password: &str) -> Result<(), HashError> { + debug!("Checking if password matches for user {}", self.username); + if let Some(hash) = &self.password_hash { + let parsed_hash = PasswordHash::new(hash)?; + Argon2::default().verify_password(password.as_bytes(), &parsed_hash) + } else { + info!("User {} has no password set", self.username); + Err(HashError::Password) + } + } + + #[must_use] + pub fn has_password(&self) -> bool { + self.password_hash.is_some() + } + + #[must_use] + pub fn name(&self) -> String { + format!("{} {}", self.first_name, self.last_name) + } + + /// Determines whether the user is considered enrolled. + /// + /// A user is treated as enrolled if: + /// - The `enrollment_pending` flag is **not** set, i.e. enrollment was not requested by an + /// administrator (https://github.com/DefGuard/client/issues/647). + /// - They either have a password configured, have authenticated via an external OIDC provider + /// or were synced from LDAP. + #[must_use] + pub fn is_enrolled(&self) -> bool { + !self.enrollment_pending + && (self.password_hash.is_some() || self.openid_sub.is_some() || self.from_ldap) + } + + #[must_use] + pub fn ldap_rdn_value(&self) -> &str { + if let Some(ldap_rdn) = &self.ldap_rdn { + ldap_rdn + } else { + warn!( + "LDAP RDN is not set for user {}. Using username as a fallback.", + self.username + ); + &self.username + } + } +} + +impl User { + /// Generate new TOTP secret, save it, then return it as RFC 4648 base32-encoded string. + pub async fn new_totp_secret<'e, E>(&mut self, executor: E) -> Result + where + E: PgExecutor<'e>, + { + let secret = gen_totp_secret(); + query!( + "UPDATE \"user\" SET totp_secret = $1 WHERE id = $2", + secret, + self.id + ) + .execute(executor) + .await?; + + let secret_base32 = base32::encode(base32::Alphabet::Rfc4648 { padding: false }, &secret); + self.totp_secret = Some(secret); + Ok(secret_base32) + } + + /// Generate new email secret, similar to TOTP secret above, but don't return generated value. + pub async fn new_email_secret<'e, E>(&mut self, executor: E) -> Result<(), SqlxError> + where + E: PgExecutor<'e>, + { + let email_secret = gen_totp_secret(); + query!( + "UPDATE \"user\" SET email_mfa_secret = $1 WHERE id = $2", + email_secret, + self.id + ) + .execute(executor) + .await?; + + self.email_mfa_secret = Some(email_secret); + + Ok(()) + } + + pub async fn set_mfa_method<'e, E>( + &mut self, + executor: E, + mfa_method: MFAMethod, + ) -> Result<(), SqlxError> + where + E: PgExecutor<'e>, + { + info!( + "Setting MFA method for user {} to {mfa_method:?}", + self.username + ); + query!( + "UPDATE \"user\" SET mfa_method = $2 WHERE id = $1", + self.id, + &mfa_method as &MFAMethod + ) + .execute(executor) + .await?; + self.mfa_method = mfa_method; + + Ok(()) + } + + /// Check if any of the multi-factor authentication methods is on. + /// - TOTP is enabled + /// - a security key for Webauthn + async fn check_mfa_enabled<'e, E>(&self, executor: E) -> Result + where + E: PgExecutor<'e>, + { + // short-cut + if self.totp_enabled || self.email_mfa_enabled { + return Ok(true); + } + + query_scalar!( + "SELECT totp_enabled OR email_mfa_enabled \ + OR count(webauthn.id) > 0 \"bool!\" FROM \"user\" \ + LEFT JOIN webauthn ON webauthn.user_id = \"user\".id \ + WHERE \"user\".id = $1 GROUP BY totp_enabled, email_mfa_enabled;", + self.id + ) + .fetch_one(executor) + .await + } + + /// Verify the state of MFA flags are correct. + /// Recovers from invalid mfa_method + /// Use this function after removing any of the authentication factors. + pub async fn verify_mfa_state(&mut self, pool: &PgPool) -> Result<(), UserError> { + if let Some(info) = MFAInfo::for_user(pool, self).await? { + let factors_present = info.mfa_available(); + if self.mfa_enabled != factors_present { + // store correct value for MFA flag in the DB + if self.mfa_enabled { + // last factor was removed so we have to disable MFA + self.disable_mfa(pool).await?; + } else { + // first factor was added so MFA needs to be enabled + query!( + "UPDATE \"user\" SET mfa_enabled = $2 WHERE id = $1", + self.id, + factors_present + ) + .execute(pool) + .await?; + } + + if !factors_present && self.mfa_method != MFAMethod::None { + debug!( + "MFA for user {} disabled, updating MFA method to None", + self.username + ); + self.set_mfa_method(pool, MFAMethod::None).await?; + } + + self.mfa_enabled = factors_present; + } + + // set correct value for default method + if factors_present { + match info.list_available_methods() { + None => { + error!("Incorrect MFA info state for user {}", self.username); + return Err(UserError::InvalidMfaState { + username: self.username.clone(), + }); + } + Some(methods) => { + info!( + "Checking if {:?} in in available methods {methods:?}, {}", + info.mfa_method, + methods.contains(&info.mfa_method) + ); + if !methods.contains(&info.mfa_method) { + // FIXME: do not panic + self.set_mfa_method(pool, methods.into_iter().next().unwrap()) + .await?; + } + } + } + } + } + Ok(()) + } + + /// Enable MFA. At least one of the authenticator factors must be configured. + pub async fn enable_mfa(&mut self, pool: &PgPool) -> Result<(), UserError> { + if !self.mfa_enabled { + self.verify_mfa_state(pool).await?; + } + Ok(()) + } + + /// Get recovery codes. If recovery codes exist, this function returns `None`. + /// That way recovery codes are returned only once - when MFA is turned on. + pub async fn get_recovery_codes<'e, E>( + &mut self, + executor: E, + ) -> Result>, SqlxError> + where + E: PgExecutor<'e>, + { + if !self.recovery_codes.is_empty() { + return Ok(None); + } + + for _ in 0..RECOVERY_CODES_COUNT { + let code = gen_alphanumeric(16); + self.recovery_codes.push(code); + } + query!( + "UPDATE \"user\" SET recovery_codes = $2 WHERE id = $1", + self.id, + &self.recovery_codes + ) + .execute(executor) + .await?; + + Ok(Some(self.recovery_codes.clone())) + } + + /// Disable MFA; discard recovery codes, TOTP secret, and security keys. + pub async fn disable_mfa(&mut self, pool: &PgPool) -> Result<(), SqlxError> { + query!( + "UPDATE \"user\" SET mfa_enabled = FALSE, mfa_method = 'none', totp_enabled = FALSE, email_mfa_enabled = FALSE, \ + totp_secret = NULL, email_mfa_secret = NULL, recovery_codes = '{}' WHERE id = $1", + self.id + ) + .execute(pool) + .await?; + WebAuthn::delete_all_for_user(pool, self.id).await?; + + self.totp_secret = None; + self.email_mfa_secret = None; + self.totp_enabled = false; + self.email_mfa_enabled = false; + self.mfa_method = MFAMethod::None; + self.recovery_codes.clear(); + + Ok(()) + } + + /// Enable TOTP + pub async fn enable_totp<'e, E>(&mut self, executor: E) -> Result<(), SqlxError> + where + E: PgExecutor<'e>, + { + if !self.totp_enabled { + query!( + "UPDATE \"user\" SET totp_enabled = TRUE WHERE id = $1", + self.id + ) + .execute(executor) + .await?; + self.totp_enabled = true; + } + + Ok(()) + } + + /// Disable TOTP; discard the secret. + pub async fn disable_totp(&mut self, pool: &PgPool) -> Result<(), SqlxError> { + if self.totp_enabled { + // FIXME: check if this flag is set correctly when TOTP is the only method + self.mfa_enabled = self.check_mfa_enabled(pool).await?; + self.totp_enabled = false; + self.totp_secret = None; + + query!( + "UPDATE \"user\" SET mfa_enabled = $2, totp_enabled = $3 AND totp_secret = $4 \ + WHERE id = $1", + self.id, + self.mfa_enabled, + self.totp_enabled, + self.totp_secret, + ) + .execute(pool) + .await?; + } + + Ok(()) + } + + /// Enable email MFA + pub async fn enable_email_mfa<'e, E>(&mut self, executor: E) -> Result<(), SqlxError> + where + E: PgExecutor<'e>, + { + if !self.email_mfa_enabled { + query!( + "UPDATE \"user\" SET email_mfa_enabled = TRUE WHERE id = $1", + self.id + ) + .execute(executor) + .await?; + + self.email_mfa_enabled = true; + } + + Ok(()) + } + + /// Disable email MFA; discard the secret. + pub async fn disable_email_mfa(&mut self, pool: &PgPool) -> Result<(), SqlxError> { + if self.email_mfa_enabled { + self.mfa_enabled = self.check_mfa_enabled(pool).await?; + self.email_mfa_enabled = false; + self.email_mfa_secret = None; + + query!( + "UPDATE \"user\" SET mfa_enabled = $2, email_mfa_enabled = $3 AND email_mfa_secret = $4 \ + WHERE id = $1", + self.id, + self.mfa_enabled, + self.email_mfa_enabled, + self.email_mfa_secret, + ) + .execute(pool) + .await?; + } + + Ok(()) + } + + /// Select all users without sensitive data. + // FIXME: Remove it when Model macro will support SecretString + pub async fn all_without_sensitive_data( + pool: &PgPool, + ) -> Result, SqlxError> { + let users = query!( + "SELECT id, mfa_enabled, totp_enabled, email_mfa_enabled, \ + mfa_method \"mfa_method: MFAMethod\", password_hash, is_active, openid_sub, \ + from_ldap, ldap_pass_randomized, ldap_rdn \ + FROM \"user\"" + ) + .fetch_all(pool) + .await?; + let res: Vec = users + .iter() + .map(|u| UserDiagnostic { + mfa_method: u.mfa_method.clone(), + totp_enabled: u.totp_enabled, + email_mfa_enabled: u.email_mfa_enabled, + mfa_enabled: u.mfa_enabled, + id: u.id, + is_active: u.is_active, + enrolled: u.password_hash.is_some() || u.openid_sub.is_some() || u.from_ldap, + }) + .collect(); + + Ok(res) + } + + /// Return all members of group + pub async fn find_by_group_name( + pool: &PgPool, + group_name: &str, + ) -> Result>, SqlxError> { + let users = query_as!( + Self, + "SELECT \"user\".id, username, password_hash, last_name, first_name, email, \ + phone, mfa_enabled, totp_enabled, totp_secret, \ + email_mfa_enabled, email_mfa_secret, \ + mfa_method \"mfa_method: _\", recovery_codes, is_active, openid_sub, \ + from_ldap, ldap_pass_randomized, ldap_rdn, ldap_user_path, enrollment_pending \ + FROM \"user\" \ + INNER JOIN \"group_user\" ON \"user\".id = \"group_user\".user_id \ + INNER JOIN \"group\" ON \"group_user\".group_id = \"group\".id \ + WHERE \"group\".name = $1", + group_name + ) + .fetch_all(pool) + .await?; + + Ok(users) + } + + /// Check if TOTP `code` is valid. + #[must_use] + pub fn verify_totp_code(&self, code: &str) -> bool { + if let Some(totp_secret) = &self.totp_secret { + if let Ok(timestamp) = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH) { + let expected_code = totp_custom::( + TOTP_CODE_VALIDITY_PERIOD, + TOTP_CODE_DIGITS, + totp_secret, + timestamp.as_secs(), + ); + return code == expected_code; + } + } + + false + } + + /// Generate MFA code for email verification. + /// + /// NOTE: This code will be valid for two time frames. See comment for verify_email_mfa_code(). + pub fn generate_email_mfa_code(&self) -> Result { + if let Some(email_mfa_secret) = &self.email_mfa_secret { + let timeout = &server_config().mfa_code_timeout; + if let Ok(timestamp) = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH) { + let code = totp_custom::( + timeout.as_secs(), + EMAIL_CODE_DIGITS, + email_mfa_secret, + timestamp.as_secs(), + ); + Ok(code) + } else { + Err(UserError::EmailMfaError( + "SystemTime before UNIX epoch".into(), + )) + } + } else { + Err(UserError::EmailMfaError(format!( + "Email MFA secret not configured for user {}", + self.username + ))) + } + } + + /// Check if email MFA `code` is valid. + /// + /// IMPORTANT: because current implementation uses TOTP for email verification, + /// allow the code for the previous time frame. This approach pretends the code is valid + /// for a certain *period of time* (as opposed to a TOTP code which is valid for a certain time *frame*). + /// + /// ```text + /// |<---- frame #0 ---->|<---- frame #1 ---->|<---- frame #2 ---->| + /// |................[*]email sent.................................| + /// |......................[*]email code verified..................| + /// ``` + #[must_use] + pub fn verify_email_mfa_code(&self, code: &str) -> bool { + if let Some(email_mfa_secret) = &self.email_mfa_secret { + let timeout = server_config().mfa_code_timeout.as_secs(); + if let Ok(timestamp) = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH) { + let expected_code = totp_custom::( + timeout, + EMAIL_CODE_DIGITS, + email_mfa_secret, + timestamp.as_secs(), + ); + if code == expected_code { + return true; + } + debug!( + "Email MFA verification TOTP code for user {} doesn't fit current time \ + frame, checking the previous one. \ + Expected: {expected_code}, got: {code}", + self.username + ); + + let previous_code = totp_custom::( + timeout, + EMAIL_CODE_DIGITS, + email_mfa_secret, + timestamp.as_secs() - timeout, + ); + + if code == previous_code { + return true; + } + debug!( + "Email MFA verification TOTP code for user {} doesn't fit previous time frame, \ + expected: {previous_code}, got: {code}", + self.username + ); + return false; + } + debug!( + "Couldn't calculate current timestamp when verifying email MFA code for user {}", + self.username + ); + } else { + debug!("Email MFA secret not configured for user {}", self.username); + } + false + } + + /// Verify recovery code. If it is valid, consume it, so it can't be used again. + pub async fn verify_recovery_code( + &mut self, + pool: &PgPool, + code: &str, + ) -> Result { + if let Some(index) = self.recovery_codes.iter().position(|c| c == code) { + // Note: swap_remove() should be faster than remove(). + self.recovery_codes.swap_remove(index); + + query!( + "UPDATE \"user\" SET recovery_codes = $2 WHERE id = $1", + self.id, + &self.recovery_codes + ) + .execute(pool) + .await?; + + Ok(true) + } else { + Ok(false) + } + } + + pub async fn find_by_username<'e, E>( + executor: E, + username: &str, + ) -> Result, SqlxError> + where + E: PgExecutor<'e>, + { + query_as!( + Self, + "SELECT id, username, password_hash, last_name, first_name, email, phone, mfa_enabled, \ + totp_enabled, email_mfa_enabled, totp_secret, email_mfa_secret, \ + mfa_method \"mfa_method: _\", recovery_codes, is_active, openid_sub, \ + from_ldap, ldap_pass_randomized, ldap_rdn, ldap_user_path, enrollment_pending \ + FROM \"user\" WHERE username = $1", + username + ) + .fetch_optional(executor) + .await + } + + pub async fn find_by_email<'e, E>(executor: E, email: &str) -> Result, SqlxError> + where + E: PgExecutor<'e>, + { + query_as!( + Self, + "SELECT id, username, password_hash, last_name, first_name, email, phone, mfa_enabled, \ + totp_enabled, email_mfa_enabled, totp_secret, email_mfa_secret, \ + mfa_method \"mfa_method: _\", recovery_codes, is_active, openid_sub, from_ldap, \ + ldap_pass_randomized, ldap_rdn, ldap_user_path, enrollment_pending \ + FROM \"user\" WHERE email ILIKE $1", + email + ) + .fetch_optional(executor) + .await + } + + /// Attempts to find user by username and then by email, if none is initially found. + pub async fn find_by_username_or_email( + conn: &mut PgConnection, + username_or_email: &str, + ) -> Result, SqlxError> { + let maybe_user = Self::find_by_username(&mut *conn, username_or_email).await?; + if let Some(user) = maybe_user { + Ok(Some(user)) + } else { + debug!( + "Failed to find user by username {username_or_email}. Attempting to find by email" + ); + Ok(Self::find_by_email(&mut *conn, username_or_email).await?) + } + } + + pub async fn find_many_by_emails<'e, E>( + executor: E, + emails: &[&str], + ) -> Result, SqlxError> + where + E: PgExecutor<'e>, + { + query_as( + "SELECT id, username, password_hash, last_name, first_name, email, phone, \ + mfa_enabled, totp_enabled, email_mfa_enabled, totp_secret, email_mfa_secret, \ + mfa_method, recovery_codes, is_active, openid_sub, from_ldap, ldap_pass_randomized, \ + ldap_rdn, ldap_user_path, enrollment_pending \ + FROM \"user\" WHERE email = ANY($1)", + ) + .bind(emails) + .fetch_all(executor) + .await + } + + pub async fn find_by_sub<'e, E>(executor: E, sub: &str) -> Result, SqlxError> + where + E: PgExecutor<'e>, + { + query_as!( + Self, + "SELECT id, username, password_hash, last_name, first_name, email, phone, \ + mfa_enabled, totp_enabled, email_mfa_enabled, totp_secret, email_mfa_secret, \ + mfa_method \"mfa_method: _\", recovery_codes, is_active, openid_sub, \ + from_ldap, ldap_pass_randomized, ldap_rdn, ldap_user_path, enrollment_pending \ + FROM \"user\" WHERE openid_sub = $1", + sub + ) + .fetch_optional(executor) + .await + } + + pub async fn member_of_names<'e, E>(&self, executor: E) -> Result, SqlxError> + where + E: PgExecutor<'e>, + { + query_scalar!( + "SELECT \"group\".name FROM \"group\" JOIN group_user ON \"group\".id = group_user.group_id \ + WHERE group_user.user_id = $1", + self.id + ) + .fetch_all(executor) + .await + } + + pub async fn member_of<'e, E>(&self, executor: E) -> Result>, SqlxError> + where + E: PgExecutor<'e>, + { + query_as!( + Group, + "SELECT id, name, is_admin FROM \"group\" JOIN group_user ON \"group\".id = group_user.group_id \ + WHERE group_user.user_id = $1", + self.id + ) + .fetch_all(executor) + .await + } + + /// Returns a vector of [`UserDevice`]s (hence the name). + /// [`UserDevice`] is a struct containing additional network info about a device. + /// If you only need [`Device`]s, use [`User::devices()`] instead. + pub async fn user_devices(&self, pool: &PgPool) -> Result, SqlxError> { + let devices = self.devices(pool).await?; + let mut user_devices = Vec::new(); + for device in devices { + if let Some(user_device) = UserDevice::from_device(pool, device).await? { + user_devices.push(user_device); + } + } + + Ok(user_devices) + } + + /// Returns a vector of [`Device`]s related to a user. If you want to get [`UserDevice`]s (which contain additional network info), + /// use [`User::user_devices()`] instead. + pub async fn devices<'e, E>(&self, executor: E) -> Result>, SqlxError> + where + E: PgExecutor<'e>, + { + query_as!( + Device, + "SELECT device.id, name, wireguard_pubkey, user_id, created, description, \ + device_type \"device_type: DeviceType\", configured \ + FROM device WHERE user_id = $1 and device_type = 'user'::device_type \ + ORDER BY id", + self.id + ) + .fetch_all(executor) + .await + } + + pub async fn oauth2authorizedapps<'e, E>( + &self, + executor: E, + ) -> Result, SqlxError> + where + E: PgExecutor<'e>, + { + query_as!( + OAuth2AuthorizedAppInfo, + "SELECT oauth2client.id \"oauth2client_id!\", oauth2client.name \"oauth2client_name\", \ + oauth2authorizedapp.user_id \"user_id\" \ + FROM oauth2authorizedapp \ + JOIN oauth2client ON oauth2client.id = oauth2authorizedapp.oauth2client_id \ + WHERE oauth2authorizedapp.user_id = $1", + self.id + ) + .fetch_all(executor) + .await + } + + pub async fn security_keys(&self, pool: &PgPool) -> Result, SqlxError> { + query_as!( + SecurityKey, + "SELECT id \"id!\", name FROM webauthn WHERE user_id = $1", + self.id + ) + .fetch_all(pool) + .await + } + + pub async fn add_to_group<'e, E>(&self, executor: E, group: &Group) -> Result<(), SqlxError> + where + E: PgExecutor<'e>, + { + query!( + "INSERT INTO group_user (group_id, user_id) VALUES ($1, $2) \ + ON CONFLICT DO NOTHING", + group.id, + self.id + ) + .execute(executor) + .await?; + Ok(()) + } + + pub async fn remove_from_group<'e, E>( + &self, + executor: E, + group: &Group, + ) -> Result<(), SqlxError> + where + E: PgExecutor<'e>, + { + query!( + "DELETE FROM group_user WHERE group_id = $1 AND user_id = $2", + group.id, + self.id + ) + .execute(executor) + .await?; + Ok(()) + } + + /// Remove authorized apps by their client id's from user + pub async fn remove_oauth2_authorized_apps<'e, E>( + &self, + executor: E, + app_client_ids: &[i64], + ) -> Result<(), SqlxError> + where + E: PgExecutor<'e>, + { + query!( + "DELETE FROM oauth2authorizedapp WHERE user_id = $1 AND oauth2client_id = ANY($2)", + self.id, + app_client_ids + ) + .execute(executor) + .await?; + + Ok(()) + } + + /// Create admin user if one doesn't exist yet + pub async fn init_admin_user( + pool: &PgPool, + default_admin_pass: &str, + ) -> Result<(), anyhow::Error> { + debug!("Checking if some admin user already exists and creating one if not..."); + let admins = User::find_admins(pool).await?; + if admins.is_empty() { + let admin_groups = Group::find_by_permission(pool, Permission::IsAdmin).await?; + if admin_groups.is_empty() { + return Err(anyhow::anyhow!( + "No admin group and users found, or they are all disabled. \ + You'll need to create and assign the admin group manually, \ + as there must be at least one active admin user." + )); + } + + // create admin user + let password_hash = hash_password(default_admin_pass)?; + let result = query_scalar!( + "INSERT INTO \"user\" (username, password_hash, last_name, first_name, email, ldap_rdn) \ + VALUES ('admin', $1, 'Administrator', 'DefGuard', 'admin@defguard', 'admin') \ + ON CONFLICT DO NOTHING \ + RETURNING id", + password_hash + ) + .fetch_optional(pool) + .await?; + + // if new user was created add them to admin group, first one you find + // the groups are sorted by ID desceding, so it will often be the 1st one = the default admin group + if let Some(new_user_id) = result { + let admin_group_id = admin_groups + .first() + .ok_or(anyhow::anyhow!( + "No admin group found, can't create admin user" + ))? + .id; + info!("New admin user has been created, adding to Admin group..."); + query("INSERT INTO group_user (group_id, user_id) VALUES ($1, $2)") + .bind(admin_group_id) + .bind(new_user_id) + .execute(pool) + .await?; + info!("Admin user has been created as there was no other admin user"); + } else { + return Err(anyhow::anyhow!( + "A conflict occurred while trying to add a missing admin. \ + There is already a user with username 'admin' but he is not an admin or he is disabled. \ + You will need to assign someone the admin group manually or enable this admin user, \ + as there must be at least one active admin." + )); + } + } else { + debug!("Admin users already exists, skipping creation of the default admin user"); + } + Ok(()) + } + + pub async fn logout_all_sessions<'e, E>(&self, executor: E) -> Result<(), SqlxError> + where + E: PgExecutor<'e>, + { + Session::delete_all_for_user(executor, self.id).await?; + Ok(()) + } + + pub async fn find_by_device_id<'e, E>( + executor: E, + device_id: Id, + ) -> Result, SqlxError> + where + E: PgExecutor<'e>, + { + query_as!( + Self, + "SELECT u.id, u.username, u.password_hash, u.last_name, u.first_name, u.email, \ + u.phone, u.mfa_enabled, u.totp_enabled, u.email_mfa_enabled, \ + u.totp_secret, u.email_mfa_secret, u.mfa_method \"mfa_method: _\", u.recovery_codes, \ + u.is_active, u.openid_sub, from_ldap, ldap_pass_randomized, ldap_rdn, ldap_user_path, \ + enrollment_pending \ + FROM \"user\" u \ + JOIN \"device\" d ON u.id = d.user_id \ + WHERE d.id = $1", + device_id + ) + .fetch_optional(executor) + .await + } + + /// Find users which emails are NOT in `user_emails`. + pub async fn exclude<'e, E>(executor: E, user_emails: &[&str]) -> Result, SqlxError> + where + E: PgExecutor<'e>, + { + // This can't be a macro since sqlx can't handle an array of slices in a macro. + query_as( + "SELECT id, username, password_hash, last_name, first_name, email, phone, \ + mfa_enabled, totp_enabled, email_mfa_enabled, totp_secret, email_mfa_secret, \ + mfa_method, recovery_codes, is_active, openid_sub, from_ldap, ldap_pass_randomized, \ + ldap_rdn, ldap_user_path, enrollment_pending \ + FROM \"user\" WHERE email NOT IN (SELECT * FROM UNNEST($1::TEXT[]))", + ) + .bind(user_emails) + .fetch_all(executor) + .await + } + + pub async fn is_admin<'e, E>(&self, executor: E) -> Result + where + E: PgExecutor<'e>, + { + query_scalar!("SELECT EXISTS (SELECT 1 FROM group_user gu LEFT JOIN \"group\" g ON gu.group_id = g.id \ + WHERE is_admin = true AND user_id = $1) \"bool!\"", self.id) + .fetch_one(executor) + .await + } + + /// Find all users that are admins and are active. + pub async fn find_admins<'e, E>(executor: E) -> Result, SqlxError> + where + E: PgExecutor<'e>, + { + query_as!( + Self, + " + SELECT u.id, u.username, u.password_hash, u.last_name, u.first_name, u.email, \ + u.phone, u.mfa_enabled, u.totp_enabled, u.email_mfa_enabled, \ + u.totp_secret, u.email_mfa_secret, u.mfa_method \"mfa_method: _\", u.recovery_codes, u.is_active, u.openid_sub, \ + from_ldap, ldap_pass_randomized, ldap_rdn, ldap_user_path, enrollment_pending \ + FROM \"user\" u \ + WHERE EXISTS (SELECT 1 FROM group_user gu LEFT JOIN \"group\" g ON gu.group_id = g.id \ + WHERE is_admin = true AND user_id = u.id) AND u.is_active = true" + ) + .fetch_all(executor) + .await + } +} + +impl Distribution> for Standard { + fn sample(&self, rng: &mut R) -> User { + User { + id: rng.r#gen(), + username: Alphanumeric.sample_string(rng, 8), + password_hash: rng + .r#gen::() + .then_some(Alphanumeric.sample_string(rng, 8)), + last_name: Alphanumeric.sample_string(rng, 8), + first_name: Alphanumeric.sample_string(rng, 8), + email: format!("{}@defguard.net", Alphanumeric.sample_string(rng, 6)), + // FIXME: generate an actual phone number + phone: rng + .r#gen::() + .then_some(Alphanumeric.sample_string(rng, 9)), + mfa_enabled: rng.r#gen(), + is_active: true, + openid_sub: rng + .r#gen::() + .then_some(Alphanumeric.sample_string(rng, 8)), + totp_enabled: rng.r#gen(), + email_mfa_enabled: rng.r#gen(), + totp_secret: (0..20).map(|_| rng.r#gen()).collect(), + email_mfa_secret: (0..20).map(|_| rng.r#gen()).collect(), + mfa_method: match rng.r#gen_range(0..4) { + 0 => MFAMethod::None, + 1 => MFAMethod::Webauthn, + 2 => MFAMethod::OneTimePassword, + _ => MFAMethod::Email, + }, + recovery_codes: (0..3).map(|_| Alphanumeric.sample_string(rng, 6)).collect(), + from_ldap: false, + ldap_pass_randomized: false, + ldap_rdn: None, + ldap_user_path: None, + enrollment_pending: false, + } + } +} + +impl Distribution> for Standard { + fn sample(&self, rng: &mut R) -> User { + User { + id: NoId, + username: Alphanumeric.sample_string(rng, 8), + password_hash: rng + .r#gen::() + .then_some(Alphanumeric.sample_string(rng, 8)), + last_name: Alphanumeric.sample_string(rng, 8), + first_name: Alphanumeric.sample_string(rng, 8), + email: format!("{}@defguard.net", Alphanumeric.sample_string(rng, 6)), + // FIXME: generate an actual phone number + phone: rng + .r#gen::() + .then_some(Alphanumeric.sample_string(rng, 9)), + mfa_enabled: rng.r#gen(), + is_active: true, + openid_sub: rng + .r#gen::() + .then_some(Alphanumeric.sample_string(rng, 8)), + totp_enabled: rng.r#gen(), + email_mfa_enabled: rng.r#gen(), + totp_secret: (0..20).map(|_| rng.r#gen()).collect(), + email_mfa_secret: (0..20).map(|_| rng.r#gen()).collect(), + mfa_method: match rng.r#gen_range(0..4) { + 0 => MFAMethod::None, + 1 => MFAMethod::Webauthn, + 2 => MFAMethod::OneTimePassword, + _ => MFAMethod::Email, + }, + recovery_codes: (0..3).map(|_| Alphanumeric.sample_string(rng, 6)).collect(), + from_ldap: false, + ldap_pass_randomized: false, + ldap_rdn: None, + ldap_user_path: None, + enrollment_pending: false, + } + } +} + +#[cfg(test)] +mod test { + use crate::{ + config::{DefGuardConfig, SERVER_CONFIG}, + db::{models::settings::initialize_current_settings, setup_pool}, + }; + use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; + + use super::*; + + #[sqlx::test] + async fn test_mfa_code(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + + let config = DefGuardConfig::new_test_config(); + let _ = SERVER_CONFIG.set(config.clone()); + initialize_current_settings(&pool).await.unwrap(); + + let mut user = User::new( + "hpotter", + Some("pass123"), + "Potter", + "Harry", + "h.potter@hogwart.edu.uk", + None, + ) + .save(&pool) + .await + .unwrap(); + user.new_email_secret(&pool).await.unwrap(); + assert!(user.email_mfa_secret.is_some()); + let code = user.generate_email_mfa_code().unwrap(); + assert!( + user.verify_email_mfa_code(&code), + "code={code}, secret={:?}", + user.email_mfa_secret.unwrap() + ); + } + + #[sqlx::test] + async fn test_user(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + + let mut user = User::new( + "hpotter", + Some("pass123"), + "Potter", + "Harry", + "h.potter@hogwart.edu.uk", + None, + ) + .save(&pool) + .await + .unwrap(); + + let fetched_user = User::find_by_username(&pool, "hpotter").await.unwrap(); + assert!(fetched_user.is_some()); + assert_eq!(fetched_user.unwrap().email, "h.potter@hogwart.edu.uk"); + + user.email = "harry.potter@hogwart.edu.uk".into(); + user.save(&pool).await.unwrap(); + + let fetched_user = User::find_by_username(&pool, "hpotter").await.unwrap(); + assert!(fetched_user.is_some()); + assert_eq!(fetched_user.unwrap().email, "harry.potter@hogwart.edu.uk"); + + assert!(user.verify_password("pass123").is_ok()); + + let fetched_user = User::find_by_username(&pool, "rweasley").await.unwrap(); + assert!(fetched_user.is_none()); + } + + #[sqlx::test] + async fn test_all_users(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + + User::new( + "hpotter", + Some("pass123"), + "Potter", + "Harry", + "h.potter@hogwart.edu.uk", + None, + ) + .save(&pool) + .await + .unwrap(); + + let albus = User::new( + "adumbledore", + Some("magic!"), + "Dumbledore", + "Albus", + "a.dumbledore@hogwart.edu.uk", + None, + ) + .save(&pool) + .await + .unwrap(); + + let users = User::all(&pool).await.unwrap(); + assert_eq!(users.len(), 2); + + albus.delete(&pool).await.unwrap(); + + let users = User::all(&pool).await.unwrap(); + assert_eq!(users.len(), 1); + } + + #[sqlx::test] + async fn test_recovery_codes(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + + let mut harry = User::new( + "hpotter", + Some("pass123"), + "Potter", + "Harry", + "h.potter@hogwart.edu.uk", + None, + ) + .save(&pool) + .await + .unwrap(); + harry.get_recovery_codes(&pool).await.unwrap(); + assert_eq!(harry.recovery_codes.len(), RECOVERY_CODES_COUNT); + + let fetched_user = User::find_by_username(&pool, "hpotter").await.unwrap(); + assert!(fetched_user.is_some()); + + let mut user = fetched_user.unwrap(); + assert_eq!(user.recovery_codes.len(), RECOVERY_CODES_COUNT); + assert!( + !user + .verify_recovery_code(&pool, "invalid code") + .await + .unwrap() + ); + let codes = user.recovery_codes.clone(); + for code in &codes { + assert!(user.verify_recovery_code(&pool, code).await.unwrap()); + } + assert_eq!(user.recovery_codes.len(), 0); + } + + #[sqlx::test] + async fn test_email_case_insensitivity(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + + let harry = User::new( + "hpotter", + Some("pass123"), + "Potter", + "Harry", + "h.potter@hogwart.edu.uk", + None, + ); + assert!(harry.save(&pool).await.is_ok()); + + let henry = User::new( + "h.potter", + Some("pass123"), + "Potter", + "Henry", + "h.potter@hogwart.edu.uk", + None, + ); + assert!(henry.save(&pool).await.is_err()); + } + + #[sqlx::test] + async fn test_is_admin(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + + let config = DefGuardConfig::new_test_config(); + let _ = SERVER_CONFIG.set(config.clone()); + + let user = User::new( + "hpotter", + Some("pass123"), + "Potter", + "Harry", + "h.potter@hogwart.edu.uk", + None, + ) + .save(&pool) + .await + .unwrap(); + + let is_admin = user.is_admin(&pool).await.unwrap(); + + assert!(!is_admin); + + query!( + "INSERT INTO group_user (group_id, user_id) VALUES (1, $1)", + user.id + ) + .execute(&pool) + .await + .unwrap(); + + let is_admin = user.is_admin(&pool).await.unwrap(); + + assert!(is_admin); + } + + #[sqlx::test] + async fn test_find_admins(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + + let config = DefGuardConfig::new_test_config(); + let _ = SERVER_CONFIG.set(config.clone()); + + let user = User::new( + "hpotter", + Some("pass123"), + "Potter", + "Harry", + "h.potter@hogwart.edu.uk", + None, + ) + .save(&pool) + .await + .unwrap(); + + let user2 = User::new( + "hpotter2", + Some("pass123"), + "Potter", + "Harry", + "h.potter2@hogwart.edu.uk", + None, + ) + .save(&pool) + .await + .unwrap(); + + User::new( + "hpotter3", + Some("pass123"), + "Potter", + "Harry", + "h.potter3@hogwart.edu.uk", + None, + ) + .save(&pool) + .await + .unwrap(); + + query!( + "INSERT INTO group_user (group_id, user_id) VALUES (1, $1), (1, $2)", + user.id, + user2.id, + ) + .execute(&pool) + .await + .unwrap(); + + let admins = User::find_admins(&pool).await.unwrap(); + assert_eq!(admins.len(), 2); + assert!(admins.iter().any(|u| u.id == user.id)); + assert!(admins.iter().any(|u| u.id == user2.id)); + } + + #[sqlx::test] + async fn test_get_missing(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + + let user1 = User::new( + "hpotter", + Some("pass123"), + "Potter", + "Harry", + "h.potter@hogwart.edu.uk", + None, + ) + .save(&pool) + .await + .unwrap(); + let user2 = User::new( + "hpotter2", + Some("pass1234"), + "Potter2", + "Harry2", + "h.potter2@hogwart.edu.uk", + None, + ) + .save(&pool) + .await + .unwrap(); + let albus = User::new( + "adumbledore", + Some("magic!"), + "Dumbledore", + "Albus", + "a.dumbledore@hogwart.edu.uk", + None, + ) + .save(&pool) + .await + .unwrap(); + + let user_emails = vec![user1.email.as_str(), albus.email.as_str()]; + let users = User::exclude(&pool, &user_emails).await.unwrap(); + assert_eq!(users.len(), 1); + assert_eq!(users[0].id, user2.id); + } + + #[sqlx::test] + async fn test_find_many_by_emails(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + + let user1 = User::new( + "hpotter", + Some("pass123"), + "Potter", + "Harry", + "h.potter@hogwart.edu.uk", + None, + ) + .save(&pool) + .await + .unwrap(); + User::new( + "hpotter2", + Some("pass1234"), + "Potter2", + "Harry2", + "h.potter2@hogwart.edu.uk", + None, + ) + .save(&pool) + .await + .unwrap(); + let albus = User::new( + "adumbledore", + Some("magic!"), + "Dumbledore", + "Albus", + "a.dumbledore@hogwart.edu.uk", + None, + ) + .save(&pool) + .await + .unwrap(); + + let user_emails = vec![user1.email.as_str(), albus.email.as_str()]; + let users = User::find_many_by_emails(&pool, &user_emails) + .await + .unwrap(); + assert_eq!(users.len(), 2); + assert_eq!(users[0].id, user1.id); + assert_eq!(users[1].id, albus.id); + } + + #[sqlx::test] + async fn test_user_is_enrolled(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + let user = User::new( + "test", + Some("31071980"), + "harry", + "potter", + "harry@hogwart.edu.uk", + None, + ); + let mut user = user.save(&pool).await.unwrap(); + + user.enrollment_pending = false; + user.password_hash = Some(hash_password("31071980").unwrap()); + user.openid_sub = Some("sub".to_string()); + user.from_ldap = true; + user.save(&pool).await.unwrap(); + assert!(user.is_enrolled()); + + user.enrollment_pending = false; + user.password_hash = None; + user.openid_sub = Some("sub".to_string()); + user.from_ldap = true; + user.save(&pool).await.unwrap(); + assert!(user.is_enrolled()); + + user.enrollment_pending = false; + user.password_hash = None; + user.openid_sub = None; + user.from_ldap = true; + user.save(&pool).await.unwrap(); + assert!(user.is_enrolled()); + + user.enrollment_pending = false; + user.password_hash = None; + user.openid_sub = None; + user.from_ldap = false; + user.save(&pool).await.unwrap(); + assert!(!user.is_enrolled()); + + user.enrollment_pending = true; + user.password_hash = None; + user.openid_sub = None; + user.from_ldap = false; + user.save(&pool).await.unwrap(); + assert!(!user.is_enrolled()); + + user.enrollment_pending = true; + user.password_hash = Some(hash_password("31071980").unwrap()); + user.openid_sub = Some("sub".to_string()); + user.from_ldap = true; + user.save(&pool).await.unwrap(); + assert!(!user.is_enrolled()); + } +} diff --git a/crates/defguard_core/src/db/models/webauthn.rs b/crates/defguard_common/src/db/models/webauthn.rs similarity index 94% rename from crates/defguard_core/src/db/models/webauthn.rs rename to crates/defguard_common/src/db/models/webauthn.rs index bd5a912f70..2861a13b10 100644 --- a/crates/defguard_core/src/db/models/webauthn.rs +++ b/crates/defguard_common/src/db/models/webauthn.rs @@ -1,4 +1,4 @@ -use defguard_common::db::{Id, NoId, models::ModelError}; +use crate::db::{Id, NoId, models::ModelError}; use model_derive::Model; use sqlx::{Error as SqlxError, PgExecutor, PgPool, query, query_as, query_scalar}; use webauthn_rs::prelude::Passkey; @@ -26,7 +26,7 @@ impl WebAuthn { impl WebAuthn { /// Serialize [`Passkey`] from binary data. - pub(crate) fn passkey(&self) -> Result { + pub fn passkey(&self) -> Result { let passkey = serde_cbor::from_slice(&self.passkey).map_err(|_| ModelError::CannotCreate)?; diff --git a/crates/defguard_core/src/db/models/wireguard.rs b/crates/defguard_common/src/db/models/wireguard.rs similarity index 62% rename from crates/defguard_core/src/db/models/wireguard.rs rename to crates/defguard_common/src/db/models/wireguard.rs index 33c26e4989..04c4de1db3 100644 --- a/crates/defguard_core/src/db/models/wireguard.rs +++ b/crates/defguard_common/src/db/models/wireguard.rs @@ -5,44 +5,36 @@ use std::{ net::{IpAddr, Ipv4Addr}, }; -use base64::prelude::{BASE64_STANDARD, Engine}; -use chrono::{NaiveDateTime, TimeDelta, Utc}; -use defguard_common::{ +use crate::{ auth::claims::{Claims, ClaimsType}, - csv::AsCsv, - db::{Id, NoId, models::ModelError}, -}; -use defguard_proto::{ - enterprise::firewall::FirewallConfig, - gateway::Peer, - proxy::{ - LocationMfaMode as ProtoLocationMfaMode, ServiceLocationMode as ProtoServiceLocationMode, + db::{ + Id, NoId, + models::{ + ModelError, + group::{Group, Permission}, + wireguard_peer_stats::WireguardPeerStats, + }, }, + types::user_info::UserInfo, }; +use base64::prelude::{BASE64_STANDARD, Engine}; +use chrono::{NaiveDateTime, TimeDelta, Utc}; use ipnetwork::{IpNetwork, IpNetworkError, NetworkSize}; use model_derive::Model; use rand::rngs::OsRng; +use serde::{Deserialize, Serialize}; use sqlx::{ Error as SqlxError, FromRow, PgConnection, PgExecutor, PgPool, Type, - postgres::types::PgInterval, query_as, query_scalar, + postgres::types::PgInterval, query, query_as, query_scalar, }; use thiserror::Error; -use tokio::sync::broadcast::Sender; +use tracing::{debug, info}; use utoipa::ToSchema; use x25519_dalek::{PublicKey, StaticSecret}; use super::{ - UserInfo, - device::{ - Device, DeviceError, DeviceInfo, DeviceNetworkInfo, DeviceType, WireguardNetworkDevice, - }, + device::{Device, DeviceError, DeviceType, WireguardNetworkDevice}, user::User, - wireguard_peer_stats::WireguardPeerStats, -}; -use crate::{ - enterprise::{firewall::FirewallError, is_enterprise_enabled}, - grpc::gateway::{send_multiple_wireguard_events, state::GatewayState}, - wg_config::ImportedDevice, }; pub const DEFAULT_KEEPALIVE_INTERVAL: i32 = 25; @@ -76,18 +68,6 @@ impl DateTimeAggregation { } } -#[derive(Clone, Debug)] -pub enum GatewayEvent { - NetworkCreated(Id, WireguardNetwork), - NetworkModified(Id, WireguardNetwork, Vec, Option), - NetworkDeleted(Id, String), - DeviceCreated(DeviceInfo), - DeviceModified(DeviceInfo), - DeviceDeleted(DeviceInfo), - FirewallConfigChanged(Id, FirewallConfig), - FirewallDisabled(Id), -} - #[derive(Clone, Debug, Default, Deserialize, Eq, Hash, PartialEq, Serialize, ToSchema, Type)] #[sqlx(type_name = "location_mfa_mode", rename_all = "lowercase")] #[serde(rename_all = "lowercase")] @@ -108,27 +88,6 @@ impl Display for LocationMfaMode { } } -impl From for LocationMfaMode { - fn from(value: ProtoLocationMfaMode) -> Self { - match value { - ProtoLocationMfaMode::Unspecified | ProtoLocationMfaMode::Disabled => { - LocationMfaMode::Disabled - } - ProtoLocationMfaMode::Internal => LocationMfaMode::Internal, - ProtoLocationMfaMode::External => LocationMfaMode::External, - } - } -} -impl From for ProtoLocationMfaMode { - fn from(value: LocationMfaMode) -> Self { - match value { - LocationMfaMode::Disabled => ProtoLocationMfaMode::Disabled, - LocationMfaMode::Internal => ProtoLocationMfaMode::Internal, - LocationMfaMode::External => ProtoLocationMfaMode::External, - } - } -} - #[derive(Clone, Debug, Default, Deserialize, Eq, Hash, PartialEq, Serialize, ToSchema, Type)] #[sqlx(type_name = "service_location_mode", rename_all = "lowercase")] #[serde(rename_all = "lowercase")] @@ -139,28 +98,6 @@ pub enum ServiceLocationMode { AlwaysOn, } -impl From for ServiceLocationMode { - fn from(value: ProtoServiceLocationMode) -> Self { - match value { - ProtoServiceLocationMode::Unspecified | ProtoServiceLocationMode::Disabled => { - ServiceLocationMode::Disabled - } - ProtoServiceLocationMode::Prelogon => ServiceLocationMode::PreLogon, - ProtoServiceLocationMode::Alwayson => ServiceLocationMode::AlwaysOn, - } - } -} - -impl From for ProtoServiceLocationMode { - fn from(value: ServiceLocationMode) -> Self { - match value { - ServiceLocationMode::Disabled => ProtoServiceLocationMode::Disabled, - ServiceLocationMode::PreLogon => ProtoServiceLocationMode::Prelogon, - ServiceLocationMode::AlwaysOn => ProtoServiceLocationMode::Alwayson, - } - } -} - /// Stores configuration required to setup a WireGuard network #[derive(Clone, Deserialize, Eq, Hash, Model, PartialEq, Serialize, ToSchema)] #[table(wireguard_network)] @@ -272,8 +209,6 @@ pub enum WireguardNetworkError { DeviceNotAllowed(String), #[error("Device error")] DeviceError(#[from] DeviceError), - #[error("Firewall config error: {0}")] - FirewallError(#[from] FirewallError), #[error(transparent)] TokenError(#[from] jsonwebtoken::errors::Error), } @@ -297,6 +232,7 @@ pub enum NetworkAddressError { } impl WireguardNetwork { + #[allow(clippy::too_many_arguments)] #[must_use] pub fn new( name: String, @@ -336,9 +272,8 @@ impl WireguardNetwork { } /// Try to set `address` from `&str`. - #[cfg(test)] - pub(crate) fn try_set_address(&mut self, address: &str) -> Result<(), IpNetworkError> { - use crate::handlers::wireguard::parse_address_list; + pub fn try_set_address(&mut self, address: &str) -> Result<(), IpNetworkError> { + use crate::utils::parse_address_list; let address = parse_address_list(address); if address.is_empty() { @@ -351,7 +286,7 @@ impl WireguardNetwork { } impl WireguardNetwork { - pub(crate) async fn find_by_name<'e, E>( + pub async fn find_by_name<'e, E>( executor: E, name: &str, ) -> Result>, WireguardNetworkError> @@ -377,36 +312,8 @@ impl WireguardNetwork { Ok(Some(networks)) } - // run sync_allowed_devices on all wireguard networks - pub(crate) async fn sync_all_networks( - conn: &mut PgConnection, - wireguard_tx: &Sender, - ) -> Result<(), WireguardNetworkError> { - info!("Syncing allowed devices for all WireGuard locations"); - let networks = Self::all(&mut *conn).await?; - for network in networks { - // sync allowed devices for location - let mut gateway_events = network.sync_allowed_devices(&mut *conn, None).await?; - - // send firewall config update if ACLs are enabled for a given location - if let Some(firewall_config) = network.try_get_firewall_config(&mut *conn).await? { - gateway_events.push(GatewayEvent::FirewallConfigChanged( - network.id, - firewall_config, - )); - } - // check if any gateway events need to be sent - if !gateway_events.is_empty() { - send_multiple_wireguard_events(gateway_events, wireguard_tx); - } - } - Ok(()) - } - - pub(crate) fn validate_network_size( - &self, - device_count: usize, - ) -> Result<(), WireguardNetworkError> { + #[allow(clippy::result_large_err)] + pub fn validate_network_size(&self, device_count: usize) -> Result<(), WireguardNetworkError> { debug!("Checking if {device_count} devices can fit in networks used by location {self}"); // if given location uses multiple subnets validate devices can fit them all for subnet in &self.address { @@ -432,7 +339,7 @@ impl WireguardNetwork { /// Utility method to create WireGuard keypair #[must_use] - pub(crate) fn genkey() -> WireguardKey { + pub fn genkey() -> WireguardKey { let private = StaticSecret::random_from_rng(OsRng); let public = PublicKey::from(&private); WireguardKey { @@ -444,7 +351,7 @@ impl WireguardNetwork { /// Get a list of all devices belonging to users in allowed groups. /// Admin users should always be allowed to access a network. /// Note: Doesn't check if the devices are really in the network. - pub(crate) async fn get_allowed_devices( + pub async fn get_allowed_devices( &self, transaction: &mut PgConnection, ) -> Result>, ModelError> { @@ -491,7 +398,7 @@ impl WireguardNetwork { /// Get a list of devices belonging to a user which are also in the network's allowed groups. /// Admin users should always be allowed to access a network. /// Note: Doesn't check if the devices are really in the network. - async fn get_allowed_devices_for_user( + pub async fn get_allowed_devices_for_user( &self, transaction: &mut PgConnection, user_id: Id, @@ -541,7 +448,7 @@ impl WireguardNetwork { /// Generate network IPs for all existing devices /// If `allowed_groups` is set, devices should be filtered accordingly - pub(crate) async fn add_all_allowed_devices( + pub async fn add_all_allowed_devices( &self, transaction: &mut PgConnection, ) -> Result<(), ModelError> { @@ -593,342 +500,6 @@ impl WireguardNetwork { self.address.iter().find(|net| net.contains(addr)).copied() } - /// Works out which devices need to be added, removed, or readdressed based on the list - /// of currently configured devices and the list of devices which should be allowed. - async fn process_device_access_changes( - &self, - transaction: &mut PgConnection, - mut allowed_devices: HashMap>, - currently_configured_devices: Vec, - reserved_ips: Option<&[IpAddr]>, - ) -> Result, WireguardNetworkError> { - // Loop through current device configurations; remove no longer allowed, readdress - // when necessary; remove processed entry from all devices list initial list should - // now contain only devices to be added. - let mut events: Vec = Vec::new(); - for device_network_config in currently_configured_devices { - // Device is allowed and an IP was already assigned - if let Some(device) = allowed_devices.remove(&device_network_config.device_id) { - // Network address has changed and IP addresses need to be updated - if !self.contains_all(&device_network_config.wireguard_ips) - || self.address.len() != device_network_config.wireguard_ips.len() - { - let wireguard_network_device = device - .assign_next_network_ip( - &mut *transaction, - self, - reserved_ips, - Some(&device_network_config.wireguard_ips), - ) - .await?; - events.push(GatewayEvent::DeviceModified(DeviceInfo { - device, - network_info: vec![DeviceNetworkInfo { - network_id: self.id, - device_wireguard_ips: wireguard_network_device.wireguard_ips, - preshared_key: wireguard_network_device.preshared_key, - is_authorized: wireguard_network_device.is_authorized, - }], - })); - } - // Device is no longer allowed - } else { - debug!( - "Device {} no longer allowed, removing network config for {self}", - device_network_config.device_id - ); - device_network_config.delete(&mut *transaction).await?; - if let Some(device) = - Device::find_by_id(&mut *transaction, device_network_config.device_id).await? - { - events.push(GatewayEvent::DeviceDeleted(DeviceInfo { - device, - network_info: vec![DeviceNetworkInfo { - network_id: self.id, - device_wireguard_ips: device_network_config.wireguard_ips, - preshared_key: device_network_config.preshared_key, - is_authorized: device_network_config.is_authorized, - }], - })); - } else { - let msg = format!("Device {} does not exist", device_network_config.device_id); - error!(msg); - return Err(WireguardNetworkError::Unexpected(msg)); - } - } - } - - // Add configs for new allowed devices - for device in allowed_devices.into_values() { - let wireguard_network_device = device - .assign_next_network_ip(&mut *transaction, self, reserved_ips, None) - .await?; - events.push(GatewayEvent::DeviceCreated(DeviceInfo { - device, - network_info: vec![DeviceNetworkInfo { - network_id: self.id, - device_wireguard_ips: wireguard_network_device.wireguard_ips, - preshared_key: wireguard_network_device.preshared_key, - is_authorized: wireguard_network_device.is_authorized, - }], - })); - } - - Ok(events) - } - - /// Refresh network IPs for all relevant devices of a given user - /// If the list of allowed devices has changed add/remove devices accordingly - /// If the network address has changed readdress existing devices - pub(crate) async fn sync_allowed_devices_for_user( - &self, - transaction: &mut PgConnection, - user: &User, - reserved_ips: Option<&[IpAddr]>, - ) -> Result, WireguardNetworkError> { - info!("Synchronizing IPs in network {self} for all allowed devices "); - // list all allowed devices - let allowed_devices = self - .get_allowed_devices_for_user(&mut *transaction, user.id) - .await?; - - // convert to a map for easier processing - let allowed_devices: HashMap> = allowed_devices - .into_iter() - .map(|dev| (dev.id, dev)) - .collect(); - - // check if all devices can fit within network - // include address, network, and broadcast in the calculation - let count = allowed_devices.len() + 3; - self.validate_network_size(count)?; - - // list all assigned IPs - let assigned_ips = - WireguardNetworkDevice::all_for_network_and_user(&mut *transaction, self.id, user.id) - .await?; - - let events = self - .process_device_access_changes( - &mut *transaction, - allowed_devices, - assigned_ips, - reserved_ips, - ) - .await?; - - Ok(events) - } - - /// Refresh network IPs for all relevant devices - /// - /// If the list of allowed devices has changed add/remove devices accordingly - /// - /// If the network address has changed readdress existing devices - pub(crate) async fn sync_allowed_devices( - &self, - conn: &mut PgConnection, - reserved_ips: Option<&[IpAddr]>, - ) -> Result, WireguardNetworkError> { - info!("Synchronizing IPs in network {self} for all allowed devices "); - // list all allowed devices - let mut allowed_devices = self.get_allowed_devices(&mut *conn).await?; - - // network devices are always allowed, make sure to take only network devices already assigned to that network - let network_devices = - Device::find_by_type_and_network(&mut *conn, DeviceType::Network, self.id).await?; - allowed_devices.extend(network_devices); - - // convert to a map for easier processing - let allowed_devices: HashMap> = allowed_devices - .into_iter() - .map(|dev| (dev.id, dev)) - .collect(); - - // check if all devices can fit within network - // include address, network, and broadcast in the calculation - let count = allowed_devices.len() + 3; - self.validate_network_size(count)?; - - // list all assigned IPs - let assigned_ips = WireguardNetworkDevice::all_for_network(&mut *conn, self.id).await?; - - let events = self - .process_device_access_changes(&mut *conn, allowed_devices, assigned_ips, reserved_ips) - .await?; - - Ok(events) - } - - /// Check if devices found in an imported config file exist already, - /// if they do assign a specified IP. - /// Return a list of imported devices which need to be manually mapped to a user - /// and a list of WireGuard events to be sent out. - pub(crate) async fn handle_imported_devices( - &self, - transaction: &mut PgConnection, - imported_devices: Vec, - ) -> Result<(Vec, Vec), WireguardNetworkError> { - let allowed_devices = self.get_allowed_devices(&mut *transaction).await?; - // convert to a map for easier processing - let allowed_devices: HashMap> = allowed_devices - .into_iter() - .map(|dev| (dev.id, dev)) - .collect(); - - let mut devices_to_map = Vec::new(); - let mut assigned_device_ids = Vec::new(); - let mut events = Vec::new(); - for imported_device in imported_devices { - // check if device with a given pubkey exists already - match Device::find_by_pubkey(&mut *transaction, &imported_device.wireguard_pubkey) - .await? - { - Some(existing_device) => { - // check if device is allowed in network - match allowed_devices.get(&existing_device.id) { - Some(_) => { - info!( - "Device with pubkey {} exists already, assigning IPs {} for new network: {self}", - existing_device.wireguard_pubkey, - imported_device.wireguard_ips.as_csv() - ); - let wireguard_network_device = WireguardNetworkDevice::new( - self.id, - existing_device.id, - imported_device.wireguard_ips, - ); - wireguard_network_device.insert(&mut *transaction).await?; - // store ID of device with already generated config - assigned_device_ids.push(existing_device.id); - // send device to connected gateways - events.push(GatewayEvent::DeviceModified(DeviceInfo { - device: existing_device, - network_info: vec![DeviceNetworkInfo { - network_id: self.id, - device_wireguard_ips: wireguard_network_device.wireguard_ips, - preshared_key: wireguard_network_device.preshared_key, - is_authorized: wireguard_network_device.is_authorized, - }], - })); - } - None => { - warn!( - "Device with pubkey {} exists already, but is not allowed in network {self}. Skipping...", - existing_device.wireguard_pubkey - ); - } - } - } - None => devices_to_map.push(imported_device), - } - } - - Ok((devices_to_map, events)) - } - - /// Handle device -> user mapping in second step of network import wizard - pub(crate) async fn handle_mapped_devices( - &self, - transaction: &mut PgConnection, - mapped_devices: Vec, - ) -> Result, WireguardNetworkError> { - info!("Mapping user devices for network {}", self); - // get allowed groups for network - let allowed_groups = self.get_allowed_groups(&mut *transaction).await?; - - let mut events = Vec::new(); - // use a helper hashmap to avoid repeated queries - let mut user_groups = HashMap::new(); - for mapped_device in &mapped_devices { - debug!("Mapping device {}", mapped_device.name); - // validate device pubkey - Device::validate_pubkey(&mapped_device.wireguard_pubkey).map_err(|_| { - WireguardNetworkError::InvalidDevicePubkey(mapped_device.wireguard_pubkey.clone()) - })?; - // save a new device - let device = Device::new( - mapped_device.name.clone(), - mapped_device.wireguard_pubkey.clone(), - mapped_device.user_id, - DeviceType::User, - None, - true, - ) - .save(&mut *transaction) - .await?; - debug!("Saved new device {device}"); - - // get a list of groups user is assigned to - let groups = match user_groups.get(&device.user_id) { - // user info has already been fetched before - Some(groups) => groups, - // fetch user info - None => match User::find_by_id(&mut *transaction, device.user_id).await? { - Some(user) => { - let groups = user.member_of_names(&mut *transaction).await?; - user_groups.insert(device.user_id, groups); - // FIXME: ugly workaround to get around `groups` being dropped - user_groups.get(&device.user_id).unwrap() - } - None => return Err(WireguardNetworkError::from(ModelError::NotFound)), - }, - }; - - let mut network_info = Vec::new(); - match &allowed_groups { - None => { - let wireguard_network_device = WireguardNetworkDevice::new( - self.id, - device.id, - mapped_device.wireguard_ips.clone(), - ); - wireguard_network_device.insert(&mut *transaction).await?; - network_info.push(DeviceNetworkInfo { - network_id: self.id, - device_wireguard_ips: wireguard_network_device.wireguard_ips, - preshared_key: wireguard_network_device.preshared_key, - is_authorized: wireguard_network_device.is_authorized, - }); - } - Some(allowed) => { - // check if user belongs to an allowed group - if allowed.iter().any(|group| groups.contains(group)) { - // assign specified IP in imported network - let wireguard_network_device = WireguardNetworkDevice::new( - self.id, - device.id, - mapped_device.wireguard_ips.clone(), - ); - wireguard_network_device.insert(&mut *transaction).await?; - network_info.push(DeviceNetworkInfo { - network_id: self.id, - device_wireguard_ips: wireguard_network_device.wireguard_ips, - preshared_key: wireguard_network_device.preshared_key, - is_authorized: wireguard_network_device.is_authorized, - }); - } - } - } - - // assign IPs in other networks - let (mut all_network_info, _configs) = - device.add_to_all_networks(&mut *transaction).await?; - - network_info.append(&mut all_network_info); - - // send device to connected gateways - if !network_info.is_empty() { - events.push(GatewayEvent::DeviceCreated(DeviceInfo { - device, - network_info, - })); - } - } - - Ok(events) - } - /// Finds when the device connected based on handshake timestamps. async fn connected_at( &self, @@ -1030,7 +601,7 @@ impl WireguardNetwork { Ok(result) } - pub(crate) async fn distinct_device_stats( + pub async fn distinct_device_stats( &self, conn: &PgPool, from: &NaiveDateTime, @@ -1057,7 +628,7 @@ impl WireguardNetwork { } /// Retrieves network stats grouped by currently active users since `from` timestamp. - pub(crate) async fn user_stats( + pub async fn user_stats( &self, conn: &PgPool, from: &NaiveDateTime, @@ -1166,7 +737,7 @@ impl WireguardNetwork { } /// Retrieves network stats - pub(crate) async fn network_stats( + pub async fn network_stats( &self, conn: &PgPool, from: &NaiveDateTime, @@ -1238,7 +809,7 @@ impl WireguardNetwork { /// /// - `Ok(())`: All addresses passed every check. /// - `Err(NetworkIpAssignmentError)`: The first failing check. - pub(crate) async fn can_assign_ips( + pub async fn can_assign_ips( &self, transaction: &mut PgConnection, ip_addrs: &[IpAddr], @@ -1296,7 +867,7 @@ impl WireguardNetwork { } // fetch all locations using external MFA - pub(crate) async fn all_using_external_mfa<'e, E>( + pub async fn all_using_external_mfa<'e, E>( executor: E, ) -> Result, WireguardNetworkError> where @@ -1317,6 +888,7 @@ impl WireguardNetwork { } /// Generates auth token for a VPN gateway + #[allow(clippy::result_large_err)] pub fn generate_gateway_token(&self) -> Result { let location_id = self.id; @@ -1331,11 +903,138 @@ impl WireguardNetwork { Ok(token) } - /// If this location is marked as a service location, checks if all requirements are met for it to function: - /// - Enterprise is enabled - #[must_use] - pub fn should_prevent_service_location_usage(&self) -> bool { - self.service_location_mode != ServiceLocationMode::Disabled && !is_enterprise_enabled() + /// Fetch a list of all allowed groups for a given network from DB + pub async fn fetch_allowed_groups<'e, E>(&self, executor: E) -> Result, ModelError> + where + E: PgExecutor<'e>, + { + debug!("Fetching all allowed groups for network {self}"); + let groups = query_scalar!( + "SELECT name FROM wireguard_network_allowed_group wag \ + JOIN \"group\" g ON wag.group_id = g.id WHERE wag.network_id = $1", + self.id + ) + .fetch_all(executor) + .await?; + + Ok(groups) + } + + /// Return a list of allowed groups for a given network. + /// Admin group should always be included. + /// If no `allowed_groups` are specified for a network then all devices are allowed. + /// In this case `None` is returned to signify that there's no filtering. + /// This helper method is meant for use in all business logic gating + /// access to networks based on allowed groups. + pub async fn get_allowed_groups( + &self, + conn: &mut PgConnection, + ) -> Result>, ModelError> { + debug!("Returning a list of allowed groups for network {self}"); + let admin_groups = Group::find_by_permission(&mut *conn, Permission::IsAdmin).await?; + + // get allowed groups from DB + let mut groups = self.fetch_allowed_groups(&mut *conn).await?; + + // if no allowed groups are set then all groups are allowed + if groups.is_empty() { + return Ok(None); + } + + for group in admin_groups { + if !groups.iter().any(|name| name == &group.name) { + groups.push(group.name); + } + } + + Ok(Some(groups)) + } + + /// Set allowed groups, removing or adding groups as necessary. + pub async fn set_allowed_groups( + &self, + transaction: &mut PgConnection, + allowed_groups: Vec, + ) -> Result<(), ModelError> { + info!("Setting allowed groups for network {self} to: {allowed_groups:?}"); + if allowed_groups.is_empty() { + return self.clear_allowed_groups(transaction).await; + } + + // get list of current allowed groups + let mut current_groups = self.fetch_allowed_groups(&mut *transaction).await?; + + // add to group if not already a member + for group in &allowed_groups { + if !current_groups.contains(group) { + self.add_to_group(transaction, group).await?; + } + } + + // remove groups which are no longer present + current_groups.retain(|group| !allowed_groups.contains(group)); + if !current_groups.is_empty() { + self.remove_from_groups(transaction, current_groups).await?; + } + + Ok(()) + } + + pub async fn add_to_group( + &self, + transaction: &mut PgConnection, + group: &str, + ) -> Result<(), ModelError> { + info!("Adding allowed group {group} for network {self}"); + query!( + "INSERT INTO wireguard_network_allowed_group (network_id, group_id) \ + SELECT $1, g.id FROM \"group\" g WHERE g.name = $2", + self.id, + group + ) + .execute(transaction) + .await?; + Ok(()) + } + + pub async fn remove_from_groups( + &self, + transaction: &mut PgConnection, + groups: Vec, + ) -> Result<(), ModelError> { + info!("Removing allowed groups {groups:?} for network {self}"); + let result = query!( + "DELETE FROM wireguard_network_allowed_group \ + WHERE network_id = $1 AND group_id IN ( \ + SELECT id FROM \"group\" \ + WHERE name IN (SELECT * FROM UNNEST($2::text[])) \ + )", + self.id, + &groups + ) + .execute(transaction) + .await?; + info!( + "Removed {} allowed groups for network {self}", + result.rows_affected(), + ); + Ok(()) + } + + /// Remove all allowed groups for a given network + async fn clear_allowed_groups(&self, transaction: &mut PgConnection) -> Result<(), ModelError> { + info!("Removing all allowed groups for network {self}"); + let result = query!( + "DELETE FROM wireguard_network_allowed_group WHERE network_id=$1", + self.id + ) + .execute(transaction) + .await?; + info!( + "Removed {} allowed groups for network {self}", + result.rows_affected(), + ); + Ok(()) } } @@ -1363,15 +1062,6 @@ impl Default for WireguardNetwork { } } -#[derive(Serialize, ToSchema)] -pub struct WireguardNetworkInfo { - #[serde(flatten)] - pub network: WireguardNetwork, - pub connected: bool, - pub gateways: Vec, - pub allowed_groups: Vec, -} - #[derive(Clone, Serialize, Deserialize, PartialEq)] pub struct WireguardStatsRow { pub collected_at: Option, @@ -1428,7 +1118,7 @@ pub struct WireguardNetworkStats { pub transfer_series: Vec, } -pub(crate) async fn networks_stats( +pub async fn networks_stats( conn: &PgPool, from: &NaiveDateTime, aggregation: &DateTimeAggregation, @@ -1495,13 +1185,12 @@ pub(crate) async fn networks_stats( mod test { use std::str::FromStr; + use crate::db::setup_pool; use chrono::{SubsecRound, TimeDelta, Utc}; - use defguard_common::db::setup_pool; use matches::assert_matches; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use super::*; - use crate::db::Group; #[sqlx::test] async fn test_connected_at_reconnection(_: PgPoolOptions, options: PgConnectOptions) { @@ -1803,259 +1492,6 @@ mod test { assert!(devices.is_empty()); } - #[sqlx::test] - async fn test_sync_allowed_devices_for_user(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - let mut network = WireguardNetwork::default(); - network.try_set_address("10.1.1.1/29").unwrap(); - let network = network.save(&pool).await.unwrap(); - - let user1 = User::new( - "testuser1", - Some("pass1"), - "Tester1", - "Test1", - "test1@test.com", - None, - ) - .save(&pool) - .await - .unwrap(); - - let user2 = User::new( - "testuser2", - Some("pass2"), - "Tester2", - "Test2", - "test2@test.com", - None, - ) - .save(&pool) - .await - .unwrap(); - - let device1 = Device::new( - "device1".into(), - "key1".into(), - user1.id, - DeviceType::User, - None, - true, - ) - .save(&pool) - .await - .unwrap(); - - let device2 = Device::new( - "device2".into(), - "key2".into(), - user1.id, - DeviceType::User, - None, - true, - ) - .save(&pool) - .await - .unwrap(); - - let device3 = Device::new( - "device3".into(), - "key3".into(), - user2.id, - DeviceType::User, - None, - true, - ) - .save(&pool) - .await - .unwrap(); - - let mut transaction = pool.begin().await.unwrap(); - - // user1 sync - let events = network - .sync_allowed_devices_for_user(&mut transaction, &user1, None) - .await - .unwrap(); - - assert_eq!(events.len(), 2); - assert!(events.iter().any(|e| match e { - GatewayEvent::DeviceCreated(info) => info.device.id == device1.id, - _ => false, - })); - assert!(events.iter().any(|e| match e { - GatewayEvent::DeviceCreated(info) => info.device.id == device2.id, - _ => false, - })); - - // user 2 sync - let events = network - .sync_allowed_devices_for_user(&mut transaction, &user2, None) - .await - .unwrap(); - - assert_eq!(events.len(), 1); - match &events[0] { - GatewayEvent::DeviceCreated(info) => { - assert_eq!(info.device.id, device3.id); - } - _ => panic!("Expected DeviceCreated event"), - } - - // Second sync should not generate any events - let events = network - .sync_allowed_devices_for_user(&mut transaction, &user1, None) - .await - .unwrap(); - assert_eq!(events.len(), 0); - - transaction.commit().await.unwrap(); - } - - #[sqlx::test] - async fn test_sync_allowed_devices_for_user_with_groups( - _: PgPoolOptions, - options: PgConnectOptions, - ) { - let pool = setup_pool(options).await; - let mut network = WireguardNetwork::default(); - network.try_set_address("10.1.1.1/29").unwrap(); - let network = network.save(&pool).await.unwrap(); - - let user1 = User::new( - "testuser1", - Some("pass1"), - "Tester1", - "Test1", - "test1@test.com", - None, - ) - .save(&pool) - .await - .unwrap(); - - let user2 = User::new( - "testuser2", - Some("pass2"), - "Tester2", - "Test2", - "test2@test.com", - None, - ) - .save(&pool) - .await - .unwrap(); - - let user3 = User::new( - "testuser3", - Some("pass3"), - "Tester3", - "Test3", - "test3@test.com", - None, - ) - .save(&pool) - .await - .unwrap(); - - let device1 = Device::new( - "device1".into(), - "key1".into(), - user1.id, - DeviceType::User, - None, - true, - ) - .save(&pool) - .await - .unwrap(); - - let device2 = Device::new( - "device2".into(), - "key2".into(), - user2.id, - DeviceType::User, - None, - true, - ) - .save(&pool) - .await - .unwrap(); - - let device3 = Device::new( - "device3".into(), - "key3".into(), - user3.id, - DeviceType::User, - None, - true, - ) - .save(&pool) - .await - .unwrap(); - - let group1 = Group::new("group1").save(&pool).await.unwrap(); - let group2 = Group::new("group2").save(&pool).await.unwrap(); - - let mut transaction = pool.begin().await.unwrap(); - - network - .set_allowed_groups( - &mut transaction, - vec![group1.name.clone(), group2.name.clone()], - ) - .await - .unwrap(); - - let events = network - .sync_allowed_devices_for_user(&mut transaction, &user1, None) - .await - .unwrap(); - assert_eq!(events.len(), 0); - - user1.add_to_group(&pool, &group1).await.unwrap(); - user2.add_to_group(&pool, &group1).await.unwrap(); - user3.add_to_group(&pool, &group2).await.unwrap(); - - let events = network - .sync_allowed_devices_for_user(&mut transaction, &user1, None) - .await - .unwrap(); - assert_eq!(events.len(), 1); - match &events[0] { - GatewayEvent::DeviceCreated(info) => { - assert_eq!(info.device.id, device1.id); - } - _ => panic!("Expected DeviceCreated event"), - } - - let events = network - .sync_allowed_devices_for_user(&mut transaction, &user2, None) - .await - .unwrap(); - assert_eq!(events.len(), 1); - match &events[0] { - GatewayEvent::DeviceCreated(info) => { - assert_eq!(info.device.id, device2.id); - } - _ => panic!("Expected DeviceCreated event"), - } - - let events = network - .sync_allowed_devices_for_user(&mut transaction, &user3, None) - .await - .unwrap(); - assert_eq!(events.len(), 1); - match &events[0] { - GatewayEvent::DeviceCreated(info) => { - assert_eq!(info.device.id, device3.id); - } - _ => panic!("Expected DeviceCreated event"), - } - - transaction.commit().await.unwrap(); - } - #[sqlx::test] async fn test_can_assign_ips(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; @@ -2324,167 +1760,4 @@ mod test { Err(NetworkAddressError::IsBroadcastAddress(..)) ); } - - #[sqlx::test] - async fn test_get_peers_service_location_modes(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - - let user = User::new( - "testuser", - Some("password123"), - "Test", - "User", - "test@example.com", - None, - ) - .save(&pool) - .await - .unwrap(); - - let device1 = Device::new( - "device1".into(), - "pubkey1".into(), - user.id, - DeviceType::User, - None, - true, - ) - .save(&pool) - .await - .unwrap(); - - let device2 = Device::new( - "device2".into(), - "pubkey2".into(), - user.id, - DeviceType::User, - None, - true, - ) - .save(&pool) - .await - .unwrap(); - - // Normal location (service_location_mode = Disabled) should return peers - let mut network_normal = WireguardNetwork { - name: "normal-location".to_string(), - service_location_mode: ServiceLocationMode::Disabled, - location_mfa_mode: LocationMfaMode::Disabled, - ..Default::default() - }; - network_normal.try_set_address("10.1.1.1/24").unwrap(); - let network_normal = network_normal.save(&pool).await.unwrap(); - - WireguardNetworkDevice::new( - network_normal.id, - device1.id, - vec![IpAddr::from_str("10.1.1.2").unwrap()], - ) - .insert(&pool) - .await - .unwrap(); - - let peers_normal = network_normal.get_peers(&pool).await.unwrap(); - assert_eq!(peers_normal.len(), 1, "Normal location should return peers"); - assert_eq!(peers_normal[0].pubkey, "pubkey1"); - - // Service location with PreLogon mode returns peers when enterprise is enabled (test env default) - let mut network_prelogon = WireguardNetwork { - name: "prelogon-service-location".to_string(), - service_location_mode: ServiceLocationMode::PreLogon, - location_mfa_mode: LocationMfaMode::Disabled, - ..Default::default() - }; - network_prelogon.try_set_address("10.2.1.1/24").unwrap(); - let network_prelogon = network_prelogon.save(&pool).await.unwrap(); - - WireguardNetworkDevice::new( - network_prelogon.id, - device2.id, - vec![IpAddr::from_str("10.2.1.2").unwrap()], - ) - .insert(&pool) - .await - .unwrap(); - - // PreLogon service location should return peers when enterprise is enabled - let peers_prelogon = network_prelogon.get_peers(&pool).await.unwrap(); - assert_eq!( - peers_prelogon.len(), - 1, - "PreLogon service location should return peers when enterprise is enabled" - ); - assert_eq!(peers_prelogon[0].pubkey, "pubkey2"); - - // Service location with AlwaysOn mode also returns peers when enterprise is enabled - let mut network_alwayson = WireguardNetwork { - name: "alwayson-service-location".to_string(), - service_location_mode: ServiceLocationMode::AlwaysOn, - location_mfa_mode: LocationMfaMode::Disabled, - ..Default::default() - }; - network_alwayson.try_set_address("10.3.1.1/24").unwrap(); - let network_alwayson = network_alwayson.save(&pool).await.unwrap(); - - let device3 = Device::new( - "device3".into(), - "pubkey3".into(), - user.id, - DeviceType::User, - None, - true, - ) - .save(&pool) - .await - .unwrap(); - - WireguardNetworkDevice::new( - network_alwayson.id, - device3.id, - vec![IpAddr::from_str("10.3.1.2").unwrap()], - ) - .insert(&pool) - .await - .unwrap(); - - // AlwaysOn service location should return peers when enterprise is enabled - let peers_alwayson = network_alwayson.get_peers(&pool).await.unwrap(); - assert_eq!( - peers_alwayson.len(), - 1, - "AlwaysOn service location should return peers when enterprise is enabled" - ); - assert_eq!(peers_alwayson[0].pubkey, "pubkey3"); - - // Now test the negative case: service locations with enterprise disabled - // Exceed the enterprise limits to disable enterprise features - use crate::enterprise::limits::{Counts, DEFAULT_LOCATIONS_LIMIT, set_counts}; - let over_limit_counts = Counts::new(1, 1, DEFAULT_LOCATIONS_LIMIT + 1, 0); - set_counts(over_limit_counts); - - // Test that normal location still returns peers even without enterprise - let peers_normal_no_ent = network_normal.get_peers(&pool).await.unwrap(); - assert_eq!( - peers_normal_no_ent.len(), - 1, - "Normal location should still return peers without enterprise" - ); - - // Test that PreLogon service location returns NO peers without enterprise - let peers_prelogon_no_ent = network_prelogon.get_peers(&pool).await.unwrap(); - assert!( - peers_prelogon_no_ent.is_empty(), - "PreLogon service location should return NO peers when enterprise is disabled" - ); - - // Test that AlwaysOn service location returns NO peers without enterprise - let peers_alwayson_no_ent = network_alwayson.get_peers(&pool).await.unwrap(); - assert!( - peers_alwayson_no_ent.is_empty(), - "AlwaysOn service location should return NO peers when enterprise is disabled" - ); - - let normal_counts = Counts::new(0, 0, 0, 0); - set_counts(normal_counts); - } } diff --git a/crates/defguard_core/src/db/models/wireguard_peer_stats.rs b/crates/defguard_common/src/db/models/wireguard_peer_stats.rs similarity index 96% rename from crates/defguard_core/src/db/models/wireguard_peer_stats.rs rename to crates/defguard_common/src/db/models/wireguard_peer_stats.rs index 350f826fdd..c0eb5a04f3 100644 --- a/crates/defguard_core/src/db/models/wireguard_peer_stats.rs +++ b/crates/defguard_common/src/db/models/wireguard_peer_stats.rs @@ -1,11 +1,13 @@ use std::time::Duration; +use crate::db::{Id, NoId}; use chrono::{DateTime, NaiveDateTime, TimeDelta, Utc}; -use defguard_common::db::{Id, NoId}; use humantime::format_duration; use ipnetwork::IpNetwork; use model_derive::Model; +use serde::{Deserialize, Serialize}; use sqlx::{PgExecutor, PgPool, query, query_as, query_scalar}; +use tracing::{debug, info}; #[derive(Debug, Deserialize, Model, Serialize)] #[table(wireguard_peer_stats)] @@ -30,7 +32,7 @@ impl WireguardPeerStats { /// This is done to prevent unnecessary table growth. /// At least one record is retained for each device and network combination, /// even when older than set threshold. - pub(crate) async fn purge_old_stats( + pub async fn purge_old_stats( pool: &PgPool, stats_purge_threshold: Duration, ) -> Result<(), sqlx::Error> { @@ -110,7 +112,7 @@ impl WireguardPeerStats { } impl WireguardPeerStats { - pub(crate) async fn fetch_latest( + pub async fn fetch_latest( conn: &PgPool, device_id: Id, network_id: Id, @@ -135,7 +137,7 @@ impl WireguardPeerStats { /// Remove port part from `endpoint`. /// IPv4: a.b.c.d:p -> a.b.c.d /// IPv6: [x::y:z]:p -> x::y:z - pub(crate) fn endpoint_without_port(&self) -> Option { + pub fn endpoint_without_port(&self) -> Option { self.endpoint.as_ref().and_then(|endpoint| { let mut addr = endpoint.rsplit_once(':')?.0; // Strip square brackets. @@ -149,7 +151,7 @@ impl WireguardPeerStats { /// Returns a `Vec` of `allowed_ips` without a CIDR mask. /// Non-parsable addresses are omitted. - pub(crate) fn trim_allowed_ips(&self) -> Vec { + pub fn trim_allowed_ips(&self) -> Vec { let Some(allowed_ips) = &self.allowed_ips else { return Vec::new(); }; diff --git a/crates/defguard_core/src/db/models/yubikey.rs b/crates/defguard_common/src/db/models/yubikey.rs similarity index 94% rename from crates/defguard_core/src/db/models/yubikey.rs rename to crates/defguard_common/src/db/models/yubikey.rs index b12f4b8548..5eec85d52a 100644 --- a/crates/defguard_core/src/db/models/yubikey.rs +++ b/crates/defguard_common/src/db/models/yubikey.rs @@ -1,5 +1,6 @@ -use defguard_common::db::{Id, NoId}; +use crate::db::{Id, NoId}; use model_derive::Model; +use serde::{Deserialize, Serialize}; use sqlx::{PgExecutor, query, query_as}; #[derive(Deserialize, Model, Serialize)] diff --git a/crates/defguard_common/src/lib.rs b/crates/defguard_common/src/lib.rs index c326bfc3c3..fb2351648d 100644 --- a/crates/defguard_common/src/lib.rs +++ b/crates/defguard_common/src/lib.rs @@ -6,6 +6,11 @@ pub mod globals; pub mod hex; pub mod random; pub mod secret; +pub mod types; +pub mod utils; pub const VERSION: &str = concat!(env!("CARGO_PKG_VERSION"), "+", env!("VERGEN_GIT_SHA")); pub const CARGO_VERSION: &str = env!("CARGO_PKG_VERSION"); + +// WireGuard key length in bytes. +pub const KEY_LENGTH: usize = 32; diff --git a/crates/defguard_common/src/types/group_diff.rs b/crates/defguard_common/src/types/group_diff.rs new file mode 100644 index 0000000000..053937fea1 --- /dev/null +++ b/crates/defguard_common/src/types/group_diff.rs @@ -0,0 +1,14 @@ +use std::collections::HashSet; + +#[derive(Debug, Default)] +pub struct GroupDiff { + pub added: HashSet, + pub removed: HashSet, +} + +impl GroupDiff { + #[must_use] + pub fn changed(&self) -> bool { + !self.added.is_empty() || !self.removed.is_empty() + } +} diff --git a/crates/defguard_common/src/types/mod.rs b/crates/defguard_common/src/types/mod.rs new file mode 100644 index 0000000000..ff2bbead02 --- /dev/null +++ b/crates/defguard_common/src/types/mod.rs @@ -0,0 +1,2 @@ +pub mod group_diff; +pub mod user_info; diff --git a/crates/defguard_common/src/types/user_info.rs b/crates/defguard_common/src/types/user_info.rs new file mode 100644 index 0000000000..9609d5d005 --- /dev/null +++ b/crates/defguard_common/src/types/user_info.rs @@ -0,0 +1,200 @@ +use crate::{ + db::{ + Id, + models::{MFAMethod, group::Group, user::User}, + }, + types::group_diff::GroupDiff, +}; +use serde::{Deserialize, Serialize}; +use sqlx::{Error as SqlxError, PgConnection, PgPool}; +use utoipa::ToSchema; + +#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] +pub struct OAuth2AuthorizedAppInfo { + pub oauth2client_id: Id, + pub user_id: Id, + pub oauth2client_name: String, +} + +// Basic user info used in user list, etc. +#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] +pub struct UserInfo { + pub id: Id, + pub username: String, + pub last_name: String, + pub first_name: String, + pub email: String, + pub phone: Option, + pub mfa_enabled: bool, + pub totp_enabled: bool, + pub email_mfa_enabled: bool, + pub groups: Vec, + pub mfa_method: MFAMethod, + pub authorized_apps: Vec, + pub is_active: bool, + pub enrolled: bool, + pub is_admin: bool, + pub ldap_pass_requires_change: bool, +} + +impl UserInfo { + pub async fn from_user(pool: &PgPool, user: &User) -> Result { + let groups = user.member_of_names(pool).await?; + let authorized_apps = user.oauth2authorizedapps(pool).await?; + + Ok(Self { + id: user.id, + username: user.username.clone(), + last_name: user.last_name.clone(), + first_name: user.first_name.clone(), + email: user.email.clone(), + phone: user.phone.clone(), + mfa_enabled: user.mfa_enabled, + totp_enabled: user.totp_enabled, + email_mfa_enabled: user.email_mfa_enabled, + groups, + mfa_method: user.mfa_method.clone(), + authorized_apps, + is_active: user.is_active, + enrolled: user.is_enrolled(), + is_admin: user.is_admin(pool).await?, + ldap_pass_requires_change: user.ldap_pass_randomized, + }) + } + + /// Copy status to [`User`]. This function should be used by administrators. + /// + /// Return `true` if status was changed, `false` otherwise. + /// If status was changed to inactive, all user sessions will be invalidated. + pub async fn handle_status_change( + &self, + transaction: &mut PgConnection, + user: &mut User, + ) -> Result { + if self.is_active == user.is_active { + Ok(false) + } else { + if !self.is_active { + user.logout_all_sessions(&mut *transaction).await?; + } + user.is_active = self.is_active; + user.save(&mut *transaction).await?; + Ok(true) + } + } + + /// Copy groups to [`User`]. This function should be used by administrators. + /// + /// Return `true` if groups were changed, `false` otherwise. + pub async fn handle_user_groups( + &self, + transaction: &mut PgConnection, + user: &mut User, + ) -> Result { + // initialize return value + let mut group_diff = GroupDiff::default(); + + // handle groups + let mut present_groups = user.member_of(&mut *transaction).await?; + + // add to groups if not already a member + for groupname in &self.groups { + match present_groups + .iter() + .position(|group| &group.name == groupname) + { + Some(index) => { + present_groups.swap_remove(index); + } + None => { + if let Some(group) = Group::find_by_name(&mut *transaction, groupname).await? { + user.add_to_group(&mut *transaction, &group).await?; + group_diff.added.insert(group.name); + } + } + } + } + + // remove from remaining groups + for group in present_groups { + user.remove_from_group(&mut *transaction, &group).await?; + group_diff.removed.insert(group.name); + } + + Ok(group_diff) + } + + /// Copy fields to [`User`]. This function is safe to call by a non-admin user. + pub fn into_user_safe_fields(self, user: &mut User) -> Result<(), SqlxError> { + user.phone = self.phone; + user.mfa_method = self.mfa_method; + + Ok(()) + } + + /// Copy fields to [`User`]. This function should be used by administrators. + pub fn into_user_all_fields(self, user: &mut User) -> Result<(), SqlxError> { + user.phone = self.phone; + user.username = self.username; + user.last_name = self.last_name; + user.first_name = self.first_name; + user.email = self.email; + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use crate::db::setup_pool; + use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; + + use super::*; + + #[sqlx::test] + async fn test_user_info(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + + let user = User::new( + "hpotter", + Some("pass123"), + "Potter", + "Harry", + "h.potter@hogwart.edu.uk", + None, + ) + .save(&pool) + .await + .unwrap(); + + let group1 = Group::new("Gryffindor").save(&pool).await.unwrap(); + let group2 = Group::new("Hufflepuff").save(&pool).await.unwrap(); + let group3 = Group::new("Ravenclaw").save(&pool).await.unwrap(); + let group4 = Group::new("Slytherin").save(&pool).await.unwrap(); + + user.add_to_group(&pool, &group1).await.unwrap(); + user.add_to_group(&pool, &group2).await.unwrap(); + + let mut user_info = UserInfo::from_user(&pool, &user).await.unwrap(); + assert_eq!(user_info.groups, ["Gryffindor", "Hufflepuff"]); + + user_info.groups = vec!["Gryffindor".into(), "Ravenclaw".into()]; + let mut user = User::find_by_username(&pool, "hpotter") + .await + .unwrap() + .unwrap(); + + let mut transaction = pool.begin().await.unwrap(); + user_info + .handle_user_groups(&mut transaction, &mut user) + .await + .unwrap(); + user_info.into_user_all_fields(&mut user).unwrap(); + transaction.commit().await.unwrap(); + + assert_eq!(group1.member_usernames(&pool).await.unwrap(), ["hpotter"]); + assert_eq!(group3.member_usernames(&pool).await.unwrap(), ["hpotter"]); + assert!(group2.member_usernames(&pool).await.unwrap().is_empty()); + assert!(group4.member_usernames(&pool).await.unwrap().is_empty()); + } +} diff --git a/crates/defguard_common/src/utils.rs b/crates/defguard_common/src/utils.rs new file mode 100644 index 0000000000..5dc46c1911 --- /dev/null +++ b/crates/defguard_common/src/utils.rs @@ -0,0 +1,23 @@ +use ipnetwork::IpNetwork; + +/// Parse a string with comma-separated IP addresses. +/// Invalid addresses will be silently ignored. +pub fn parse_address_list(ips: &str) -> Vec { + ips.split(',') + .filter_map(|ip| ip.trim().parse().ok()) + .collect() +} + +/// Parse a string with comma-separated IP network addresses. +/// Host bits will be stripped. +/// Invalid addresses will be silently ignored. +pub fn parse_network_address_list(ips: &str) -> Vec { + ips.split(',') + .filter_map(|ip| ip.trim().parse().ok()) + .filter_map(|ip: IpNetwork| { + let network_address = ip.network(); + let network_mask = ip.mask(); + IpNetwork::with_netmask(network_address, network_mask).ok() + }) + .collect() +} diff --git a/crates/defguard_core/src/appstate.rs b/crates/defguard_core/src/appstate.rs index 7483b0b725..10387ade90 100644 --- a/crates/defguard_core/src/appstate.rs +++ b/crates/defguard_core/src/appstate.rs @@ -19,10 +19,10 @@ use webauthn_rs::prelude::*; use crate::{ auth::failed_login::FailedLoginMap, - db::{AppEvent, GatewayEvent, WebHook}, + db::{AppEvent, WebHook}, error::WebError, events::ApiEvent, - grpc::gateway::{send_multiple_wireguard_events, send_wireguard_event}, + grpc::gateway::{events::GatewayEvent, send_multiple_wireguard_events, send_wireguard_event}, version::IncompatibleComponents, }; diff --git a/crates/defguard_core/src/auth/mod.rs b/crates/defguard_core/src/auth/mod.rs index 462f904fa6..08384886f4 100644 --- a/crates/defguard_core/src/auth/mod.rs +++ b/crates/defguard_core/src/auth/mod.rs @@ -10,24 +10,26 @@ use axum_extra::{ extract::cookie::CookieJar, headers::{Authorization, authorization::Bearer}, }; -use defguard_common::db::Id; +use defguard_common::db::{ + Id, + models::{ + OAuth2Token, Session, SessionState, + group::{Group, Permission}, + oauth2client::OAuth2Client, + user::User, + }, +}; use crate::{ appstate::AppState, - db::{ - Group, OAuth2Token, Session, SessionState, User, - models::{group::Permission, oauth2client::OAuth2Client}, - }, enterprise::{db::models::api_tokens::ApiToken, is_enterprise_enabled}, error::WebError, handlers::SESSION_COOKIE_NAME, }; -pub const TOTP_CODE_VALIDITY_PERIOD: u64 = 30; -pub const EMAIL_CODE_DIGITS: u32 = 6; -pub const TOTP_CODE_DIGITS: u32 = 6; +pub struct SessionExtractor(pub Session); -impl FromRequestParts for Session +impl FromRequestParts for SessionExtractor where S: Send + Sync, AppState: FromRef, @@ -59,12 +61,12 @@ where error!("Failed to get client IP: {err:?}"); WebError::ClientIpError })?; - Ok(Session::new( + Ok(Self(Session::new( api_token.user_id, SessionState::ApiTokenVerified, ip_address.0.to_string(), None, - )) + ))) } Ok(None) => Err(WebError::Authorization("Invalid API token".into())), Err(err) => Err(err.into()), @@ -81,7 +83,7 @@ where let _result = session.delete(&appstate.pool).await; Err(WebError::Authorization("Session expired".into())) } else { - Ok(session) + Ok(Self(session)) } } Ok(None) => Err(WebError::Authorization("Session not found".into())), @@ -130,7 +132,7 @@ where type Rejection = WebError; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { - let session = Session::from_request_parts(parts, state).await?; + let session = SessionExtractor::from_request_parts(parts, state).await?.0; let appstate = AppState::from_ref(state); let user = User::find_by_id(&appstate.pool, session.user_id).await?; diff --git a/crates/defguard_core/src/db/mod.rs b/crates/defguard_core/src/db/mod.rs index 37bcc961f0..7fe635229c 100644 --- a/crates/defguard_core/src/db/mod.rs +++ b/crates/defguard_core/src/db/mod.rs @@ -1,15 +1,3 @@ pub mod models; -pub use models::{ - MFAInfo, UserDetails, UserInfo, - device::{AddDevice, Device}, - group::Group, - oauth2authorizedapp::OAuth2AuthorizedApp, - oauth2token::OAuth2Token, - session::{Session, SessionState}, - user::User, - webauthn::WebAuthn, - webhook::{AppEvent, HWKeyUserData, WebHook}, - wireguard::{GatewayEvent, WireguardNetwork}, - yubikey::YubiKey, -}; +pub use models::webhook::{AppEvent, HWKeyUserData, WebHook}; diff --git a/crates/defguard_core/src/db/models/activity_log/metadata.rs b/crates/defguard_core/src/db/models/activity_log/metadata.rs index e358560fe2..582897cc30 100644 --- a/crates/defguard_core/src/db/models/activity_log/metadata.rs +++ b/crates/defguard_core/src/db/models/activity_log/metadata.rs @@ -2,16 +2,17 @@ use chrono::NaiveDateTime; use defguard_common::db::{ Id, models::{ - AuthenticationKey, AuthenticationKeyType, MFAMethod, Settings, + AuthenticationKey, AuthenticationKeyType, Device, MFAMethod, Settings, WebAuthn, + WireguardNetwork, + group::Group, + oauth2client::OAuth2Client, settings::{LdapSyncStatus, OpenidUsernameHandling, SmtpEncryption}, + user::User, }, }; use crate::{ - db::{ - Device, Group, User, WebAuthn, WebHook, WireguardNetwork, - models::oauth2client::OAuth2Client, - }, + db::WebHook, enterprise::db::models::{ activity_log_stream::{ActivityLogStream, ActivityLogStreamType}, api_tokens::ApiToken, diff --git a/crates/defguard_core/src/db/models/enrollment.rs b/crates/defguard_core/src/db/models/enrollment.rs index e0e0824d2e..de33f97f51 100644 --- a/crates/defguard_core/src/db/models/enrollment.rs +++ b/crates/defguard_core/src/db/models/enrollment.rs @@ -2,28 +2,21 @@ use chrono::{NaiveDateTime, TimeDelta, Utc}; use defguard_common::{ VERSION, config::server_config, - db::{Id, models::Settings}, + db::{ + Id, + models::{Settings, user::User}, + }, random::gen_alphanumeric, }; -use defguard_mail::{ - Mail, - templates::{self, TemplateError, safe_tera}, -}; -use reqwest::Url; +use defguard_mail::templates::{self, TemplateError, safe_tera}; use sqlx::{Error as SqlxError, PgConnection, PgExecutor, PgPool, query, query_as}; use tera::Context; use thiserror::Error; -use tokio::sync::mpsc::UnboundedSender; use tonic::{Code, Status}; -use super::User; - pub static ENROLLMENT_TOKEN_TYPE: &str = "ENROLLMENT"; pub static PASSWORD_RESET_TOKEN_TYPE: &str = "PASSWORD_RESET"; -static ENROLLMENT_START_MAIL_SUBJECT: &str = "Defguard user enrollment"; -static DESKTOP_START_MAIL_SUBJECT: &str = "Defguard desktop client configuration"; - #[derive(Error, Debug)] pub enum TokenError { #[error(transparent)] @@ -319,7 +312,7 @@ impl Token { /// - admin_last_name /// - admin_email /// - admin_phone - async fn get_welcome_message_context( + pub async fn get_welcome_message_context( &self, transaction: &mut PgConnection, ) -> Result { @@ -389,240 +382,6 @@ impl Token { } } -impl User { - /// Start user enrollment process - /// This creates a new enrollment token valid for 24h - /// and optionally sends enrollment email notification to user - pub async fn start_enrollment( - &mut self, - transaction: &mut PgConnection, - admin: &User, - email: Option, - token_timeout_seconds: u64, - enrollment_service_url: Url, - send_user_notification: bool, - mail_tx: UnboundedSender, - ) -> Result { - info!( - "User {} started a new enrollment process for user {}.", - admin.username, self.username - ); - debug!( - "Notify user by mail about the enrollment process: {}", - send_user_notification - ); - debug!("Check if {} has a password.", self.username); - if self.has_password() { - debug!( - "User {} that you want to start enrollment process for already has a password.", - self.username - ); - return Err(TokenError::AlreadyActive); - } - - debug!("Verify that {} is an active user.", self.username); - if !self.is_active { - warn!( - "Can't create enrollment token for disabled user {}", - self.username - ); - return Err(TokenError::UserDisabled); - } - - self.clear_unused_enrollment_tokens(&mut *transaction) - .await?; - - debug!("Create a new enrollment token for user {}.", self.username); - let enrollment = Token::new( - self.id, - Some(admin.id), - email.clone(), - token_timeout_seconds, - Some(ENROLLMENT_TOKEN_TYPE.to_string()), - ); - debug!("Saving a new enrollment token..."); - enrollment.save(&mut *transaction).await?; - debug!( - "Saved a new enrollment token with id {} for user {}.", - enrollment.id, self.username - ); - - // Mark the user with enrollment-pending flag. - // https://github.com/DefGuard/client/issues/647 - self.enrollment_pending = true; - self.save(&mut *transaction).await?; - - if send_user_notification { - if let Some(email) = email { - debug!( - "Sending an enrollment mail for user {} to {email}.", - self.username - ); - let base_message_context = enrollment - .get_welcome_message_context(&mut *transaction) - .await?; - let mail = Mail { - to: email.clone(), - subject: ENROLLMENT_START_MAIL_SUBJECT.to_string(), - content: templates::enrollment_start_mail( - base_message_context, - enrollment_service_url, - &enrollment.id, - ) - .map_err(|err| { - debug!( - "Cannot send an email to the user {} due to the error {}.", - self.username, - err.to_string() - ); - TokenError::NotificationError(err.to_string()) - })?, - attachments: Vec::new(), - result_tx: None, - }; - match mail_tx.send(mail) { - Ok(()) => { - info!( - "Sent enrollment start mail for user {} to {email}", - self.username - ); - } - Err(err) => { - error!("Error sending mail: {err}"); - return Err(TokenError::NotificationError(err.to_string())); - } - } - } - } - info!( - "New enrollment token has been generated for {}.", - self.username - ); - - Ok(enrollment.id) - } - - /// Start user remote desktop configuration process - /// This creates a new enrollment token valid for 24h - /// and optionally sends email notification to user - pub async fn start_remote_desktop_configuration( - &self, - transaction: &mut PgConnection, - admin: &User, - email: Option, - token_timeout_seconds: u64, - enrollment_service_url: Url, - send_user_notification: bool, - mail_tx: UnboundedSender, - // Whether to attach some device to the token. It allows for a partial initialization of - // the device before the desktop configuration has taken place. - device_id: Option, - ) -> Result { - info!( - "User {} starting a new desktop activation for user {}", - admin.username, self.username - ); - debug!( - "Notify {} by mail about the enrollment process: {}", - self.username, send_user_notification - ); - - debug!("Verify that {} is an active user.", self.username); - if !self.is_active { - warn!( - "Can't create desktop activation token for disabled user {}.", - self.username - ); - return Err(TokenError::UserDisabled); - } - - self.clear_unused_enrollment_tokens(&mut *transaction) - .await?; - debug!("Cleared unused tokens for {}.", self.username); - - debug!( - "Create a new desktop activation token for user {}.", - self.username - ); - let mut desktop_configuration = Token::new( - self.id, - Some(admin.id), - email.clone(), - token_timeout_seconds, - Some(ENROLLMENT_TOKEN_TYPE.to_string()), - ); - if let Some(device_id) = device_id { - desktop_configuration.device_id = Some(device_id); - } - debug!("Saving a new desktop configuration token..."); - desktop_configuration.save(&mut *transaction).await?; - debug!( - "Saved a new desktop activation token with id {} for user {}.", - desktop_configuration.id, self.username - ); - - if send_user_notification { - if let Some(email) = email { - debug!( - "Sending a desktop configuration mail for user {} to {email}", - self.username - ); - let base_message_context = desktop_configuration - .get_welcome_message_context(&mut *transaction) - .await?; - let mail = Mail { - to: email.clone(), - subject: DESKTOP_START_MAIL_SUBJECT.to_string(), - content: templates::desktop_start_mail( - base_message_context, - &enrollment_service_url, - &desktop_configuration.id, - ) - .map_err(|err| { - debug!( - "Cannot send an email to the user {} due to the error {}.", - self.username, - err.to_string() - ); - TokenError::NotificationError(err.to_string()) - })?, - attachments: Vec::new(), - result_tx: None, - }; - match mail_tx.send(mail) { - Ok(()) => { - info!( - "Sent desktop configuration start mail for user {} to {email}", - self.username - ); - } - Err(err) => { - error!("Error sending mail: {err}"); - } - } - } - } - info!( - "New desktop activation token has been generated for {}.", - self.username - ); - - Ok(desktop_configuration.id) - } - - // Remove unused tokens when triggering user enrollment - pub(crate) async fn clear_unused_enrollment_tokens<'e, E>( - &self, - executor: E, - ) -> Result<(), TokenError> - where - E: PgExecutor<'e>, - { - info!("Removing unused tokens for user {}.", self.username); - Token::delete_unused_user_tokens(executor, self.id).await - } -} - pub fn enrollment_welcome_message(settings: &Settings) -> Result { settings.enrollment_welcome_message.clone().ok_or_else(|| { error!("Enrollment welcome message not configured"); diff --git a/crates/defguard_core/src/db/models/mod.rs b/crates/defguard_core/src/db/models/mod.rs index df2faac41b..b6efe1b578 100644 --- a/crates/defguard_core/src/db/models/mod.rs +++ b/crates/defguard_core/src/db/models/mod.rs @@ -1,326 +1,3 @@ pub mod activity_log; -pub mod device; pub mod enrollment; -pub mod group; -pub mod oauth2authorizedapp; -pub mod oauth2client; -pub mod oauth2token; -pub mod polling_token; -pub mod session; -pub mod user; -pub mod webauthn; pub mod webhook; -pub mod wireguard; -pub mod wireguard_peer_stats; -pub mod yubikey; - -use std::collections::HashSet; - -use defguard_common::db::{ - Id, - models::{BiometricAuth, MFAMethod}, -}; -use sqlx::{Error as SqlxError, PgConnection, PgPool, query_as}; -use utoipa::ToSchema; - -use self::{device::UserDevice, user::User}; -use super::Group; - -#[derive(Deserialize, Serialize)] -pub struct NewOpenIDClient { - pub name: String, - pub redirect_uri: Vec, - pub scope: Vec, - pub enabled: bool, -} - -#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] -pub struct OAuth2AuthorizedAppInfo { - pub oauth2client_id: Id, - pub user_id: Id, - pub oauth2client_name: String, -} - -/// Only `id` and `name` from [`WebAuthn`]. -#[derive(Debug, Deserialize, Serialize, ToSchema)] -pub struct SecurityKey { - pub id: Id, - pub name: String, -} - -// Basic user info used in user list, etc. -#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] -pub struct UserInfo { - pub id: Id, - pub username: String, - pub last_name: String, - pub first_name: String, - pub email: String, - pub phone: Option, - pub mfa_enabled: bool, - pub totp_enabled: bool, - pub email_mfa_enabled: bool, - pub groups: Vec, - pub mfa_method: MFAMethod, - pub authorized_apps: Vec, - pub is_active: bool, - pub enrolled: bool, - pub is_admin: bool, - pub ldap_pass_requires_change: bool, -} - -#[derive(Debug, Default)] -pub struct GroupDiff { - pub added: HashSet, - pub removed: HashSet, -} - -impl GroupDiff { - #[must_use] - pub fn changed(&self) -> bool { - !self.added.is_empty() || !self.removed.is_empty() - } -} - -impl UserInfo { - pub async fn from_user(pool: &PgPool, user: &User) -> Result { - let groups = user.member_of_names(pool).await?; - let authorized_apps = user.oauth2authorizedapps(pool).await?; - - Ok(Self { - id: user.id, - username: user.username.clone(), - last_name: user.last_name.clone(), - first_name: user.first_name.clone(), - email: user.email.clone(), - phone: user.phone.clone(), - mfa_enabled: user.mfa_enabled, - totp_enabled: user.totp_enabled, - email_mfa_enabled: user.email_mfa_enabled, - groups, - mfa_method: user.mfa_method.clone(), - authorized_apps, - is_active: user.is_active, - enrolled: user.is_enrolled(), - is_admin: user.is_admin(pool).await?, - ldap_pass_requires_change: user.ldap_pass_randomized, - }) - } - - /// Copy status to [`User`]. This function should be used by administrators. - /// - /// Return `true` if status was changed, `false` otherwise. - /// If status was changed to inactive, all user sessions will be invalidated. - pub(crate) async fn handle_status_change( - &self, - transaction: &mut PgConnection, - user: &mut User, - ) -> Result { - if self.is_active == user.is_active { - Ok(false) - } else { - if !self.is_active { - user.logout_all_sessions(&mut *transaction).await?; - } - user.is_active = self.is_active; - user.save(&mut *transaction).await?; - Ok(true) - } - } - - /// Copy groups to [`User`]. This function should be used by administrators. - /// - /// Return `true` if groups were changed, `false` otherwise. - pub(crate) async fn handle_user_groups( - &self, - transaction: &mut PgConnection, - user: &mut User, - ) -> Result { - // initialize return value - let mut group_diff = GroupDiff::default(); - - // handle groups - let mut present_groups = user.member_of(&mut *transaction).await?; - - // add to groups if not already a member - for groupname in &self.groups { - match present_groups - .iter() - .position(|group| &group.name == groupname) - { - Some(index) => { - present_groups.swap_remove(index); - } - None => { - if let Some(group) = Group::find_by_name(&mut *transaction, groupname).await? { - user.add_to_group(&mut *transaction, &group).await?; - group_diff.added.insert(group.name); - } - } - } - } - - // remove from remaining groups - for group in present_groups { - user.remove_from_group(&mut *transaction, &group).await?; - group_diff.removed.insert(group.name); - } - - Ok(group_diff) - } - - /// Copy fields to [`User`]. This function is safe to call by a non-admin user. - pub fn into_user_safe_fields(self, user: &mut User) -> Result<(), SqlxError> { - user.phone = self.phone; - user.mfa_method = self.mfa_method; - - Ok(()) - } - - /// Copy fields to [`User`]. This function should be used by administrators. - pub fn into_user_all_fields(self, user: &mut User) -> Result<(), SqlxError> { - user.phone = self.phone; - user.username = self.username; - user.last_name = self.last_name; - user.first_name = self.first_name; - user.email = self.email; - - Ok(()) - } -} - -// Full user info with related objects -#[derive(Deserialize, Serialize, Debug, ToSchema)] -pub struct UserDetails { - pub user: UserInfo, - #[serde(default)] - pub devices: Vec, - pub biometric_enabled_devices: Vec, - #[serde(default)] - pub security_keys: Vec, -} - -impl UserDetails { - pub async fn from_user(pool: &PgPool, user: &User) -> Result { - let devices = user.user_devices(pool).await?; - let security_keys = user.security_keys(pool).await?; - let biometric_enabled_devices = BiometricAuth::find_by_user_id(pool, user.id) - .await? - .iter() - .map(|a| a.device_id) - .collect::>(); - Ok(Self { - user: UserInfo::from_user(pool, user).await?, - devices, - security_keys, - biometric_enabled_devices, - }) - } -} - -#[derive(Deserialize, Serialize)] -pub struct MFAInfo { - mfa_method: MFAMethod, - totp_available: bool, - webauthn_available: bool, - email_available: bool, -} - -impl MFAInfo { - pub async fn for_user(pool: &PgPool, user: &User) -> Result, SqlxError> { - query_as!( - Self, - "SELECT mfa_method \"mfa_method: _\", totp_enabled totp_available, \ - email_mfa_enabled email_available, \ - (SELECT count(*) > 0 FROM webauthn WHERE user_id = $1) \"webauthn_available!\" \ - FROM \"user\" WHERE \"user\".id = $1", - user.id - ) - .fetch_optional(pool) - .await - } - - #[must_use] - pub fn mfa_available(&self) -> bool { - self.webauthn_available || self.totp_available || self.email_available - } - - #[must_use] - pub fn current_mfa_method(&self) -> &MFAMethod { - &self.mfa_method - } - - #[must_use] - pub fn list_available_methods(&self) -> Option> { - if !self.mfa_available() { - return None; - } - - let mut methods = Vec::new(); - if self.webauthn_available { - methods.push(MFAMethod::Webauthn); - } - if self.totp_available { - methods.push(MFAMethod::OneTimePassword); - } - if self.email_available { - methods.push(MFAMethod::Email); - } - Some(methods) - } -} - -#[cfg(test)] -mod test { - use defguard_common::db::setup_pool; - use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; - - use super::*; - - #[sqlx::test] - async fn test_user_info(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - - let user = User::new( - "hpotter", - Some("pass123"), - "Potter", - "Harry", - "h.potter@hogwart.edu.uk", - None, - ) - .save(&pool) - .await - .unwrap(); - - let group1 = Group::new("Gryffindor").save(&pool).await.unwrap(); - let group2 = Group::new("Hufflepuff").save(&pool).await.unwrap(); - let group3 = Group::new("Ravenclaw").save(&pool).await.unwrap(); - let group4 = Group::new("Slytherin").save(&pool).await.unwrap(); - - user.add_to_group(&pool, &group1).await.unwrap(); - user.add_to_group(&pool, &group2).await.unwrap(); - - let mut user_info = UserInfo::from_user(&pool, &user).await.unwrap(); - assert_eq!(user_info.groups, ["Gryffindor", "Hufflepuff"]); - - user_info.groups = vec!["Gryffindor".into(), "Ravenclaw".into()]; - let mut user = User::find_by_username(&pool, "hpotter") - .await - .unwrap() - .unwrap(); - - let mut transaction = pool.begin().await.unwrap(); - user_info - .handle_user_groups(&mut transaction, &mut user) - .await - .unwrap(); - user_info.into_user_all_fields(&mut user).unwrap(); - transaction.commit().await.unwrap(); - - assert_eq!(group1.member_usernames(&pool).await.unwrap(), ["hpotter"]); - assert_eq!(group3.member_usernames(&pool).await.unwrap(), ["hpotter"]); - assert!(group2.member_usernames(&pool).await.unwrap().is_empty()); - assert!(group4.member_usernames(&pool).await.unwrap().is_empty()); - } -} diff --git a/crates/defguard_core/src/db/models/user.rs b/crates/defguard_core/src/db/models/user.rs deleted file mode 100644 index f545ed850d..0000000000 --- a/crates/defguard_core/src/db/models/user.rs +++ /dev/null @@ -1,1701 +0,0 @@ -use std::{collections::HashSet, fmt, time::SystemTime}; - -use argon2::{ - Argon2, - password_hash::{ - PasswordHash, PasswordHasher, PasswordVerifier, SaltString, errors::Error as HashError, - rand_core::OsRng, - }, -}; -use axum::http::StatusCode; -use defguard_common::{ - config::server_config, - db::{Id, NoId, models::MFAMethod}, - random::{gen_alphanumeric, gen_totp_secret}, -}; -use defguard_mail::templates::UserContext; -use model_derive::Model; -#[cfg(test)] -use rand::{ - Rng, - distributions::{Alphanumeric, DistString, Standard}, - prelude::Distribution, -}; -use serde::Serialize; -use sqlx::{ - Error as SqlxError, FromRow, PgConnection, PgExecutor, PgPool, query, query_as, query_scalar, -}; -use tokio::sync::broadcast::Sender; -use totp_lite::{Sha1, totp_custom}; - -use super::{ - MFAInfo, OAuth2AuthorizedAppInfo, SecurityKey, - device::{Device, DeviceInfo, DeviceType, UserDevice}, - group::Group, - webauthn::WebAuthn, -}; -use crate::{ - auth::{EMAIL_CODE_DIGITS, TOTP_CODE_DIGITS, TOTP_CODE_VALIDITY_PERIOD}, - db::{GatewayEvent, Session, WireguardNetwork, models::group::Permission}, - enterprise::limits::update_counts, - error::WebError, - grpc::gateway::{send_multiple_wireguard_events, send_wireguard_event}, -}; - -const RECOVERY_CODES_COUNT: usize = 8; - -// User information ready to be sent as part of diagnostic data. -#[derive(Serialize)] -pub struct UserDiagnostic { - pub id: Id, - pub mfa_enabled: bool, - pub totp_enabled: bool, - pub email_mfa_enabled: bool, - pub mfa_method: MFAMethod, - pub is_active: bool, - pub enrolled: bool, -} - -#[derive(Clone, Model, PartialEq, Eq, Hash, Serialize, FromRow)] -pub struct User { - pub id: I, - pub username: String, - pub(crate) password_hash: Option, - pub last_name: String, - pub first_name: String, - pub email: String, - pub phone: Option, - pub mfa_enabled: bool, - pub is_active: bool, - /// Indicates whether the user has been created via the LDAP integration. - pub from_ldap: bool, - /// Indicates whether a user has a random password set in LDAP, if so, the user - /// will be prompted to change it on their profile page. - /// - /// The random password is set if we are creating a new user in LDAP from a Defguard user - /// and we don't have access to the plain text password, e.g. during Defguard -> LDAP user import. - pub ldap_pass_randomized: bool, - /// The user's LDAP RDN value. This is the first part of the DN. - /// For example, if the DN is `cn=John Doe,ou=users,dc=example,dc=com`, - /// the RDN is `cn=John Doe`. - /// This is used to identify the user in LDAP as we sometimes can't use the Defguard's username - /// since the RDN may contain spaces or other special characters and the username may not. - pub ldap_rdn: Option, - /// Rest of the user's DN - pub ldap_user_path: Option, - /// The user's sub claim returned by the OpenID provider. Also indicates whether the user has - /// used OpenID to log in. - // FIXME: must be unique - pub openid_sub: Option, - // secret has been verified and TOTP can be used - pub(crate) totp_enabled: bool, - pub(crate) email_mfa_enabled: bool, - pub(crate) totp_secret: Option>, - pub(crate) email_mfa_secret: Option>, - #[model(enum)] - pub(crate) mfa_method: MFAMethod, - #[model(ref)] - pub(crate) recovery_codes: Vec, - /// Indicates that an administrator has requested an enrollment token for this user. - /// Uninitialized clients should then guide the user through enrollment process. - /// Related issue: https://github.com/DefGuard/client/issues/647. - pub enrollment_pending: bool, -} - -// TODO: Refactor the user struct to use SecretStringWrapper instead of this -impl fmt::Debug for User { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let Self { - id, - username, - password_hash: _, - last_name, - first_name, - email, - phone, - mfa_enabled, - is_active, - from_ldap, - ldap_pass_randomized, - ldap_rdn, - ldap_user_path, - openid_sub, - totp_enabled, - email_mfa_enabled, - totp_secret: _, - email_mfa_secret: _, - mfa_method, - recovery_codes, - enrollment_pending, - } = self; - - f.debug_struct("User") - .field("id", id) - .field("username", username) - .field("last_name", last_name) - .field("first_name", first_name) - .field("email", email) - .field("phone", phone) - .field("mfa_enabled", mfa_enabled) - .field("is_active", is_active) - .field("from_ldap", from_ldap) - .field("ldap_pass_randomized", ldap_pass_randomized) - .field("ldap_rdn", ldap_rdn) - .field("ldap_user_path", ldap_user_path) // sensitive data - .field("openid_sub", openid_sub) - .field("totp_enabled", totp_enabled) - .field("email_mfa_enabled", email_mfa_enabled) - .field("mfa_method", mfa_method) - .field( - "recovery_codes", - &format_args!("{} items", recovery_codes.len()), - ) - .field("password_hash", &"***") - .field("totp_secret", &"***") - .field("email_mfa_secret", &"***") - .field("enrollment_pending", enrollment_pending) - .finish() - } -} - -fn hash_password(password: &str) -> Result { - let salt = SaltString::generate(&mut OsRng); - Ok(Argon2::default() - .hash_password(password.as_bytes(), &salt)? - .to_string()) -} - -impl From> for UserContext { - fn from(value: User) -> Self { - Self { - last_name: value.last_name, - first_name: value.first_name, - } - } -} - -impl User { - #[must_use] - pub fn new>( - username: S, - password: Option<&str>, - last_name: S, - first_name: S, - email: S, - phone: Option, - ) -> Self { - let password_hash = password.and_then(|password_hash| hash_password(password_hash).ok()); - let username: String = username.into(); - Self { - id: NoId, - username: username.clone(), - password_hash, - last_name: last_name.into(), - first_name: first_name.into(), - email: email.into(), - phone, - mfa_enabled: false, - totp_enabled: false, - email_mfa_enabled: false, - totp_secret: None, - email_mfa_secret: None, - mfa_method: MFAMethod::None, - recovery_codes: Vec::new(), - is_active: true, - openid_sub: None, - from_ldap: false, - ldap_pass_randomized: false, - ldap_rdn: Some(username.clone()), - ldap_user_path: None, - enrollment_pending: false, - } - } -} - -impl fmt::Display for User { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.username) - } -} - -impl User { - pub fn set_password(&mut self, password: &str) { - self.password_hash = hash_password(password).ok(); - } - - pub(crate) fn verify_password(&self, password: &str) -> Result<(), HashError> { - debug!("Checking if password matches for user {}", self.username); - if let Some(hash) = &self.password_hash { - let parsed_hash = PasswordHash::new(hash)?; - Argon2::default().verify_password(password.as_bytes(), &parsed_hash) - } else { - info!("User {} has no password set", self.username); - Err(HashError::Password) - } - } - - #[must_use] - pub(crate) fn has_password(&self) -> bool { - self.password_hash.is_some() - } - - #[must_use] - pub(crate) fn name(&self) -> String { - format!("{} {}", self.first_name, self.last_name) - } - - /// Determines whether the user is considered enrolled. - /// - /// A user is treated as enrolled if: - /// - The `enrollment_pending` flag is **not** set, i.e. enrollment was not requested by an - /// administrator (https://github.com/DefGuard/client/issues/647). - /// - They either have a password configured, have authenticated via an external OIDC provider - /// or were synced from LDAP. - #[must_use] - pub fn is_enrolled(&self) -> bool { - !self.enrollment_pending - && (self.password_hash.is_some() || self.openid_sub.is_some() || self.from_ldap) - } - - #[must_use] - pub(crate) fn ldap_rdn_value(&self) -> &str { - if let Some(ldap_rdn) = &self.ldap_rdn { - ldap_rdn - } else { - warn!( - "LDAP RDN is not set for user {}. Using username as a fallback.", - self.username - ); - &self.username - } - } -} - -impl User { - /// Generate new TOTP secret, save it, then return it as RFC 4648 base32-encoded string. - pub async fn new_totp_secret<'e, E>(&mut self, executor: E) -> Result - where - E: PgExecutor<'e>, - { - let secret = gen_totp_secret(); - query!( - "UPDATE \"user\" SET totp_secret = $1 WHERE id = $2", - secret, - self.id - ) - .execute(executor) - .await?; - - let secret_base32 = base32::encode(base32::Alphabet::Rfc4648 { padding: false }, &secret); - self.totp_secret = Some(secret); - Ok(secret_base32) - } - - /// Generate new email secret, similar to TOTP secret above, but don't return generated value. - pub async fn new_email_secret<'e, E>(&mut self, executor: E) -> Result<(), SqlxError> - where - E: PgExecutor<'e>, - { - let email_secret = gen_totp_secret(); - query!( - "UPDATE \"user\" SET email_mfa_secret = $1 WHERE id = $2", - email_secret, - self.id - ) - .execute(executor) - .await?; - - self.email_mfa_secret = Some(email_secret); - - Ok(()) - } - - pub async fn set_mfa_method<'e, E>( - &mut self, - executor: E, - mfa_method: MFAMethod, - ) -> Result<(), SqlxError> - where - E: PgExecutor<'e>, - { - info!( - "Setting MFA method for user {} to {mfa_method:?}", - self.username - ); - query!( - "UPDATE \"user\" SET mfa_method = $2 WHERE id = $1", - self.id, - &mfa_method as &MFAMethod - ) - .execute(executor) - .await?; - self.mfa_method = mfa_method; - - Ok(()) - } - - /// Check if any of the multi-factor authentication methods is on. - /// - TOTP is enabled - /// - a security key for Webauthn - async fn check_mfa_enabled<'e, E>(&self, executor: E) -> Result - where - E: PgExecutor<'e>, - { - // short-cut - if self.totp_enabled || self.email_mfa_enabled { - return Ok(true); - } - - query_scalar!( - "SELECT totp_enabled OR email_mfa_enabled \ - OR count(webauthn.id) > 0 \"bool!\" FROM \"user\" \ - LEFT JOIN webauthn ON webauthn.user_id = \"user\".id \ - WHERE \"user\".id = $1 GROUP BY totp_enabled, email_mfa_enabled;", - self.id - ) - .fetch_one(executor) - .await - } - - /// Verify the state of MFA flags are correct. - /// Recovers from invalid mfa_method - /// Use this function after removing any of the authentication factors. - pub async fn verify_mfa_state(&mut self, pool: &PgPool) -> Result<(), WebError> { - if let Some(info) = MFAInfo::for_user(pool, self).await? { - let factors_present = info.mfa_available(); - if self.mfa_enabled != factors_present { - // store correct value for MFA flag in the DB - if self.mfa_enabled { - // last factor was removed so we have to disable MFA - self.disable_mfa(pool).await?; - } else { - // first factor was added so MFA needs to be enabled - query!( - "UPDATE \"user\" SET mfa_enabled = $2 WHERE id = $1", - self.id, - factors_present - ) - .execute(pool) - .await?; - } - - if !factors_present && self.mfa_method != MFAMethod::None { - debug!( - "MFA for user {} disabled, updating MFA method to None", - self.username - ); - self.set_mfa_method(pool, MFAMethod::None).await?; - } - - self.mfa_enabled = factors_present; - } - - // set correct value for default method - if factors_present { - match info.list_available_methods() { - None => { - error!("Incorrect MFA info state for user {}", self.username); - return Err(WebError::Http(StatusCode::INTERNAL_SERVER_ERROR)); - } - Some(methods) => { - info!( - "Checking if {:?} in in available methods {methods:?}, {}", - info.mfa_method, - methods.contains(&info.mfa_method) - ); - if !methods.contains(&info.mfa_method) { - // FIXME: do not panic - self.set_mfa_method(pool, methods.into_iter().next().unwrap()) - .await?; - } - } - } - } - } - Ok(()) - } - - /// Disable user, log out all his sessions and update gateways state. - pub async fn disable( - &mut self, - conn: &mut PgConnection, - wg_tx: &Sender, - ) -> Result<(), WebError> { - self.is_active = false; - self.save(&mut *conn).await?; - self.logout_all_sessions(&mut *conn).await?; - self.sync_allowed_devices(conn, wg_tx).await?; - Ok(()) - } - - /// Update gateway state based on this user device access rights - pub async fn sync_allowed_devices( - &self, - conn: &mut PgConnection, - wg_tx: &Sender, - ) -> Result<(), WebError> { - debug!("Syncing allowed devices of user {}", self.username); - let networks = WireguardNetwork::all(&mut *conn).await?; - for network in networks { - let gateway_events = network - .sync_allowed_devices_for_user(&mut *conn, self, None) - .await?; - - // check if any peers were updated - if !gateway_events.is_empty() { - // send peer update events - send_multiple_wireguard_events(gateway_events, wg_tx); - } - - // send firewall config update if ACLs & enterprise features are enabled - if let Some(firewall_config) = network.try_get_firewall_config(&mut *conn).await? { - send_wireguard_event( - GatewayEvent::FirewallConfigChanged(network.id, firewall_config), - wg_tx, - ); - } - } - info!("Allowed devices of user {} synced", self.username); - Ok(()) - } - - /// Deletes the user and cleans up his devices from gateways - pub async fn delete_and_cleanup( - self, - conn: &mut PgConnection, - wg_tx: &Sender, - ) -> Result<(), WebError> { - let username = self.username.clone(); - debug!("Deleting user {username}, removing his devices from gateways and updating ldap...",); - let devices = self.devices(&mut *conn).await?; - let mut events = Vec::new(); - - // get all locations affected by devices being deleted - let mut affected_location_ids = HashSet::new(); - - for device in devices { - let device_info = DeviceInfo::from_device(&mut *conn, device).await?; - for network_info in &device_info.network_info { - affected_location_ids.insert(network_info.network_id); - } - events.push(GatewayEvent::DeviceDeleted(device_info)); - } - - self.delete(&mut *conn).await?; - update_counts(&mut *conn).await?; - - // send firewall config updates to affected locations - // if they have ACL enabled & enterprise features are active - for location_id in affected_location_ids { - if let Some(location) = WireguardNetwork::find_by_id(&mut *conn, location_id).await? { - if let Some(firewall_config) = location.try_get_firewall_config(&mut *conn).await? { - debug!( - "Sending firewall config update for location {location} affected by deleting user {username} devices" - ); - events.push(GatewayEvent::FirewallConfigChanged( - location_id, - firewall_config, - )); - } - } - } - - send_multiple_wireguard_events(events, wg_tx); - info!( - "The user {} has been deleted and his devices removed from gateways.", - &username - ); - Ok(()) - } - - /// Enable MFA. At least one of the authenticator factors must be configured. - pub async fn enable_mfa(&mut self, pool: &PgPool) -> Result<(), WebError> { - if !self.mfa_enabled { - self.verify_mfa_state(pool).await?; - } - Ok(()) - } - - /// Get recovery codes. If recovery codes exist, this function returns `None`. - /// That way recovery codes are returned only once - when MFA is turned on. - pub async fn get_recovery_codes<'e, E>( - &mut self, - executor: E, - ) -> Result>, SqlxError> - where - E: PgExecutor<'e>, - { - if !self.recovery_codes.is_empty() { - return Ok(None); - } - - for _ in 0..RECOVERY_CODES_COUNT { - let code = gen_alphanumeric(16); - self.recovery_codes.push(code); - } - query!( - "UPDATE \"user\" SET recovery_codes = $2 WHERE id = $1", - self.id, - &self.recovery_codes - ) - .execute(executor) - .await?; - - Ok(Some(self.recovery_codes.clone())) - } - - /// Disable MFA; discard recovery codes, TOTP secret, and security keys. - pub async fn disable_mfa(&mut self, pool: &PgPool) -> Result<(), SqlxError> { - query!( - "UPDATE \"user\" SET mfa_enabled = FALSE, mfa_method = 'none', totp_enabled = FALSE, email_mfa_enabled = FALSE, \ - totp_secret = NULL, email_mfa_secret = NULL, recovery_codes = '{}' WHERE id = $1", - self.id - ) - .execute(pool) - .await?; - WebAuthn::delete_all_for_user(pool, self.id).await?; - - self.totp_secret = None; - self.email_mfa_secret = None; - self.totp_enabled = false; - self.email_mfa_enabled = false; - self.mfa_method = MFAMethod::None; - self.recovery_codes.clear(); - - Ok(()) - } - - /// Enable TOTP - pub async fn enable_totp<'e, E>(&mut self, executor: E) -> Result<(), SqlxError> - where - E: PgExecutor<'e>, - { - if !self.totp_enabled { - query!( - "UPDATE \"user\" SET totp_enabled = TRUE WHERE id = $1", - self.id - ) - .execute(executor) - .await?; - self.totp_enabled = true; - } - - Ok(()) - } - - /// Disable TOTP; discard the secret. - pub async fn disable_totp(&mut self, pool: &PgPool) -> Result<(), SqlxError> { - if self.totp_enabled { - // FIXME: check if this flag is set correctly when TOTP is the only method - self.mfa_enabled = self.check_mfa_enabled(pool).await?; - self.totp_enabled = false; - self.totp_secret = None; - - query!( - "UPDATE \"user\" SET mfa_enabled = $2, totp_enabled = $3 AND totp_secret = $4 \ - WHERE id = $1", - self.id, - self.mfa_enabled, - self.totp_enabled, - self.totp_secret, - ) - .execute(pool) - .await?; - } - - Ok(()) - } - - /// Enable email MFA - pub async fn enable_email_mfa<'e, E>(&mut self, executor: E) -> Result<(), SqlxError> - where - E: PgExecutor<'e>, - { - if !self.email_mfa_enabled { - query!( - "UPDATE \"user\" SET email_mfa_enabled = TRUE WHERE id = $1", - self.id - ) - .execute(executor) - .await?; - - self.email_mfa_enabled = true; - } - - Ok(()) - } - - /// Disable email MFA; discard the secret. - pub async fn disable_email_mfa(&mut self, pool: &PgPool) -> Result<(), SqlxError> { - if self.email_mfa_enabled { - self.mfa_enabled = self.check_mfa_enabled(pool).await?; - self.email_mfa_enabled = false; - self.email_mfa_secret = None; - - query!( - "UPDATE \"user\" SET mfa_enabled = $2, email_mfa_enabled = $3 AND email_mfa_secret = $4 \ - WHERE id = $1", - self.id, - self.mfa_enabled, - self.email_mfa_enabled, - self.email_mfa_secret, - ) - .execute(pool) - .await?; - } - - Ok(()) - } - - /// Select all users without sensitive data. - // FIXME: Remove it when Model macro will support SecretString - pub async fn all_without_sensitive_data( - pool: &PgPool, - ) -> Result, SqlxError> { - let users = query!( - "SELECT id, mfa_enabled, totp_enabled, email_mfa_enabled, \ - mfa_method \"mfa_method: MFAMethod\", password_hash, is_active, openid_sub, \ - from_ldap, ldap_pass_randomized, ldap_rdn \ - FROM \"user\"" - ) - .fetch_all(pool) - .await?; - let res: Vec = users - .iter() - .map(|u| UserDiagnostic { - mfa_method: u.mfa_method.clone(), - totp_enabled: u.totp_enabled, - email_mfa_enabled: u.email_mfa_enabled, - mfa_enabled: u.mfa_enabled, - id: u.id, - is_active: u.is_active, - enrolled: u.password_hash.is_some() || u.openid_sub.is_some() || u.from_ldap, - }) - .collect(); - - Ok(res) - } - - /// Return all members of group - pub async fn find_by_group_name( - pool: &PgPool, - group_name: &str, - ) -> Result>, SqlxError> { - let users = query_as!( - Self, - "SELECT \"user\".id, username, password_hash, last_name, first_name, email, \ - phone, mfa_enabled, totp_enabled, totp_secret, \ - email_mfa_enabled, email_mfa_secret, \ - mfa_method \"mfa_method: _\", recovery_codes, is_active, openid_sub, \ - from_ldap, ldap_pass_randomized, ldap_rdn, ldap_user_path, enrollment_pending \ - FROM \"user\" \ - INNER JOIN \"group_user\" ON \"user\".id = \"group_user\".user_id \ - INNER JOIN \"group\" ON \"group_user\".group_id = \"group\".id \ - WHERE \"group\".name = $1", - group_name - ) - .fetch_all(pool) - .await?; - - Ok(users) - } - - /// Check if TOTP `code` is valid. - #[must_use] - pub fn verify_totp_code(&self, code: &str) -> bool { - if let Some(totp_secret) = &self.totp_secret { - if let Ok(timestamp) = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH) { - let expected_code = totp_custom::( - TOTP_CODE_VALIDITY_PERIOD, - TOTP_CODE_DIGITS, - totp_secret, - timestamp.as_secs(), - ); - return code == expected_code; - } - } - - false - } - - /// Generate MFA code for email verification. - /// - /// NOTE: This code will be valid for two time frames. See comment for verify_email_mfa_code(). - pub fn generate_email_mfa_code(&self) -> Result { - if let Some(email_mfa_secret) = &self.email_mfa_secret { - let timeout = &server_config().mfa_code_timeout; - if let Ok(timestamp) = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH) { - let code = totp_custom::( - timeout.as_secs(), - EMAIL_CODE_DIGITS, - email_mfa_secret, - timestamp.as_secs(), - ); - Ok(code) - } else { - Err(WebError::EmailMfa("SystemTime before UNIX epoch".into())) - } - } else { - Err(WebError::EmailMfa(format!( - "Email MFA secret not configured for user {}", - self.username - ))) - } - } - - /// Check if email MFA `code` is valid. - /// - /// IMPORTANT: because current implementation uses TOTP for email verification, - /// allow the code for the previous time frame. This approach pretends the code is valid - /// for a certain *period of time* (as opposed to a TOTP code which is valid for a certain time *frame*). - /// - /// ```text - /// |<---- frame #0 ---->|<---- frame #1 ---->|<---- frame #2 ---->| - /// |................[*]email sent.................................| - /// |......................[*]email code verified..................| - /// ``` - #[must_use] - pub fn verify_email_mfa_code(&self, code: &str) -> bool { - if let Some(email_mfa_secret) = &self.email_mfa_secret { - let timeout = server_config().mfa_code_timeout.as_secs(); - if let Ok(timestamp) = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH) { - let expected_code = totp_custom::( - timeout, - EMAIL_CODE_DIGITS, - email_mfa_secret, - timestamp.as_secs(), - ); - if code == expected_code { - return true; - } - debug!( - "Email MFA verification TOTP code for user {} doesn't fit current time \ - frame, checking the previous one. \ - Expected: {expected_code}, got: {code}", - self.username - ); - - let previous_code = totp_custom::( - timeout, - EMAIL_CODE_DIGITS, - email_mfa_secret, - timestamp.as_secs() - timeout, - ); - - if code == previous_code { - return true; - } - debug!( - "Email MFA verification TOTP code for user {} doesn't fit previous time frame, \ - expected: {previous_code}, got: {code}", - self.username - ); - return false; - } - debug!( - "Couldn't calculate current timestamp when verifying email MFA code for user {}", - self.username - ); - } else { - debug!("Email MFA secret not configured for user {}", self.username); - } - false - } - - /// Verify recovery code. If it is valid, consume it, so it can't be used again. - pub(crate) async fn verify_recovery_code( - &mut self, - pool: &PgPool, - code: &str, - ) -> Result { - if let Some(index) = self.recovery_codes.iter().position(|c| c == code) { - // Note: swap_remove() should be faster than remove(). - self.recovery_codes.swap_remove(index); - - query!( - "UPDATE \"user\" SET recovery_codes = $2 WHERE id = $1", - self.id, - &self.recovery_codes - ) - .execute(pool) - .await?; - - Ok(true) - } else { - Ok(false) - } - } - - pub async fn find_by_username<'e, E>( - executor: E, - username: &str, - ) -> Result, SqlxError> - where - E: PgExecutor<'e>, - { - query_as!( - Self, - "SELECT id, username, password_hash, last_name, first_name, email, phone, mfa_enabled, \ - totp_enabled, email_mfa_enabled, totp_secret, email_mfa_secret, \ - mfa_method \"mfa_method: _\", recovery_codes, is_active, openid_sub, \ - from_ldap, ldap_pass_randomized, ldap_rdn, ldap_user_path, enrollment_pending \ - FROM \"user\" WHERE username = $1", - username - ) - .fetch_optional(executor) - .await - } - - pub(crate) async fn find_by_email<'e, E>( - executor: E, - email: &str, - ) -> Result, SqlxError> - where - E: PgExecutor<'e>, - { - query_as!( - Self, - "SELECT id, username, password_hash, last_name, first_name, email, phone, mfa_enabled, \ - totp_enabled, email_mfa_enabled, totp_secret, email_mfa_secret, \ - mfa_method \"mfa_method: _\", recovery_codes, is_active, openid_sub, from_ldap, \ - ldap_pass_randomized, ldap_rdn, ldap_user_path, enrollment_pending \ - FROM \"user\" WHERE email ILIKE $1", - email - ) - .fetch_optional(executor) - .await - } - - /// Attempts to find user by username and then by email, if none is initially found. - pub async fn find_by_username_or_email( - conn: &mut PgConnection, - username_or_email: &str, - ) -> Result, SqlxError> { - let maybe_user = Self::find_by_username(&mut *conn, username_or_email).await?; - if let Some(user) = maybe_user { - Ok(Some(user)) - } else { - debug!( - "Failed to find user by username {username_or_email}. Attempting to find by email" - ); - Ok(Self::find_by_email(&mut *conn, username_or_email).await?) - } - } - - pub(crate) async fn find_many_by_emails<'e, E>( - executor: E, - emails: &[&str], - ) -> Result, SqlxError> - where - E: PgExecutor<'e>, - { - query_as( - "SELECT id, username, password_hash, last_name, first_name, email, phone, \ - mfa_enabled, totp_enabled, email_mfa_enabled, totp_secret, email_mfa_secret, \ - mfa_method, recovery_codes, is_active, openid_sub, from_ldap, ldap_pass_randomized, \ - ldap_rdn, ldap_user_path, enrollment_pending \ - FROM \"user\" WHERE email = ANY($1)", - ) - .bind(emails) - .fetch_all(executor) - .await - } - - pub(crate) async fn find_by_sub<'e, E>( - executor: E, - sub: &str, - ) -> Result, SqlxError> - where - E: PgExecutor<'e>, - { - query_as!( - Self, - "SELECT id, username, password_hash, last_name, first_name, email, phone, \ - mfa_enabled, totp_enabled, email_mfa_enabled, totp_secret, email_mfa_secret, \ - mfa_method \"mfa_method: _\", recovery_codes, is_active, openid_sub, \ - from_ldap, ldap_pass_randomized, ldap_rdn, ldap_user_path, enrollment_pending \ - FROM \"user\" WHERE openid_sub = $1", - sub - ) - .fetch_optional(executor) - .await - } - - pub(crate) async fn member_of_names<'e, E>(&self, executor: E) -> Result, SqlxError> - where - E: PgExecutor<'e>, - { - query_scalar!( - "SELECT \"group\".name FROM \"group\" JOIN group_user ON \"group\".id = group_user.group_id \ - WHERE group_user.user_id = $1", - self.id - ) - .fetch_all(executor) - .await - } - - pub(crate) async fn member_of<'e, E>(&self, executor: E) -> Result>, SqlxError> - where - E: PgExecutor<'e>, - { - query_as!( - Group, - "SELECT id, name, is_admin FROM \"group\" JOIN group_user ON \"group\".id = group_user.group_id \ - WHERE group_user.user_id = $1", - self.id - ) - .fetch_all(executor) - .await - } - - /// Returns a vector of [`UserDevice`]s (hence the name). - /// [`UserDevice`] is a struct containing additional network info about a device. - /// If you only need [`Device`]s, use [`User::devices()`] instead. - pub(crate) async fn user_devices(&self, pool: &PgPool) -> Result, SqlxError> { - let devices = self.devices(pool).await?; - let mut user_devices = Vec::new(); - for device in devices { - if let Some(user_device) = UserDevice::from_device(pool, device).await? { - user_devices.push(user_device); - } - } - - Ok(user_devices) - } - - /// Returns a vector of [`Device`]s related to a user. If you want to get [`UserDevice`]s (which contain additional network info), - /// use [`User::user_devices()`] instead. - pub(crate) async fn devices<'e, E>(&self, executor: E) -> Result>, SqlxError> - where - E: PgExecutor<'e>, - { - query_as!( - Device, - "SELECT device.id, name, wireguard_pubkey, user_id, created, description, \ - device_type \"device_type: DeviceType\", configured \ - FROM device WHERE user_id = $1 and device_type = 'user'::device_type \ - ORDER BY id", - self.id - ) - .fetch_all(executor) - .await - } - - pub(crate) async fn oauth2authorizedapps<'e, E>( - &self, - executor: E, - ) -> Result, SqlxError> - where - E: PgExecutor<'e>, - { - query_as!( - OAuth2AuthorizedAppInfo, - "SELECT oauth2client.id \"oauth2client_id!\", oauth2client.name \"oauth2client_name\", \ - oauth2authorizedapp.user_id \"user_id\" \ - FROM oauth2authorizedapp \ - JOIN oauth2client ON oauth2client.id = oauth2authorizedapp.oauth2client_id \ - WHERE oauth2authorizedapp.user_id = $1", - self.id - ) - .fetch_all(executor) - .await - } - - pub(crate) async fn security_keys(&self, pool: &PgPool) -> Result, SqlxError> { - query_as!( - SecurityKey, - "SELECT id \"id!\", name FROM webauthn WHERE user_id = $1", - self.id - ) - .fetch_all(pool) - .await - } - - pub async fn add_to_group<'e, E>(&self, executor: E, group: &Group) -> Result<(), SqlxError> - where - E: PgExecutor<'e>, - { - query!( - "INSERT INTO group_user (group_id, user_id) VALUES ($1, $2) \ - ON CONFLICT DO NOTHING", - group.id, - self.id - ) - .execute(executor) - .await?; - Ok(()) - } - - pub(crate) async fn remove_from_group<'e, E>( - &self, - executor: E, - group: &Group, - ) -> Result<(), SqlxError> - where - E: PgExecutor<'e>, - { - query!( - "DELETE FROM group_user WHERE group_id = $1 AND user_id = $2", - group.id, - self.id - ) - .execute(executor) - .await?; - Ok(()) - } - - /// Remove authorized apps by their client id's from user - pub(crate) async fn remove_oauth2_authorized_apps<'e, E>( - &self, - executor: E, - app_client_ids: &[i64], - ) -> Result<(), SqlxError> - where - E: PgExecutor<'e>, - { - query!( - "DELETE FROM oauth2authorizedapp WHERE user_id = $1 AND oauth2client_id = ANY($2)", - self.id, - app_client_ids - ) - .execute(executor) - .await?; - - Ok(()) - } - - /// Create admin user if one doesn't exist yet - pub async fn init_admin_user( - pool: &PgPool, - default_admin_pass: &str, - ) -> Result<(), anyhow::Error> { - debug!("Checking if some admin user already exists and creating one if not..."); - let admins = User::find_admins(pool).await?; - if admins.is_empty() { - let admin_groups = Group::find_by_permission(pool, Permission::IsAdmin).await?; - if admin_groups.is_empty() { - return Err(anyhow::anyhow!( - "No admin group and users found, or they are all disabled. \ - You'll need to create and assign the admin group manually, \ - as there must be at least one active admin user." - )); - } - - // create admin user - let password_hash = hash_password(default_admin_pass)?; - let result = query_scalar!( - "INSERT INTO \"user\" (username, password_hash, last_name, first_name, email, ldap_rdn) \ - VALUES ('admin', $1, 'Administrator', 'DefGuard', 'admin@defguard', 'admin') \ - ON CONFLICT DO NOTHING \ - RETURNING id", - password_hash - ) - .fetch_optional(pool) - .await?; - - // if new user was created add them to admin group, first one you find - // the groups are sorted by ID desceding, so it will often be the 1st one = the default admin group - if let Some(new_user_id) = result { - let admin_group_id = admin_groups - .first() - .ok_or(anyhow::anyhow!( - "No admin group found, can't create admin user" - ))? - .id; - info!("New admin user has been created, adding to Admin group..."); - query("INSERT INTO group_user (group_id, user_id) VALUES ($1, $2)") - .bind(admin_group_id) - .bind(new_user_id) - .execute(pool) - .await?; - info!("Admin user has been created as there was no other admin user"); - } else { - return Err(anyhow::anyhow!( - "A conflict occurred while trying to add a missing admin. \ - There is already a user with username 'admin' but he is not an admin or he is disabled. \ - You will need to assign someone the admin group manually or enable this admin user, \ - as there must be at least one active admin." - )); - } - } else { - debug!("Admin users already exists, skipping creation of the default admin user"); - } - Ok(()) - } - - pub async fn logout_all_sessions<'e, E>(&self, executor: E) -> Result<(), SqlxError> - where - E: PgExecutor<'e>, - { - Session::delete_all_for_user(executor, self.id).await?; - Ok(()) - } - - pub async fn find_by_device_id<'e, E>( - executor: E, - device_id: Id, - ) -> Result, SqlxError> - where - E: PgExecutor<'e>, - { - query_as!( - Self, - "SELECT u.id, u.username, u.password_hash, u.last_name, u.first_name, u.email, \ - u.phone, u.mfa_enabled, u.totp_enabled, u.email_mfa_enabled, \ - u.totp_secret, u.email_mfa_secret, u.mfa_method \"mfa_method: _\", u.recovery_codes, \ - u.is_active, u.openid_sub, from_ldap, ldap_pass_randomized, ldap_rdn, ldap_user_path, \ - enrollment_pending \ - FROM \"user\" u \ - JOIN \"device\" d ON u.id = d.user_id \ - WHERE d.id = $1", - device_id - ) - .fetch_optional(executor) - .await - } - - /// Find users which emails are NOT in `user_emails`. - pub(crate) async fn exclude<'e, E>( - executor: E, - user_emails: &[&str], - ) -> Result, SqlxError> - where - E: PgExecutor<'e>, - { - // This can't be a macro since sqlx can't handle an array of slices in a macro. - query_as( - "SELECT id, username, password_hash, last_name, first_name, email, phone, \ - mfa_enabled, totp_enabled, email_mfa_enabled, totp_secret, email_mfa_secret, \ - mfa_method, recovery_codes, is_active, openid_sub, from_ldap, ldap_pass_randomized, \ - ldap_rdn, ldap_user_path, enrollment_pending \ - FROM \"user\" WHERE email NOT IN (SELECT * FROM UNNEST($1::TEXT[]))", - ) - .bind(user_emails) - .fetch_all(executor) - .await - } - - pub(crate) async fn is_admin<'e, E>(&self, executor: E) -> Result - where - E: PgExecutor<'e>, - { - query_scalar!("SELECT EXISTS (SELECT 1 FROM group_user gu LEFT JOIN \"group\" g ON gu.group_id = g.id \ - WHERE is_admin = true AND user_id = $1) \"bool!\"", self.id) - .fetch_one(executor) - .await - } - - /// Find all users that are admins and are active. - pub(crate) async fn find_admins<'e, E>(executor: E) -> Result, SqlxError> - where - E: PgExecutor<'e>, - { - query_as!( - Self, - " - SELECT u.id, u.username, u.password_hash, u.last_name, u.first_name, u.email, \ - u.phone, u.mfa_enabled, u.totp_enabled, u.email_mfa_enabled, \ - u.totp_secret, u.email_mfa_secret, u.mfa_method \"mfa_method: _\", u.recovery_codes, u.is_active, u.openid_sub, \ - from_ldap, ldap_pass_randomized, ldap_rdn, ldap_user_path, enrollment_pending \ - FROM \"user\" u \ - WHERE EXISTS (SELECT 1 FROM group_user gu LEFT JOIN \"group\" g ON gu.group_id = g.id \ - WHERE is_admin = true AND user_id = u.id) AND u.is_active = true" - ) - .fetch_all(executor) - .await - } -} - -#[cfg(test)] -impl Distribution> for Standard { - fn sample(&self, rng: &mut R) -> User { - User { - id: rng.r#gen(), - username: Alphanumeric.sample_string(rng, 8), - password_hash: rng - .r#gen::() - .then_some(Alphanumeric.sample_string(rng, 8)), - last_name: Alphanumeric.sample_string(rng, 8), - first_name: Alphanumeric.sample_string(rng, 8), - email: format!("{}@defguard.net", Alphanumeric.sample_string(rng, 6)), - // FIXME: generate an actual phone number - phone: rng - .r#gen::() - .then_some(Alphanumeric.sample_string(rng, 9)), - mfa_enabled: rng.r#gen(), - is_active: true, - openid_sub: rng - .r#gen::() - .then_some(Alphanumeric.sample_string(rng, 8)), - totp_enabled: rng.r#gen(), - email_mfa_enabled: rng.r#gen(), - totp_secret: (0..20).map(|_| rng.r#gen()).collect(), - email_mfa_secret: (0..20).map(|_| rng.r#gen()).collect(), - mfa_method: match rng.r#gen_range(0..4) { - 0 => MFAMethod::None, - 1 => MFAMethod::Webauthn, - 2 => MFAMethod::OneTimePassword, - _ => MFAMethod::Email, - }, - recovery_codes: (0..3).map(|_| Alphanumeric.sample_string(rng, 6)).collect(), - from_ldap: false, - ldap_pass_randomized: false, - ldap_rdn: None, - ldap_user_path: None, - enrollment_pending: false, - } - } -} - -#[cfg(test)] -impl Distribution> for Standard { - fn sample(&self, rng: &mut R) -> User { - User { - id: NoId, - username: Alphanumeric.sample_string(rng, 8), - password_hash: rng - .r#gen::() - .then_some(Alphanumeric.sample_string(rng, 8)), - last_name: Alphanumeric.sample_string(rng, 8), - first_name: Alphanumeric.sample_string(rng, 8), - email: format!("{}@defguard.net", Alphanumeric.sample_string(rng, 6)), - // FIXME: generate an actual phone number - phone: rng - .r#gen::() - .then_some(Alphanumeric.sample_string(rng, 9)), - mfa_enabled: rng.r#gen(), - is_active: true, - openid_sub: rng - .r#gen::() - .then_some(Alphanumeric.sample_string(rng, 8)), - totp_enabled: rng.r#gen(), - email_mfa_enabled: rng.r#gen(), - totp_secret: (0..20).map(|_| rng.r#gen()).collect(), - email_mfa_secret: (0..20).map(|_| rng.r#gen()).collect(), - mfa_method: match rng.r#gen_range(0..4) { - 0 => MFAMethod::None, - 1 => MFAMethod::Webauthn, - 2 => MFAMethod::OneTimePassword, - _ => MFAMethod::Email, - }, - recovery_codes: (0..3).map(|_| Alphanumeric.sample_string(rng, 6)).collect(), - from_ldap: false, - ldap_pass_randomized: false, - ldap_rdn: None, - ldap_user_path: None, - enrollment_pending: false, - } - } -} - -#[cfg(test)] -mod test { - use defguard_common::{ - config::{DefGuardConfig, SERVER_CONFIG}, - db::{models::settings::initialize_current_settings, setup_pool}, - }; - use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; - - use super::*; - - #[sqlx::test] - async fn test_mfa_code(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - - let config = DefGuardConfig::new_test_config(); - let _ = SERVER_CONFIG.set(config.clone()); - initialize_current_settings(&pool).await.unwrap(); - - let mut user = User::new( - "hpotter", - Some("pass123"), - "Potter", - "Harry", - "h.potter@hogwart.edu.uk", - None, - ) - .save(&pool) - .await - .unwrap(); - user.new_email_secret(&pool).await.unwrap(); - assert!(user.email_mfa_secret.is_some()); - let code = user.generate_email_mfa_code().unwrap(); - assert!( - user.verify_email_mfa_code(&code), - "code={code}, secret={:?}", - user.email_mfa_secret.unwrap() - ); - } - - #[sqlx::test] - async fn test_user(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - - let mut user = User::new( - "hpotter", - Some("pass123"), - "Potter", - "Harry", - "h.potter@hogwart.edu.uk", - None, - ) - .save(&pool) - .await - .unwrap(); - - let fetched_user = User::find_by_username(&pool, "hpotter").await.unwrap(); - assert!(fetched_user.is_some()); - assert_eq!(fetched_user.unwrap().email, "h.potter@hogwart.edu.uk"); - - user.email = "harry.potter@hogwart.edu.uk".into(); - user.save(&pool).await.unwrap(); - - let fetched_user = User::find_by_username(&pool, "hpotter").await.unwrap(); - assert!(fetched_user.is_some()); - assert_eq!(fetched_user.unwrap().email, "harry.potter@hogwart.edu.uk"); - - assert!(user.verify_password("pass123").is_ok()); - - let fetched_user = User::find_by_username(&pool, "rweasley").await.unwrap(); - assert!(fetched_user.is_none()); - } - - #[sqlx::test] - async fn test_all_users(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - - User::new( - "hpotter", - Some("pass123"), - "Potter", - "Harry", - "h.potter@hogwart.edu.uk", - None, - ) - .save(&pool) - .await - .unwrap(); - - let albus = User::new( - "adumbledore", - Some("magic!"), - "Dumbledore", - "Albus", - "a.dumbledore@hogwart.edu.uk", - None, - ) - .save(&pool) - .await - .unwrap(); - - let users = User::all(&pool).await.unwrap(); - assert_eq!(users.len(), 2); - - albus.delete(&pool).await.unwrap(); - - let users = User::all(&pool).await.unwrap(); - assert_eq!(users.len(), 1); - } - - #[sqlx::test] - async fn test_recovery_codes(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - - let mut harry = User::new( - "hpotter", - Some("pass123"), - "Potter", - "Harry", - "h.potter@hogwart.edu.uk", - None, - ) - .save(&pool) - .await - .unwrap(); - harry.get_recovery_codes(&pool).await.unwrap(); - assert_eq!(harry.recovery_codes.len(), RECOVERY_CODES_COUNT); - - let fetched_user = User::find_by_username(&pool, "hpotter").await.unwrap(); - assert!(fetched_user.is_some()); - - let mut user = fetched_user.unwrap(); - assert_eq!(user.recovery_codes.len(), RECOVERY_CODES_COUNT); - assert!( - !user - .verify_recovery_code(&pool, "invalid code") - .await - .unwrap() - ); - let codes = user.recovery_codes.clone(); - for code in &codes { - assert!(user.verify_recovery_code(&pool, code).await.unwrap()); - } - assert_eq!(user.recovery_codes.len(), 0); - } - - #[sqlx::test] - async fn test_email_case_insensitivity(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - - let harry = User::new( - "hpotter", - Some("pass123"), - "Potter", - "Harry", - "h.potter@hogwart.edu.uk", - None, - ); - assert!(harry.save(&pool).await.is_ok()); - - let henry = User::new( - "h.potter", - Some("pass123"), - "Potter", - "Henry", - "h.potter@hogwart.edu.uk", - None, - ); - assert!(henry.save(&pool).await.is_err()); - } - - #[sqlx::test] - async fn test_is_admin(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - - let config = DefGuardConfig::new_test_config(); - let _ = SERVER_CONFIG.set(config.clone()); - - let user = User::new( - "hpotter", - Some("pass123"), - "Potter", - "Harry", - "h.potter@hogwart.edu.uk", - None, - ) - .save(&pool) - .await - .unwrap(); - - let is_admin = user.is_admin(&pool).await.unwrap(); - - assert!(!is_admin); - - query!( - "INSERT INTO group_user (group_id, user_id) VALUES (1, $1)", - user.id - ) - .execute(&pool) - .await - .unwrap(); - - let is_admin = user.is_admin(&pool).await.unwrap(); - - assert!(is_admin); - } - - #[sqlx::test] - async fn test_find_admins(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - - let config = DefGuardConfig::new_test_config(); - let _ = SERVER_CONFIG.set(config.clone()); - - let user = User::new( - "hpotter", - Some("pass123"), - "Potter", - "Harry", - "h.potter@hogwart.edu.uk", - None, - ) - .save(&pool) - .await - .unwrap(); - - let user2 = User::new( - "hpotter2", - Some("pass123"), - "Potter", - "Harry", - "h.potter2@hogwart.edu.uk", - None, - ) - .save(&pool) - .await - .unwrap(); - - User::new( - "hpotter3", - Some("pass123"), - "Potter", - "Harry", - "h.potter3@hogwart.edu.uk", - None, - ) - .save(&pool) - .await - .unwrap(); - - query!( - "INSERT INTO group_user (group_id, user_id) VALUES (1, $1), (1, $2)", - user.id, - user2.id, - ) - .execute(&pool) - .await - .unwrap(); - - let admins = User::find_admins(&pool).await.unwrap(); - assert_eq!(admins.len(), 2); - assert!(admins.iter().any(|u| u.id == user.id)); - assert!(admins.iter().any(|u| u.id == user2.id)); - } - - #[sqlx::test] - async fn test_get_missing(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - - let user1 = User::new( - "hpotter", - Some("pass123"), - "Potter", - "Harry", - "h.potter@hogwart.edu.uk", - None, - ) - .save(&pool) - .await - .unwrap(); - let user2 = User::new( - "hpotter2", - Some("pass1234"), - "Potter2", - "Harry2", - "h.potter2@hogwart.edu.uk", - None, - ) - .save(&pool) - .await - .unwrap(); - let albus = User::new( - "adumbledore", - Some("magic!"), - "Dumbledore", - "Albus", - "a.dumbledore@hogwart.edu.uk", - None, - ) - .save(&pool) - .await - .unwrap(); - - let user_emails = vec![user1.email.as_str(), albus.email.as_str()]; - let users = User::exclude(&pool, &user_emails).await.unwrap(); - assert_eq!(users.len(), 1); - assert_eq!(users[0].id, user2.id); - } - - #[sqlx::test] - async fn test_find_many_by_emails(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - - let user1 = User::new( - "hpotter", - Some("pass123"), - "Potter", - "Harry", - "h.potter@hogwart.edu.uk", - None, - ) - .save(&pool) - .await - .unwrap(); - User::new( - "hpotter2", - Some("pass1234"), - "Potter2", - "Harry2", - "h.potter2@hogwart.edu.uk", - None, - ) - .save(&pool) - .await - .unwrap(); - let albus = User::new( - "adumbledore", - Some("magic!"), - "Dumbledore", - "Albus", - "a.dumbledore@hogwart.edu.uk", - None, - ) - .save(&pool) - .await - .unwrap(); - - let user_emails = vec![user1.email.as_str(), albus.email.as_str()]; - let users = User::find_many_by_emails(&pool, &user_emails) - .await - .unwrap(); - assert_eq!(users.len(), 2); - assert_eq!(users[0].id, user1.id); - assert_eq!(users[1].id, albus.id); - } - - #[sqlx::test] - async fn test_user_is_enrolled(_: PgPoolOptions, options: PgConnectOptions) { - let pool = setup_pool(options).await; - let user = User::new( - "test", - Some("31071980"), - "harry", - "potter", - "harry@hogwart.edu.uk", - None, - ); - let mut user = user.save(&pool).await.unwrap(); - - user.enrollment_pending = false; - user.password_hash = Some(hash_password("31071980").unwrap()); - user.openid_sub = Some("sub".to_string()); - user.from_ldap = true; - user.save(&pool).await.unwrap(); - assert!(user.is_enrolled()); - - user.enrollment_pending = false; - user.password_hash = None; - user.openid_sub = Some("sub".to_string()); - user.from_ldap = true; - user.save(&pool).await.unwrap(); - assert!(user.is_enrolled()); - - user.enrollment_pending = false; - user.password_hash = None; - user.openid_sub = None; - user.from_ldap = true; - user.save(&pool).await.unwrap(); - assert!(user.is_enrolled()); - - user.enrollment_pending = false; - user.password_hash = None; - user.openid_sub = None; - user.from_ldap = false; - user.save(&pool).await.unwrap(); - assert!(!user.is_enrolled()); - - user.enrollment_pending = true; - user.password_hash = None; - user.openid_sub = None; - user.from_ldap = false; - user.save(&pool).await.unwrap(); - assert!(!user.is_enrolled()); - - user.enrollment_pending = true; - user.password_hash = Some(hash_password("31071980").unwrap()); - user.openid_sub = Some("sub".to_string()); - user.from_ldap = true; - user.save(&pool).await.unwrap(); - assert!(!user.is_enrolled()); - } -} diff --git a/crates/defguard_core/src/db/models/webhook.rs b/crates/defguard_core/src/db/models/webhook.rs index 8b2715c46a..91086edbce 100644 --- a/crates/defguard_core/src/db/models/webhook.rs +++ b/crates/defguard_core/src/db/models/webhook.rs @@ -1,9 +1,10 @@ -use defguard_common::db::{Id, NoId}; +use defguard_common::{ + db::{Id, NoId}, + types::user_info::UserInfo, +}; use model_derive::Model; use sqlx::{Error as SqlxError, FromRow, PgPool, query_as}; -use super::UserInfo; - /// App events which triggers webhook action #[derive(Debug)] pub enum AppEvent { diff --git a/crates/defguard_core/src/enrollment_management.rs b/crates/defguard_core/src/enrollment_management.rs new file mode 100644 index 0000000000..821cff6da2 --- /dev/null +++ b/crates/defguard_core/src/enrollment_management.rs @@ -0,0 +1,240 @@ +use defguard_common::db::{Id, models::user::User}; +use defguard_mail::{Mail, templates}; +use reqwest::Url; +use sqlx::{PgConnection, PgExecutor}; +use tokio::sync::mpsc::UnboundedSender; + +use crate::db::models::enrollment::{ENROLLMENT_TOKEN_TYPE, Token, TokenError}; + +static ENROLLMENT_START_MAIL_SUBJECT: &str = "Defguard user enrollment"; +static DESKTOP_START_MAIL_SUBJECT: &str = "Defguard desktop client configuration"; + +/// Start user enrollment process +/// This creates a new enrollment token valid for 24h +/// and optionally sends enrollment email notification to user +pub async fn start_user_enrollment( + user: &mut User, + transaction: &mut PgConnection, + admin: &User, + email: Option, + token_timeout_seconds: u64, + enrollment_service_url: Url, + send_user_notification: bool, + mail_tx: UnboundedSender, +) -> Result { + info!( + "User {} started a new enrollment process for user {}.", + admin.username, user.username + ); + debug!( + "Notify user by mail about the enrollment process: {}", + send_user_notification + ); + debug!("Check if {} has a password.", user.username); + if user.has_password() { + debug!( + "User {} that you want to start enrollment process for already has a password.", + user.username + ); + return Err(TokenError::AlreadyActive); + } + + debug!("Verify that {} is an active user.", user.username); + if !user.is_active { + warn!( + "Can't create enrollment token for disabled user {}", + user.username + ); + return Err(TokenError::UserDisabled); + } + + clear_unused_enrollment_tokens(user, &mut *transaction).await?; + + debug!("Create a new enrollment token for user {}.", user.username); + let enrollment = Token::new( + user.id, + Some(admin.id), + email.clone(), + token_timeout_seconds, + Some(ENROLLMENT_TOKEN_TYPE.to_string()), + ); + debug!("Saving a new enrollment token..."); + enrollment.save(&mut *transaction).await?; + debug!( + "Saved a new enrollment token with id {} for user {}.", + enrollment.id, user.username + ); + + // Mark the user with enrollment-pending flag. + // https://github.com/DefGuard/client/issues/647 + user.enrollment_pending = true; + user.save(&mut *transaction).await?; + + if send_user_notification { + if let Some(email) = email { + debug!( + "Sending an enrollment mail for user {} to {email}.", + user.username + ); + let base_message_context = enrollment + .get_welcome_message_context(&mut *transaction) + .await?; + let mail = Mail { + to: email.clone(), + subject: ENROLLMENT_START_MAIL_SUBJECT.to_string(), + content: templates::enrollment_start_mail( + base_message_context, + enrollment_service_url, + &enrollment.id, + ) + .map_err(|err| { + debug!( + "Cannot send an email to the user {} due to the error {}.", + user.username, + err.to_string() + ); + TokenError::NotificationError(err.to_string()) + })?, + attachments: Vec::new(), + result_tx: None, + }; + match mail_tx.send(mail) { + Ok(()) => { + info!( + "Sent enrollment start mail for user {} to {email}", + user.username + ); + } + Err(err) => { + error!("Error sending mail: {err}"); + return Err(TokenError::NotificationError(err.to_string())); + } + } + } + } + info!( + "New enrollment token has been generated for {}.", + user.username + ); + + Ok(enrollment.id) +} + +/// Start user remote desktop configuration process +/// This creates a new enrollment token valid for 24h +/// and optionally sends email notification to user +pub async fn start_desktop_configuration( + user: &User, + transaction: &mut PgConnection, + admin: &User, + email: Option, + token_timeout_seconds: u64, + enrollment_service_url: Url, + send_user_notification: bool, + mail_tx: UnboundedSender, + // Whether to attach some device to the token. It allows for a partial initialization of + // the device before the desktop configuration has taken place. + device_id: Option, +) -> Result { + info!( + "User {} starting a new desktop activation for user {}", + admin.username, user.username + ); + debug!( + "Notify {} by mail about the enrollment process: {}", + user.username, send_user_notification + ); + + debug!("Verify that {} is an active user.", user.username); + if !user.is_active { + warn!( + "Can't create desktop activation token for disabled user {}.", + user.username + ); + return Err(TokenError::UserDisabled); + } + + clear_unused_enrollment_tokens(user, &mut *transaction).await?; + debug!("Cleared unused tokens for {}.", user.username); + + debug!( + "Create a new desktop activation token for user {}.", + user.username + ); + let mut desktop_configuration = Token::new( + user.id, + Some(admin.id), + email.clone(), + token_timeout_seconds, + Some(ENROLLMENT_TOKEN_TYPE.to_string()), + ); + if let Some(device_id) = device_id { + desktop_configuration.device_id = Some(device_id); + } + debug!("Saving a new desktop configuration token..."); + desktop_configuration.save(&mut *transaction).await?; + debug!( + "Saved a new desktop activation token with id {} for user {}.", + desktop_configuration.id, user.username + ); + + if send_user_notification { + if let Some(email) = email { + debug!( + "Sending a desktop configuration mail for user {} to {email}", + user.username + ); + let base_message_context = desktop_configuration + .get_welcome_message_context(&mut *transaction) + .await?; + let mail = Mail { + to: email.clone(), + subject: DESKTOP_START_MAIL_SUBJECT.to_string(), + content: templates::desktop_start_mail( + base_message_context, + &enrollment_service_url, + &desktop_configuration.id, + ) + .map_err(|err| { + debug!( + "Cannot send an email to the user {} due to the error {}.", + user.username, + err.to_string() + ); + TokenError::NotificationError(err.to_string()) + })?, + attachments: Vec::new(), + result_tx: None, + }; + match mail_tx.send(mail) { + Ok(()) => { + info!( + "Sent desktop configuration start mail for user {} to {email}", + user.username + ); + } + Err(err) => { + error!("Error sending mail: {err}"); + } + } + } + } + info!( + "New desktop activation token has been generated for {}.", + user.username + ); + + Ok(desktop_configuration.id) +} + +// Remove unused tokens when triggering user enrollment +pub async fn clear_unused_enrollment_tokens<'e, E>( + user: &User, + executor: E, +) -> Result<(), TokenError> +where + E: PgExecutor<'e>, +{ + info!("Removing unused tokens for user {}.", user.username); + Token::delete_unused_user_tokens(executor, user.id).await +} diff --git a/crates/defguard_core/src/enterprise/db/models/acl.rs b/crates/defguard_core/src/enterprise/db/models/acl.rs index 069a9ebe81..605f147828 100644 --- a/crates/defguard_core/src/enterprise/db/models/acl.rs +++ b/crates/defguard_core/src/enterprise/db/models/acl.rs @@ -6,7 +6,15 @@ use std::{ }; use chrono::NaiveDateTime; -use defguard_common::db::{Id, NoId}; +use defguard_common::db::{ + Id, NoId, + models::{ + Device, DeviceType, WireguardNetwork, + group::Group, + user::User, + wireguard::{LocationMfaMode, ServiceLocationMode}, + }, +}; use ipnetwork::{IpNetwork, IpNetworkError}; use model_derive::Model; use sqlx::{ @@ -16,16 +24,12 @@ use sqlx::{ use thiserror::Error; use crate::{ - DeviceType, appstate::AppState, - db::{ - Device, GatewayEvent, Group, User, WireguardNetwork, - models::wireguard::{LocationMfaMode, ServiceLocationMode}, - }, enterprise::{ - firewall::FirewallError, + firewall::{FirewallError, try_get_location_firewall_config}, handlers::acl::{ApiAclAlias, ApiAclRule, EditAclAlias, EditAclRule}, }, + grpc::gateway::events::GatewayEvent, }; #[derive(Debug, Error)] @@ -479,7 +483,7 @@ impl AclRule { ); for location in affected_locations { - match location.try_get_firewall_config(&mut transaction).await? { + match try_get_location_firewall_config(&location, &mut transaction).await? { Some(firewall_config) => { debug!("Sending firewall update event for location {location}"); appstate.send_wireguard_event(GatewayEvent::FirewallConfigChanged( @@ -1677,7 +1681,7 @@ impl AclAlias { ); for location in affected_locations { - match location.try_get_firewall_config(&mut transaction).await? { + match try_get_location_firewall_config(&location, &mut transaction).await? { Some(firewall_config) => { debug!("Sending firewall update event for location {location}"); appstate.send_wireguard_event(GatewayEvent::FirewallConfigChanged( diff --git a/crates/defguard_core/src/enterprise/db/models/acl/tests.rs b/crates/defguard_core/src/enterprise/db/models/acl/tests.rs index 85cdd42d82..00f598db6e 100644 --- a/crates/defguard_core/src/enterprise/db/models/acl/tests.rs +++ b/crates/defguard_core/src/enterprise/db/models/acl/tests.rs @@ -1,11 +1,10 @@ use std::ops::Bound; -use defguard_common::db::setup_pool; +use defguard_common::{db::setup_pool, utils::parse_address_list}; use rand::{Rng, thread_rng}; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use super::*; -use crate::handlers::wireguard::parse_address_list; #[sqlx::test] async fn test_alias(_: PgPoolOptions, options: PgConnectOptions) { diff --git a/crates/defguard_core/src/enterprise/directory_sync/mod.rs b/crates/defguard_core/src/enterprise/directory_sync/mod.rs index b37fccba56..e46eb61081 100644 --- a/crates/defguard_core/src/enterprise/directory_sync/mod.rs +++ b/crates/defguard_core/src/enterprise/directory_sync/mod.rs @@ -4,7 +4,10 @@ use std::{ time::Duration, }; -use defguard_common::db::{Id, models::Settings}; +use defguard_common::db::{ + Id, + models::{Settings, group::Group, user::User}, +}; use paste::paste; use reqwest::header::AUTHORIZATION; use sqlx::{PgConnection, PgPool, error::Error as SqlxError}; @@ -18,13 +21,17 @@ use super::{ ldap::utils::ldap_update_users_state, }; use crate::{ - db::{GatewayEvent, Group, User}, enterprise::{ db::models::openid_provider::DirectorySyncUserBehavior, handlers::openid_login::prune_username, - ldap::utils::{ldap_add_users_to_groups, ldap_delete_users, ldap_remove_users_from_groups}, + ldap::{ + model::ldap_sync_allowed_for_user, + utils::{ldap_add_users_to_groups, ldap_delete_users, ldap_remove_users_from_groups}, + }, }, + grpc::gateway::events::GatewayEvent, handlers::user::check_username, + user_management::{delete_user_and_cleanup_devices, disable_user, sync_allowed_user_devices}, }; const REQUEST_TIMEOUT: Duration = Duration::from_secs(10); @@ -358,7 +365,7 @@ async fn sync_user_groups( } } - user.sync_allowed_devices(&mut transaction, wg_tx) + sync_allowed_user_devices(user, &mut transaction, wg_tx) .await .map_err(|err| { DirectorySyncError::NetworkUpdateError(format!( @@ -561,7 +568,7 @@ async fn sync_all_users_groups( create_and_add_to_group(&user, group, pool).await?; } - user.sync_allowed_devices(&mut transaction, wg_tx).await.map_err(|err| { + sync_allowed_user_devices(&user, &mut transaction, wg_tx).await.map_err(|err| { DirectorySyncError::NetworkUpdateError(format!( "Failed to sync allowed devices for user {} during directory synchronization: {err}", user.email @@ -752,7 +759,7 @@ async fn sync_all_users_state( the admin behavior setting is set to disable", user.email ); - user.disable(&mut transaction, wg_tx).await.map_err(|err| { + disable_user(&mut user, &mut transaction, wg_tx).await.map_err(|err| { DirectorySyncError::UserUpdateError(format!( "Failed to disable admin {} during directory synchronization: {err}", user.email @@ -779,10 +786,10 @@ async fn sync_all_users_state( "Deleting admin {} because they are not present in the directory", user.email ); - if user.ldap_sync_allowed(&mut *transaction).await? { + if ldap_sync_allowed_for_user(&user, &mut *transaction).await? { deleted_users.push(user.clone().as_noid()); } - user.delete_and_cleanup(&mut transaction, wg_tx) + delete_user_and_cleanup_devices(user, &mut transaction, wg_tx) .await .map_err(|err| { DirectorySyncError::UserUpdateError(format!( @@ -806,7 +813,7 @@ async fn sync_all_users_state( "Disabling user {} because they are not present in the directory and the user behavior setting is set to disable", user.email ); - user.disable(&mut transaction, wg_tx).await.map_err(|err| { + disable_user(&mut user, &mut transaction, wg_tx).await.map_err(|err| { DirectorySyncError::UserUpdateError(format!( "Failed to disable user {} during directory synchronization: {err}", user.email @@ -825,10 +832,10 @@ async fn sync_all_users_state( "Deleting user {} because they are not present in the directory", user.email ); - if user.ldap_sync_allowed(&mut *transaction).await? { + if ldap_sync_allowed_for_user(&user, &mut *transaction).await? { deleted_users.push(user.clone().as_noid()); } - user.delete_and_cleanup(&mut transaction, wg_tx) + delete_user_and_cleanup_devices(user, &mut transaction, wg_tx) .await .map_err(|err| { DirectorySyncError::UserUpdateError(format!( @@ -890,12 +897,14 @@ async fn sync_inactive_directory_users( "Disabling user {} because they are disabled in the directory", user.email ); - user.disable(transaction, wg_tx).await.map_err(|err| { - DirectorySyncError::UserUpdateError(format!( - "Failed to disable user {} during directory synchronization: {err}", - user.email - )) - })?; + disable_user(&mut user, transaction, wg_tx) + .await + .map_err(|err| { + DirectorySyncError::UserUpdateError(format!( + "Failed to disable user {} during directory synchronization: {err}", + user.email + )) + })?; modified_users.push(user); } else { debug!("User {} is already disabled, skipping", user.email); diff --git a/crates/defguard_core/src/enterprise/directory_sync/tests.rs b/crates/defguard_core/src/enterprise/directory_sync/tests.rs index 9bac17d281..e79849e2da 100644 --- a/crates/defguard_core/src/enterprise/directory_sync/tests.rs +++ b/crates/defguard_core/src/enterprise/directory_sync/tests.rs @@ -5,7 +5,11 @@ mod test { use defguard_common::{ config::{DefGuardConfig, SERVER_CONFIG}, db::{ - models::{Settings, settings::initialize_current_settings}, + models::{ + Device, DeviceType, Session, SessionState, Settings, WireguardNetwork, + settings::initialize_current_settings, + wireguard::{LocationMfaMode, ServiceLocationMode}, + }, setup_pool, }, }; @@ -15,16 +19,7 @@ mod test { use tokio::sync::broadcast; use super::super::*; - use crate::{ - db::{ - Device, Session, SessionState, WireguardNetwork, - models::{ - device::DeviceType, - wireguard::{LocationMfaMode, ServiceLocationMode}, - }, - }, - enterprise::db::models::openid_provider::DirectorySyncTarget, - }; + use crate::enterprise::db::models::openid_provider::DirectorySyncTarget; async fn get_test_network(pool: &PgPool) -> WireguardNetwork { WireguardNetwork::find_by_name(pool, "test") @@ -657,7 +652,7 @@ mod test { let user2 = get_test_user(&pool, "user2").await; assert!(user2.is_none()); let mut transaction = pool.begin().await.unwrap(); - user.sync_allowed_devices(&mut transaction, &wg_tx) + sync_allowed_user_devices(&user, &mut transaction, &wg_tx) .await .unwrap(); transaction.commit().await.unwrap(); diff --git a/crates/defguard_core/src/enterprise/firewall/mod.rs b/crates/defguard_core/src/enterprise/firewall/mod.rs index 5e2b7e8d97..fc446d7a41 100644 --- a/crates/defguard_core/src/enterprise/firewall/mod.rs +++ b/crates/defguard_core/src/enterprise/firewall/mod.rs @@ -3,7 +3,10 @@ use std::{ ops::RangeInclusive, }; -use defguard_common::db::{Id, models::ModelError}; +use defguard_common::db::{ + Id, + models::{Device, ModelError, WireguardNetwork, user::User}, +}; use defguard_proto::enterprise::firewall::{ FirewallConfig, FirewallPolicy, FirewallRule, IpAddress, IpRange, IpVersion, Port, PortRange as PortRangeProto, SnatBinding as SnatBindingProto, ip_address::Address, @@ -19,12 +22,9 @@ use super::{ }, utils::merge_ranges, }; -use crate::{ - db::{Device, User, WireguardNetwork}, - enterprise::{ - db::models::{acl::AliasKind, snat::UserSnatBinding}, - is_enterprise_enabled, - }, +use crate::enterprise::{ + db::models::{acl::AliasKind, snat::UserSnatBinding}, + is_enterprise_enabled, }; #[derive(Debug, thiserror::Error)] @@ -863,16 +863,15 @@ async fn generate_user_snat_bindings_for_location( Ok(bindings) } -impl WireguardNetwork { - /// Fetches all active ACL rules for a given location. - /// Filters out rules which are disabled, expired or have not been deployed yet. - pub(crate) async fn get_active_acl_rules( - &self, - conn: &mut PgConnection, - ) -> Result>, SqlxError> { - debug!("Fetching active ACL rules for location {self}"); - let rules: Vec> = query_as( - "SELECT DISTINCT ON (a.id) a.id, name, allow_all_users, deny_all_users, all_networks, \ +/// Fetches all active ACL rules for a given location. +/// Filters out rules which are disabled, expired or have not been deployed yet. +pub(crate) async fn get_location_active_acl_rules( + location: &WireguardNetwork, + conn: &mut PgConnection, +) -> Result>, SqlxError> { + debug!("Fetching active ACL rules for location {location}"); + let rules: Vec> = query_as( + "SELECT DISTINCT ON (a.id) a.id, name, allow_all_users, deny_all_users, all_networks, \ allow_all_network_devices, deny_all_network_devices, destination, ports, protocols, \ expires, enabled, parent_id, state \ FROM aclrule a \ @@ -881,65 +880,67 @@ impl WireguardNetwork { WHERE (an.network_id = $1 OR a.all_networks = true) AND enabled = true \ AND state = 'applied'::aclrule_state \ AND (expires IS NULL OR expires > NOW())", - ) - .bind(self.id) - .fetch_all(&mut *conn) - .await?; - debug!("Found {} active ACL rules for location {self}", rules.len()); - - // convert to `AclRuleInfo` - let mut rules_info = Vec::new(); - for rule in rules { - let rule_info = rule.to_info(&mut *conn).await?; - rules_info.push(rule_info); - } - Ok(rules_info) + ) + .bind(location.id) + .fetch_all(&mut *conn) + .await?; + debug!( + "Found {} active ACL rules for location {location}", + rules.len() + ); + + // convert to `AclRuleInfo` + let mut rules_info = Vec::new(); + for rule in rules { + let rule_info = rule.to_info(&mut *conn).await?; + rules_info.push(rule_info); } + Ok(rules_info) +} - /// Prepares firewall configuration for a gateway based on location config and ACLs - /// Returns `None` if firewall management is disabled for a given location. - pub async fn try_get_firewall_config( - &self, - conn: &mut PgConnection, - ) -> Result, FirewallError> { - // do a license check - if !is_enterprise_enabled() { - debug!( - "Enterprise features are disabled, skipping generating firewall config for \ - location {self}" - ); - return Ok(None); - } +/// Prepares firewall configuration for a gateway based on location config and ACLs +/// Returns `None` if firewall management is disabled for a given location. +pub async fn try_get_location_firewall_config( + location: &WireguardNetwork, + conn: &mut PgConnection, +) -> Result, FirewallError> { + // do a license check + if !is_enterprise_enabled() { + debug!( + "Enterprise features are disabled, skipping generating firewall config for \ + location {location}" + ); + return Ok(None); + } - // check if ACLs are enabled - if !self.acl_enabled { - debug!( - "ACL rules are disabled for location {self}, skipping generating firewall config" - ); - return Ok(None); - } + // check if ACLs are enabled + if !location.acl_enabled { + debug!( + "ACL rules are disabled for location {location}, skipping generating firewall config" + ); + return Ok(None); + } - info!("Generating firewall config for location {self}"); - // fetch all active ACLs for location - let location_acls = self.get_active_acl_rules(&mut *conn).await?; + info!("Generating firewall config for location {location}"); + // fetch all active ACLs for location + let location_acls = get_location_active_acl_rules(location, &mut *conn).await?; - let default_policy = if self.acl_default_allow { - FirewallPolicy::Allow - } else { - FirewallPolicy::Deny - }; - let firewall_rules = - generate_firewall_rules_from_acls(self.id, location_acls, &mut *conn).await?; - let snat_bindings = generate_user_snat_bindings_for_location(self.id, &mut *conn).await?; - let firewall_config = FirewallConfig { - default_policy: default_policy.into(), - rules: firewall_rules, - snat_bindings, - }; + let default_policy = if location.acl_default_allow { + FirewallPolicy::Allow + } else { + FirewallPolicy::Deny + }; + let firewall_rules = + generate_firewall_rules_from_acls(location.id, location_acls, &mut *conn).await?; + let snat_bindings = generate_user_snat_bindings_for_location(location.id, &mut *conn).await?; + let firewall_config = FirewallConfig { + default_policy: default_policy.into(), + rules: firewall_rules, + snat_bindings, + }; - debug!("Firewall config generated for location {self}: {firewall_config:?}"); - Ok(Some(firewall_config)) - } + debug!("Firewall config generated for location {location}: {firewall_config:?}"); + Ok(Some(firewall_config)) } #[cfg(test)] diff --git a/crates/defguard_core/src/enterprise/firewall/tests.rs b/crates/defguard_core/src/enterprise/firewall/tests.rs index abad1a6910..7f4b9e1bf8 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests.rs @@ -1,7 +1,14 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use chrono::{DateTime, NaiveDateTime}; -use defguard_common::db::{Id, NoId, setup_pool}; +use defguard_common::db::{ + Id, NoId, + models::{ + Device, DeviceType, WireguardNetwork, device::WireguardNetworkDevice, group::Group, + user::User, + }, + setup_pool, +}; use defguard_proto::enterprise::firewall::{ FirewallPolicy, IpAddress, IpRange, IpVersion, Port, PortRange as PortRangeProto, Protocol, ip_address::Address, port::Port as PortInner, @@ -18,18 +25,12 @@ use super::{ find_largest_subnet_in_range, get_last_ip_in_v6_subnet, get_source_users, merge_addrs, merge_port_ranges, process_destination_addrs, }; -use crate::{ - db::{ - Device, Group, User, WireguardNetwork, - models::device::{DeviceType, WireguardNetworkDevice}, - }, - enterprise::{ - db::models::acl::{ - AclAlias, AclRule, AclRuleAlias, AclRuleDestinationRange, AclRuleDevice, AclRuleGroup, - AclRuleInfo, AclRuleNetwork, AclRuleUser, AliasKind, PortRange, RuleState, - }, - firewall::{get_source_addrs, get_source_network_devices}, +use crate::enterprise::{ + db::models::acl::{ + AclAlias, AclRule, AclRuleAlias, AclRuleDestinationRange, AclRuleDevice, AclRuleGroup, + AclRuleInfo, AclRuleNetwork, AclRuleUser, AliasKind, PortRange, RuleState, }, + firewall::{get_source_addrs, get_source_network_devices, try_get_location_firewall_config}, }; impl Default for AclRuleDestinationRange { @@ -1175,14 +1176,15 @@ async fn test_generate_firewall_rules_ipv4(_: PgPoolOptions, options: PgConnectO // try to generate firewall config with ACL disabled location.acl_enabled = false; - let generated_firewall_config = location.try_get_firewall_config(&mut conn).await.unwrap(); + let generated_firewall_config = try_get_location_firewall_config(&location, &mut conn) + .await + .unwrap(); assert!(generated_firewall_config.is_none()); // generate firewall config with default policy Allow location.acl_enabled = true; location.acl_default_allow = true; - let generated_firewall_config = location - .try_get_firewall_config(&mut conn) + let generated_firewall_config = try_get_location_firewall_config(&location, &mut conn) .await .unwrap() .unwrap(); @@ -1596,14 +1598,15 @@ async fn test_generate_firewall_rules_ipv6(_: PgPoolOptions, options: PgConnectO // try to generate firewall config with ACL disabled location.acl_enabled = false; - let generated_firewall_config = location.try_get_firewall_config(&mut conn).await.unwrap(); + let generated_firewall_config = try_get_location_firewall_config(&location, &mut conn) + .await + .unwrap(); assert!(generated_firewall_config.is_none()); // generate firewall config with default policy Allow location.acl_enabled = true; location.acl_default_allow = true; - let generated_firewall_config = location - .try_get_firewall_config(&mut conn) + let generated_firewall_config = try_get_location_firewall_config(&location, &mut conn) .await .unwrap() .unwrap(); @@ -2062,14 +2065,15 @@ async fn test_generate_firewall_rules_ipv4_and_ipv6(_: PgPoolOptions, options: P // try to generate firewall config with ACL disabled location.acl_enabled = false; - let generated_firewall_config = location.try_get_firewall_config(&mut conn).await.unwrap(); + let generated_firewall_config = try_get_location_firewall_config(&location, &mut conn) + .await + .unwrap(); assert!(generated_firewall_config.is_none()); // generate firewall config with default policy Allow location.acl_enabled = true; location.acl_default_allow = true; - let generated_firewall_config = location - .try_get_firewall_config(&mut conn) + let generated_firewall_config = try_get_location_firewall_config(&location, &mut conn) .await .unwrap() .unwrap(); @@ -2460,8 +2464,7 @@ async fn test_expired_acl_rules_ipv4(_: PgPoolOptions, options: PgConnectOptions } let mut conn = pool.acquire().await.unwrap(); - let generated_firewall_rules = location - .try_get_firewall_config(&mut conn) + let generated_firewall_rules = try_get_location_firewall_config(&location, &mut conn) .await .unwrap() .unwrap() @@ -2477,8 +2480,7 @@ async fn test_expired_acl_rules_ipv4(_: PgPoolOptions, options: PgConnectOptions acl_rule_2.expires = Some(NaiveDateTime::MAX); acl_rule_2.save(&pool).await.unwrap(); - let generated_firewall_rules = location - .try_get_firewall_config(&mut conn) + let generated_firewall_rules = try_get_location_firewall_config(&location, &mut conn) .await .unwrap() .unwrap() @@ -2531,8 +2533,7 @@ async fn test_expired_acl_rules_ipv6(_: PgPoolOptions, options: PgConnectOptions } let mut conn = pool.acquire().await.unwrap(); - let generated_firewall_rules = location - .try_get_firewall_config(&mut conn) + let generated_firewall_rules = try_get_location_firewall_config(&location, &mut conn) .await .unwrap() .unwrap() @@ -2548,8 +2549,7 @@ async fn test_expired_acl_rules_ipv6(_: PgPoolOptions, options: PgConnectOptions acl_rule_2.expires = Some(NaiveDateTime::MAX); acl_rule_2.save(&pool).await.unwrap(); - let generated_firewall_rules = location - .try_get_firewall_config(&mut conn) + let generated_firewall_rules = try_get_location_firewall_config(&location, &mut conn) .await .unwrap() .unwrap() @@ -2605,8 +2605,7 @@ async fn test_expired_acl_rules_ipv4_and_ipv6(_: PgPoolOptions, options: PgConne } let mut conn = pool.acquire().await.unwrap(); - let generated_firewall_rules = location - .try_get_firewall_config(&mut conn) + let generated_firewall_rules = try_get_location_firewall_config(&location, &mut conn) .await .unwrap() .unwrap() @@ -2622,8 +2621,7 @@ async fn test_expired_acl_rules_ipv4_and_ipv6(_: PgPoolOptions, options: PgConne acl_rule_2.expires = Some(NaiveDateTime::MAX); acl_rule_2.save(&pool).await.unwrap(); - let generated_firewall_rules = location - .try_get_firewall_config(&mut conn) + let generated_firewall_rules = try_get_location_firewall_config(&location, &mut conn) .await .unwrap() .unwrap() @@ -2675,8 +2673,7 @@ async fn test_disabled_acl_rules_ipv4(_: PgPoolOptions, options: PgConnectOption } let mut conn = pool.acquire().await.unwrap(); - let generated_firewall_rules = location - .try_get_firewall_config(&mut conn) + let generated_firewall_rules = try_get_location_firewall_config(&location, &mut conn) .await .unwrap() .unwrap() @@ -2692,8 +2689,7 @@ async fn test_disabled_acl_rules_ipv4(_: PgPoolOptions, options: PgConnectOption acl_rule_2.enabled = true; acl_rule_2.save(&pool).await.unwrap(); - let generated_firewall_rules = location - .try_get_firewall_config(&mut conn) + let generated_firewall_rules = try_get_location_firewall_config(&location, &mut conn) .await .unwrap() .unwrap() @@ -2746,8 +2742,7 @@ async fn test_disabled_acl_rules_ipv6(_: PgPoolOptions, options: PgConnectOption } let mut conn = pool.acquire().await.unwrap(); - let generated_firewall_rules = location - .try_get_firewall_config(&mut conn) + let generated_firewall_rules = try_get_location_firewall_config(&location, &mut conn) .await .unwrap() .unwrap() @@ -2763,8 +2758,7 @@ async fn test_disabled_acl_rules_ipv6(_: PgPoolOptions, options: PgConnectOption acl_rule_2.enabled = true; acl_rule_2.save(&pool).await.unwrap(); - let generated_firewall_rules = location - .try_get_firewall_config(&mut conn) + let generated_firewall_rules = try_get_location_firewall_config(&location, &mut conn) .await .unwrap() .unwrap() @@ -2820,8 +2814,7 @@ async fn test_disabled_acl_rules_ipv4_and_ipv6(_: PgPoolOptions, options: PgConn } let mut conn = pool.acquire().await.unwrap(); - let generated_firewall_rules = location - .try_get_firewall_config(&mut conn) + let generated_firewall_rules = try_get_location_firewall_config(&location, &mut conn) .await .unwrap() .unwrap() @@ -2837,8 +2830,7 @@ async fn test_disabled_acl_rules_ipv4_and_ipv6(_: PgPoolOptions, options: PgConn acl_rule_2.enabled = true; acl_rule_2.save(&pool).await.unwrap(); - let generated_firewall_rules = location - .try_get_firewall_config(&mut conn) + let generated_firewall_rules = try_get_location_firewall_config(&location, &mut conn) .await .unwrap() .unwrap() @@ -2890,8 +2882,7 @@ async fn test_unapplied_acl_rules_ipv4(_: PgPoolOptions, options: PgConnectOptio } let mut conn = pool.acquire().await.unwrap(); - let generated_firewall_rules = location - .try_get_firewall_config(&mut conn) + let generated_firewall_rules = try_get_location_firewall_config(&location, &mut conn) .await .unwrap() .unwrap() @@ -2907,8 +2898,7 @@ async fn test_unapplied_acl_rules_ipv4(_: PgPoolOptions, options: PgConnectOptio acl_rule_2.state = RuleState::Applied; acl_rule_2.save(&pool).await.unwrap(); - let generated_firewall_rules = location - .try_get_firewall_config(&mut conn) + let generated_firewall_rules = try_get_location_firewall_config(&location, &mut conn) .await .unwrap() .unwrap() @@ -2961,8 +2951,7 @@ async fn test_unapplied_acl_rules_ipv6(_: PgPoolOptions, options: PgConnectOptio } let mut conn = pool.acquire().await.unwrap(); - let generated_firewall_rules = location - .try_get_firewall_config(&mut conn) + let generated_firewall_rules = try_get_location_firewall_config(&location, &mut conn) .await .unwrap() .unwrap() @@ -2978,8 +2967,7 @@ async fn test_unapplied_acl_rules_ipv6(_: PgPoolOptions, options: PgConnectOptio acl_rule_2.state = RuleState::Applied; acl_rule_2.save(&pool).await.unwrap(); - let generated_firewall_rules = location - .try_get_firewall_config(&mut conn) + let generated_firewall_rules = try_get_location_firewall_config(&location, &mut conn) .await .unwrap() .unwrap() @@ -3035,8 +3023,7 @@ async fn test_unapplied_acl_rules_ipv4_and_ipv6(_: PgPoolOptions, options: PgCon } let mut conn = pool.acquire().await.unwrap(); - let generated_firewall_rules = location - .try_get_firewall_config(&mut conn) + let generated_firewall_rules = try_get_location_firewall_config(&location, &mut conn) .await .unwrap() .unwrap() @@ -3052,8 +3039,7 @@ async fn test_unapplied_acl_rules_ipv4_and_ipv6(_: PgPoolOptions, options: PgCon acl_rule_2.state = RuleState::Applied; acl_rule_2.save(&pool).await.unwrap(); - let generated_firewall_rules = location - .try_get_firewall_config(&mut conn) + let generated_firewall_rules = try_get_location_firewall_config(&location, &mut conn) .await .unwrap() .unwrap() @@ -3191,8 +3177,7 @@ async fn test_acl_rules_all_locations_ipv4(_: PgPoolOptions, options: PgConnectO } let mut conn = pool.acquire().await.unwrap(); - let generated_firewall_rules = location_1 - .try_get_firewall_config(&mut conn) + let generated_firewall_rules = try_get_location_firewall_config(&location_1, &mut conn) .await .unwrap() .unwrap() @@ -3201,8 +3186,7 @@ async fn test_acl_rules_all_locations_ipv4(_: PgPoolOptions, options: PgConnectO // both rules were assigned to this location assert_eq!(generated_firewall_rules.len(), 4); - let generated_firewall_rules = location_2 - .try_get_firewall_config(&mut conn) + let generated_firewall_rules = try_get_location_firewall_config(&location_2, &mut conn) .await .unwrap() .unwrap() @@ -3353,8 +3337,7 @@ async fn test_acl_rules_all_locations_ipv6(_: PgPoolOptions, options: PgConnectO } let mut conn = pool.acquire().await.unwrap(); - let generated_firewall_rules = location_1 - .try_get_firewall_config(&mut conn) + let generated_firewall_rules = try_get_location_firewall_config(&location_1, &mut conn) .await .unwrap() .unwrap() @@ -3363,8 +3346,7 @@ async fn test_acl_rules_all_locations_ipv6(_: PgPoolOptions, options: PgConnectO // both rules were assigned to this location assert_eq!(generated_firewall_rules.len(), 4); - let generated_firewall_rules = location_2 - .try_get_firewall_config(&mut conn) + let generated_firewall_rules = try_get_location_firewall_config(&location_2, &mut conn) .await .unwrap() .unwrap() @@ -3529,8 +3511,7 @@ async fn test_acl_rules_all_locations_ipv4_and_ipv6(_: PgPoolOptions, options: P } let mut conn = pool.acquire().await.unwrap(); - let generated_firewall_rules = location_1 - .try_get_firewall_config(&mut conn) + let generated_firewall_rules = try_get_location_firewall_config(&location_1, &mut conn) .await .unwrap() .unwrap() @@ -3539,8 +3520,7 @@ async fn test_acl_rules_all_locations_ipv4_and_ipv6(_: PgPoolOptions, options: P // both rules were assigned to this location assert_eq!(generated_firewall_rules.len(), 8); - let generated_firewall_rules = location_2 - .try_get_firewall_config(&mut conn) + let generated_firewall_rules = try_get_location_firewall_config(&location_2, &mut conn) .await .unwrap() .unwrap() @@ -3658,8 +3638,7 @@ async fn test_alias_kinds(_: PgPoolOptions, options: PgConnectOptions) { obj.save(&pool).await.unwrap(); let mut conn = pool.acquire().await.unwrap(); - let generated_firewall_rules = location - .try_get_firewall_config(&mut conn) + let generated_firewall_rules = try_get_location_firewall_config(&location, &mut conn) .await .unwrap() .unwrap() @@ -3854,8 +3833,7 @@ async fn test_destination_alias_only_acl(_: PgPoolOptions, options: PgConnectOpt obj.save(&pool).await.unwrap(); let mut conn = pool.acquire().await.unwrap(); - let generated_firewall_rules = location - .try_get_firewall_config(&mut conn) + let generated_firewall_rules = try_get_location_firewall_config(&location, &mut conn) .await .unwrap() .unwrap() diff --git a/crates/defguard_core/src/enterprise/grpc/polling.rs b/crates/defguard_core/src/enterprise/grpc/polling.rs index 8e04e9a411..fc8960d5d0 100644 --- a/crates/defguard_core/src/enterprise/grpc/polling.rs +++ b/crates/defguard_core/src/enterprise/grpc/polling.rs @@ -1,13 +1,12 @@ -use defguard_common::db::Id; +use defguard_common::db::{ + Id, + models::{Device, polling_token::PollingToken, user::User}, +}; use defguard_proto::proxy::{DeviceInfo, InstanceInfoRequest, InstanceInfoResponse}; use sqlx::PgPool; use tonic::Status; -use crate::{ - db::{Device, User, models::polling_token::PollingToken}, - enterprise::is_enterprise_enabled, - grpc::utils::build_device_config_response, -}; +use crate::{enterprise::is_enterprise_enabled, grpc::utils::build_device_config_response}; pub struct PollingServer { pool: PgPool, diff --git a/crates/defguard_core/src/enterprise/handlers/api_tokens.rs b/crates/defguard_core/src/enterprise/handlers/api_tokens.rs index 15a40ecfda..5a1d5923c7 100644 --- a/crates/defguard_core/src/enterprise/handlers/api_tokens.rs +++ b/crates/defguard_core/src/enterprise/handlers/api_tokens.rs @@ -4,14 +4,13 @@ use axum::{ http::StatusCode, }; use chrono::Utc; -use defguard_common::random::gen_alphanumeric; +use defguard_common::{db::models::user::User, random::gen_alphanumeric}; use serde_json::json; use super::LicenseInfo; use crate::{ appstate::AppState, auth::{AdminRole, SessionInfo}, - db::User, enterprise::db::models::api_tokens::{ApiToken, ApiTokenInfo}, error::WebError, events::{ApiEvent, ApiEventType, ApiRequestContext}, diff --git a/crates/defguard_core/src/enterprise/handlers/openid_login.rs b/crates/defguard_core/src/enterprise/handlers/openid_login.rs index 6b506f5223..b17860caa2 100644 --- a/crates/defguard_core/src/enterprise/handlers/openid_login.rs +++ b/crates/defguard_core/src/enterprise/handlers/openid_login.rs @@ -13,7 +13,7 @@ use defguard_common::{ config::server_config, db::{ Id, - models::{Settings, settings::OpenidUsernameHandling}, + models::{Settings, settings::OpenidUsernameHandling, user::User}, }, }; use openidconnect::{ @@ -36,7 +36,6 @@ pub(crate) const SELECT_ACCOUNT_SUPPORTED_PROVIDERS: &[&str] = &["Google"]; use super::LicenseInfo; use crate::{ appstate::AppState, - db::User, enterprise::{ db::models::openid_provider::OpenIdProvider, directory_sync::sync_user_groups_if_configured, ldap::utils::ldap_update_user_state, diff --git a/crates/defguard_core/src/enterprise/handlers/openid_providers.rs b/crates/defguard_core/src/enterprise/handlers/openid_providers.rs index 01cef376e6..c236209b62 100644 --- a/crates/defguard_core/src/enterprise/handlers/openid_providers.rs +++ b/crates/defguard_core/src/enterprise/handlers/openid_providers.rs @@ -4,8 +4,9 @@ use axum::{ http::StatusCode, }; use defguard_common::db::models::{ - Settings, + Settings, WireguardNetwork, settings::{OpenidUsernameHandling, update_current_settings}, + wireguard::LocationMfaMode, }; use rsa::{RsaPrivateKey, pkcs8::DecodePrivateKey}; use serde_json::json; @@ -14,7 +15,6 @@ use super::LicenseInfo; use crate::{ appstate::AppState, auth::{AdminRole, SessionInfo}, - db::{WireguardNetwork, models::wireguard::LocationMfaMode}, enterprise::{ db::models::openid_provider::OpenIdProvider, directory_sync::test_directory_sync_connection, }, diff --git a/crates/defguard_core/src/enterprise/ldap/client.rs b/crates/defguard_core/src/enterprise/ldap/client.rs index 0524f3b649..35567bfd66 100644 --- a/crates/defguard_core/src/enterprise/ldap/client.rs +++ b/crates/defguard_core/src/enterprise/ldap/client.rs @@ -4,14 +4,14 @@ use std::{ time::Duration, }; -use defguard_common::db::models::Settings; +use defguard_common::db::models::{Settings, User}; use ldap3::{ LdapConnAsync, LdapConnSettings, Mod, Scope, SearchEntry, adapters::PagedResults, drive, ldap_escape, }; use super::error::LdapError; -use crate::{db::User, enterprise::ldap::model::extract_rdn_value}; +use crate::enterprise::ldap::model::extract_rdn_value; impl super::LDAPConnection { pub(crate) async fn create() -> Result { diff --git a/crates/defguard_core/src/enterprise/ldap/mod.rs b/crates/defguard_core/src/enterprise/ldap/mod.rs index 17152ee73d..5fefec490e 100644 --- a/crates/defguard_core/src/enterprise/ldap/mod.rs +++ b/crates/defguard_core/src/enterprise/ldap/mod.rs @@ -3,7 +3,8 @@ use std::{collections::HashSet, future::Future}; use defguard_common::db::{ Id, models::{ - Settings, + Settings, User, + group::Group, settings::{LdapSyncStatus, update_current_settings}, }, }; @@ -16,9 +17,13 @@ use sqlx::PgPool; use sync::{get_ldap_sync_status, is_ldap_desynced, set_ldap_sync_status}; use self::error::LdapError; -use crate::{ - db::{self, User}, - enterprise::{is_enterprise_enabled, ldap::model::extract_dn_path, limits::update_counts}, +use crate::enterprise::{ + is_enterprise_enabled, + ldap::model::{ + extract_dn_path, ldap_sync_allowed_for_user, user_as_ldap_attrs, user_as_ldap_mod, + user_from_searchentry, + }, + limits::update_counts, }; #[cfg(not(test))] @@ -357,7 +362,7 @@ impl LDAPConnection { debug!("Updating users state in LDAP"); for user in users { - let user_sync_allowed = user.ldap_sync_allowed(pool).await?; + let user_sync_allowed = ldap_sync_allowed_for_user(user, pool).await?; let user_exists_in_ldap = self.user_exists(user).await?; let user_groups = user.member_of_names(pool).await?; let user_in_sync_groups = self.user_in_ldap_sync_groups(user).await?; @@ -405,7 +410,7 @@ impl LDAPConnection { /// Checks if user belongs to one of the defined sync groups in the LDAP server. /// Returns true if no sync groups are defined (sync all users) or if user is in at least one sync group. async fn user_in_ldap_sync_groups(&mut self, user: &User) -> Result { - debug!("Checking if user {} is in LDAP sync groups", user.username); + debug!("Checking if user {user} is in LDAP sync groups"); // Sync groups empty, we should sync all users if self.config.ldap_sync_groups.is_empty() { @@ -539,7 +544,7 @@ impl LDAPConnection { if let Some(entry) = entries.pop() { info!("Performed LDAP user search: {username}"); self.test_bind_user(&entry.dn, password).await?; - User::from_searchentry(&entry, username, Some(password)) + user_from_searchentry(&entry, username, Some(password)) } else { Err(LdapError::ObjectNotFound(format!( "User {username} not found", @@ -549,9 +554,8 @@ impl LDAPConnection { /// Retrieves user from LDAP by username. /// Returns an error if multiple users are found or if the user doesn't exist. - pub async fn get_user_by_username(&mut self, username: &User) -> Result { + pub async fn get_user_by_username(&mut self, username: &str) -> Result { debug!("Performing LDAP user search by username: {username}"); - let username = &username.username; let username_escape = ldap_escape(username); let mut entries = self .search_users(&format!( @@ -564,7 +568,7 @@ impl LDAPConnection { } if let Some(entry) = entries.pop() { info!("Performed LDAP user search by username: {username}"); - User::from_searchentry(&entry, username, None) + user_from_searchentry(&entry, username, None) } else { Err(LdapError::ObjectNotFound(format!( "User {username} not found", @@ -580,7 +584,7 @@ impl LDAPConnection { match self.get(&dn).await? { Some(entry) => { info!("Found LDAP user with DN: {}", dn); - User::from_searchentry(&entry, &user.username, None) + user_from_searchentry(&entry, &user.username, None) } None => Err(LdapError::ObjectNotFound(format!("User {dn} not found",))), } @@ -596,26 +600,23 @@ impl LDAPConnection { password: Option<&str>, pool: &PgPool, ) -> Result<(), LdapError> { - debug!("Adding LDAP user {}", user.username); + debug!("Adding LDAP user {user}"); let user_dn = self.config.user_dn_from_user(user); let password_is_random = password.is_none(); let password = if let Some(password) = password { - debug!("Using provided password for user {}", user.username); + debug!("Using provided password for user {user}"); password.to_string() } else { // ldap may not accept no password, this is a workaround when we don't have access to the // user's password - debug!( - "Generating random password for user {}, as no password has been specified", - user.username - ); + debug!("Generating random password for user {user}, as no password has been specified",); let random_password = rand::thread_rng() .sample_iter(&rand::distributions::Alphanumeric) .take(32) .map(char::from) .collect::(); - debug!("Generated random password for user {}", user.username); + debug!("Generated random password for user {user}"); random_password }; let ssha_password = hash::salted_sha1_hash(&password); @@ -633,7 +634,8 @@ impl LDAPConnection { } self.add( &user_dn, - user.as_ldap_attrs( + user_as_ldap_attrs( + user, &ssha_password, &nt_password, user_obj_classes.iter().map(String::as_str).collect(), @@ -652,7 +654,7 @@ impl LDAPConnection { user.ldap_pass_randomized = true; } user.save(pool).await?; - info!("Added LDAP user {}", user.username); + info!("Added LDAP user {user}"); Ok(()) } @@ -688,7 +690,7 @@ impl LDAPConnection { let old_dn = self.config.user_dn(old_rdn, user_dn_path); let new_dn = self.config.user_dn(new_rdn, user_dn_path); let config = self.config.clone(); - let mods = user.as_ldap_mod(&config); + let mods = user_as_ldap_mod(user, &config); self.modify(&old_dn, &new_dn, mods).await?; info!("Modified user {old_username} in LDAP"); @@ -861,7 +863,7 @@ impl LDAPConnection { pub async fn modify_group( &mut self, groupname: &str, - group: &db::Group, + group: &Group, ) -> Result<(), LdapError> { debug!("Modifying LDAP group {groupname}"); let old_dn = self.config.group_dn(groupname); diff --git a/crates/defguard_core/src/enterprise/ldap/model.rs b/crates/defguard_core/src/enterprise/ldap/model.rs index 0da8afc1c7..a99d5709ee 100644 --- a/crates/defguard_core/src/enterprise/ldap/model.rs +++ b/crates/defguard_core/src/enterprise/ldap/model.rs @@ -1,11 +1,14 @@ use std::collections::HashSet; -use defguard_common::db::{Id, models::Settings}; +use defguard_common::db::{ + Id, + models::{Settings, User}, +}; use ldap3::{Mod, SearchEntry}; use sqlx::{Error as SqlxError, PgExecutor}; use super::{LDAPConfig, error::LdapError}; -use crate::{db::User, handlers::user::check_username, hashset}; +use crate::{handlers::user::check_username, hashset}; pub(crate) enum UserObjectClass { SambaSamAccount, @@ -51,246 +54,245 @@ impl PartialEq for &str { } } -impl User { - pub fn from_searchentry( - entry: &SearchEntry, - username: &str, - password: Option<&str>, - ) -> Result { - let mut user = Self::new( - username.into(), - password, - get_value_or_error(entry, "sn")?, - get_value_or_error(entry, "givenName")?, - get_value_or_error(entry, "mail")?, - get_value(entry, "mobile"), - ); - user.from_ldap = true; - if let Some(rdn) = extract_rdn_value(&entry.dn) { - user.ldap_rdn = Some(rdn); - } else { - return Err(LdapError::InvalidDN(entry.dn.clone())); - } - if let Some(dn_path) = extract_dn_path(&entry.dn) { - user.ldap_user_path = Some(dn_path); - } else { - return Err(LdapError::InvalidDN(entry.dn.clone())); - } - // Print the warning only if everything else checks out - if check_username(username).is_err() { - warn!( - "LDAP User \"{username}\" has username that cannot be used in Defguard, \ +pub(crate) fn user_from_searchentry( + entry: &SearchEntry, + username: &str, + password: Option<&str>, +) -> Result { + let mut user = User::new( + username.into(), + password, + get_value_or_error(entry, "sn")?, + get_value_or_error(entry, "givenName")?, + get_value_or_error(entry, "mail")?, + get_value(entry, "mobile"), + ); + user.from_ldap = true; + if let Some(rdn) = extract_rdn_value(&entry.dn) { + user.ldap_rdn = Some(rdn); + } else { + return Err(LdapError::InvalidDN(entry.dn.clone())); + } + if let Some(dn_path) = extract_dn_path(&entry.dn) { + user.ldap_user_path = Some(dn_path); + } else { + return Err(LdapError::InvalidDN(entry.dn.clone())); + } + // Print the warning only if everything else checks out + if check_username(username).is_err() { + warn!( + "LDAP User \"{username}\" has username that cannot be used in Defguard, \ change the LDAP username attribute or change the username in LDAP to a valid one", - ); - return Err(LdapError::InvalidUsername(username.to_string())); - } - Ok(user) + ); + return Err(LdapError::InvalidUsername(username.to_string())); } + Ok(user) } -impl User { - pub(crate) fn update_from_ldap_user(&mut self, ldap_user: &User, config: &LDAPConfig) { - self.last_name.clone_from(&ldap_user.last_name); - self.first_name.clone_from(&ldap_user.first_name); - self.email.clone_from(&ldap_user.email); - self.phone.clone_from(&ldap_user.phone); - // It should be ok to update the username if we are not using it in the DN (not as RDN) - if config.using_username_as_rdn() { - debug!( - "Not updating username {} from LDAP because it is used as RDN", - self.username - ); - } else { - self.username.clone_from(&ldap_user.username); - } +pub(crate) fn update_from_ldap_user(user: &mut User, ldap_user: &User, config: &LDAPConfig) { + user.last_name.clone_from(&ldap_user.last_name); + user.first_name.clone_from(&ldap_user.first_name); + user.email.clone_from(&ldap_user.email); + user.phone.clone_from(&ldap_user.phone); + // It should be ok to update the username if we are not using it in the DN (not as RDN) + if config.using_username_as_rdn() { + debug!( + "Not updating username {} from LDAP because it is used as RDN", + user.username + ); + } else { + user.username.clone_from(&ldap_user.username); } +} - #[must_use] - pub fn as_ldap_mod<'a>(&'a self, config: &'a LDAPConfig) -> Vec> { - let obj_classes = config.get_all_user_obj_classes(); - let mut changes = vec![]; - if obj_classes.contains(&UserObjectClass::InetOrgPerson.into()) - || obj_classes.contains(&UserObjectClass::User.into()) - { - changes.extend_from_slice(&[ - Mod::Replace("sn", hashset![self.last_name.as_str()]), - Mod::Replace("givenName", hashset![self.first_name.as_str()]), - Mod::Replace("mail", hashset![self.email.as_str()]), - ]); - - // Allow renaming the user if the CN is not a part of the RDN - if !config.get_rdn_attr().eq_ignore_ascii_case("cn") { - changes.push(Mod::Replace("cn", hashset![self.username.as_str()])); - } - - if !config.ldap_username_attr.eq_ignore_ascii_case("uid") - && !config - .ldap_user_rdn_attr - .as_ref() - .is_some_and(|rdn_attr| rdn_attr.eq_ignore_ascii_case("uid")) - { - changes.push(Mod::Replace("uid", hashset![self.username.as_str()])); - } - - if let Some(phone) = &self.phone { - if phone.is_empty() { - changes.push(Mod::Replace("mobile", HashSet::new())); - } else { - changes.push(Mod::Replace("mobile", hashset![phone.as_str()])); - } - } - } else { - warn!( - "No user object class found for user {}, can't generate mods", - self.username - ); - } - - if config.ldap_uses_ad && !config.get_rdn_attr().eq_ignore_ascii_case("sAMAccountName") { - changes.push(Mod::Replace( - "sAMAccountName", - hashset![self.username.as_str()], - )); +#[must_use] +pub fn user_as_ldap_mod<'a, I>(user: &'a User, config: &'a LDAPConfig) -> Vec> { + let obj_classes = config.get_all_user_obj_classes(); + let mut changes = vec![]; + if obj_classes.contains(&UserObjectClass::InetOrgPerson.into()) + || obj_classes.contains(&UserObjectClass::User.into()) + { + changes.extend_from_slice(&[ + Mod::Replace("sn", hashset![user.last_name.as_str()]), + Mod::Replace("givenName", hashset![user.first_name.as_str()]), + Mod::Replace("mail", hashset![user.email.as_str()]), + ]); + + // Allow renaming the user if the CN is not a part of the RDN + if !config.get_rdn_attr().eq_ignore_ascii_case("cn") { + changes.push(Mod::Replace("cn", hashset![user.username.as_str()])); } - let username_attr = config.ldap_username_attr.as_str(); - // Add anything the user provided, if we haven't already added it AND it's not the same as - // the RDN. - if !username_attr.eq_ignore_ascii_case("sAMAccountName") - && !username_attr.eq_ignore_ascii_case("cn") + if !config.ldap_username_attr.eq_ignore_ascii_case("uid") && !config .ldap_user_rdn_attr .as_ref() - .is_some_and(|rdn_attr| rdn_attr.eq_ignore_ascii_case(username_attr)) + .is_some_and(|rdn_attr| rdn_attr.eq_ignore_ascii_case("uid")) { - changes.push(Mod::Replace( - username_attr, - hashset![self.username.as_str()], - )); + changes.push(Mod::Replace("uid", hashset![user.username.as_str()])); } - changes + if let Some(phone) = &user.phone { + if phone.is_empty() { + changes.push(Mod::Replace("mobile", HashSet::new())); + } else { + changes.push(Mod::Replace("mobile", hashset![phone.as_str()])); + } + } + } else { + warn!( + "No user object class found for user {}, can't generate mods", + user.username + ); } - // check if key is already in attrs, if not return false - #[cfg(test)] - pub(crate) fn in_attrs<'a>(attrs: &'a Vec<(&'a str, HashSet<&'a str>)>, key: &str) -> bool { - attrs.iter().any(|(k, _)| k.eq_ignore_ascii_case(key)) + if config.ldap_uses_ad && !config.get_rdn_attr().eq_ignore_ascii_case("sAMAccountName") { + changes.push(Mod::Replace( + "sAMAccountName", + hashset![user.username.as_str()], + )); } - #[cfg(not(test))] - fn in_attrs<'a>(attrs: &'a Vec<(&'a str, HashSet<&'a str>)>, key: &str) -> bool { - attrs.iter().any(|(k, _)| k.eq_ignore_ascii_case(key)) + let username_attr = config.ldap_username_attr.as_str(); + // Add anything the user provided, if we haven't already added it AND it's not the same as + // the RDN. + if !username_attr.eq_ignore_ascii_case("sAMAccountName") + && !username_attr.eq_ignore_ascii_case("cn") + && !config + .ldap_user_rdn_attr + .as_ref() + .is_some_and(|rdn_attr| rdn_attr.eq_ignore_ascii_case(username_attr)) + { + changes.push(Mod::Replace( + username_attr, + hashset![user.username.as_str()], + )); } - #[must_use] - pub fn as_ldap_attrs<'a>( - &'a self, - ssha_password: &'a str, - nt_password: &'a str, - object_classes: HashSet<&'a str>, - uses_ad: bool, - username_attr: &'a str, - rdn_attr: &'a str, - ) -> Vec<(&'a str, HashSet<&'a str>)> { - let mut attrs = vec![]; - attrs.push((rdn_attr, hashset![self.ldap_rdn_value()])); - if object_classes.contains(UserObjectClass::InetOrgPerson.into()) - || object_classes.contains(UserObjectClass::User.into()) - { - attrs.extend_from_slice(&[ - ("sn", hashset![self.last_name.as_str()]), - ("givenName", hashset![self.first_name.as_str()]), - ("mail", hashset![self.email.as_str()]), - ]); - - if !Self::in_attrs(&attrs, "cn") { - attrs.push(("cn", hashset![self.username.as_str()])); - } + changes +} - if !Self::in_attrs(&attrs, "uid") { - attrs.push(("uid", hashset![self.username.as_str()])); - } +// check if key is already in attrs, if not return false +#[cfg(test)] +pub(crate) fn in_attrs<'a>(attrs: &'a Vec<(&'a str, HashSet<&'a str>)>, key: &str) -> bool { + attrs.iter().any(|(k, _)| k.eq_ignore_ascii_case(key)) +} - if let Some(phone) = &self.phone { - if !phone.is_empty() { - attrs.push(("mobile", hashset![phone.as_str()])); - } - } - } - if object_classes.contains(UserObjectClass::SimpleSecurityObject.into()) { - // simpleSecurityObject - attrs.push(("userPassword", hashset![ssha_password])); - } - if object_classes.contains(UserObjectClass::SambaSamAccount.into()) { - // sambaSamAccount - attrs.push(("sambaSID", hashset!["0"])); - attrs.push(("sambaNTPassword", hashset![nt_password])); +#[cfg(not(test))] +fn in_attrs<'a>(attrs: &'a Vec<(&'a str, HashSet<&'a str>)>, key: &str) -> bool { + attrs.iter().any(|(k, _)| k.eq_ignore_ascii_case(key)) +} + +#[must_use] +pub fn user_as_ldap_attrs<'a, I>( + user: &'a User, + ssha_password: &'a str, + nt_password: &'a str, + object_classes: HashSet<&'a str>, + uses_ad: bool, + username_attr: &'a str, + rdn_attr: &'a str, +) -> Vec<(&'a str, HashSet<&'a str>)> { + let mut attrs = vec![]; + attrs.push((rdn_attr, hashset![user.ldap_rdn_value()])); + if object_classes.contains(UserObjectClass::InetOrgPerson.into()) + || object_classes.contains(UserObjectClass::User.into()) + { + attrs.extend_from_slice(&[ + ("sn", hashset![user.last_name.as_str()]), + ("givenName", hashset![user.first_name.as_str()]), + ("mail", hashset![user.email.as_str()]), + ]); + + if !in_attrs(&attrs, "cn") { + attrs.push(("cn", hashset![user.username.as_str()])); } - if uses_ad { - attrs.push(("sAMAccountName", hashset![self.username.as_str()])); + + if !in_attrs(&attrs, "uid") { + attrs.push(("uid", hashset![user.username.as_str()])); } - // Add the username attr and RDN if we haven't already added it - if !Self::in_attrs(&attrs, username_attr) { - attrs.push((username_attr, hashset![self.username.as_str()])); + if let Some(phone) = &user.phone { + if !phone.is_empty() { + attrs.push(("mobile", hashset![phone.as_str()])); + } } + } + if object_classes.contains(UserObjectClass::SimpleSecurityObject.into()) { + // simpleSecurityObject + attrs.push(("userPassword", hashset![ssha_password])); + } + if object_classes.contains(UserObjectClass::SambaSamAccount.into()) { + // sambaSamAccount + attrs.push(("sambaSID", hashset!["0"])); + attrs.push(("sambaNTPassword", hashset![nt_password])); + } + if uses_ad { + attrs.push(("sAMAccountName", hashset![user.username.as_str()])); + } - attrs.push(("objectClass", object_classes)); + // Add the username attr and RDN if we haven't already added it + if !in_attrs(&attrs, username_attr) { + attrs.push((username_attr, hashset![user.username.as_str()])); + } - debug!("Generated LDAP attributes: {attrs:?}"); + attrs.push(("objectClass", object_classes)); - attrs - } + debug!("Generated LDAP attributes: {attrs:?}"); - /// Updates the LDAP RDN value of the user in Defguard, if Defguard uses the usernames as RDN. - pub(crate) fn maybe_update_rdn(&mut self) { - debug!("Updating RDN for user {} in Defguard", self.username); - let settings = Settings::get_current_settings(); - if settings.ldap_using_username_as_rdn() { - debug!("The user's username is being used as the RDN, setting it to username"); - self.ldap_rdn = Some(self.username.clone()); - } else { - debug!("The user's username is NOT being used as the RDN, skipping update"); - } - } + attrs } -impl User { - /// User is syncable with LDAP if: - /// - he is in a group that is allowed to be synced or no such groups are configured - /// - he is active (not disabled) - /// - he is enrolled - pub(crate) async fn ldap_sync_allowed<'e, E>(&self, executor: E) -> Result - where - E: PgExecutor<'e>, - { - let sync_groups = Settings::get_current_settings().ldap_sync_groups; - let my_groups = self.member_of(executor).await?; - Ok( - (sync_groups.is_empty() || my_groups.iter().any(|g| sync_groups.contains(&g.name))) - && self.is_active - && self.is_enrolled(), - ) +/// Updates the LDAP RDN value of the user in Defguard, if Defguard uses the usernames as RDN. +pub(crate) fn maybe_update_rdn(user: &mut User) { + debug!("Updating RDN for user {} in Defguard", user.username); + let settings = Settings::get_current_settings(); + if settings.ldap_using_username_as_rdn() { + debug!("The user's username is being used as the RDN, setting it to username"); + user.ldap_rdn = Some(user.username.clone()); + } else { + debug!("The user's username is NOT being used as the RDN, skipping update"); } +} - pub(super) async fn get_without_ldap_path<'e, E>(executor: E) -> Result, SqlxError> - where - E: PgExecutor<'e>, - { - sqlx::query_as!( - Self, - " +/// User is syncable with LDAP if: +/// - he is in a group that is allowed to be synced or no such groups are configured +/// - he is active (not disabled) +/// - he is enrolled +pub(crate) async fn ldap_sync_allowed_for_user<'e, E>( + user: &User, + executor: E, +) -> Result +where + E: PgExecutor<'e>, +{ + let sync_groups = Settings::get_current_settings().ldap_sync_groups; + let my_groups = user.member_of(executor).await?; + Ok( + (sync_groups.is_empty() || my_groups.iter().any(|g| sync_groups.contains(&g.name))) + && user.is_active + && user.is_enrolled(), + ) +} + +pub(super) async fn get_users_without_ldap_path<'e, E>( + executor: E, +) -> Result>, SqlxError> +where + E: PgExecutor<'e>, +{ + sqlx::query_as!( + User, + " SELECT id, username, password_hash, last_name, first_name, email, phone, \ mfa_enabled, totp_enabled, email_mfa_enabled, totp_secret, email_mfa_secret, \ mfa_method \"mfa_method: _\", recovery_codes, is_active, openid_sub, \ from_ldap, ldap_pass_randomized, ldap_rdn, ldap_user_path, enrollment_pending \ FROM \"user\" WHERE ldap_user_path IS NULL ", - ) - .fetch_all(executor) - .await - } + ) + .fetch_all(executor) + .await } fn get_value_or_error(entry: &SearchEntry, key: &str) -> Result { @@ -346,42 +348,42 @@ mod tests { ]; // Test exact case match - assert!(User::<()>::in_attrs(&attrs, "cn")); - assert!(User::<()>::in_attrs(&attrs, "Mail")); - assert!(User::<()>::in_attrs(&attrs, "PHONE")); - assert!(User::<()>::in_attrs(&attrs, "givenName")); + assert!(in_attrs(&attrs, "cn")); + assert!(in_attrs(&attrs, "Mail")); + assert!(in_attrs(&attrs, "PHONE")); + assert!(in_attrs(&attrs, "givenName")); // Test case-insensitive matching - assert!(User::<()>::in_attrs(&attrs, "CN")); - assert!(User::<()>::in_attrs(&attrs, "cn")); - assert!(User::<()>::in_attrs(&attrs, "mail")); - assert!(User::<()>::in_attrs(&attrs, "MAIL")); - assert!(User::<()>::in_attrs(&attrs, "phone")); - assert!(User::<()>::in_attrs(&attrs, "Phone")); - assert!(User::<()>::in_attrs(&attrs, "GIVENNAME")); - assert!(User::<()>::in_attrs(&attrs, "givenname")); + assert!(in_attrs(&attrs, "CN")); + assert!(in_attrs(&attrs, "cn")); + assert!(in_attrs(&attrs, "mail")); + assert!(in_attrs(&attrs, "MAIL")); + assert!(in_attrs(&attrs, "phone")); + assert!(in_attrs(&attrs, "Phone")); + assert!(in_attrs(&attrs, "GIVENNAME")); + assert!(in_attrs(&attrs, "givenname")); // Test non-existent attributes - assert!(!User::<()>::in_attrs(&attrs, "nonexistent")); - assert!(!User::<()>::in_attrs(&attrs, "sn")); - assert!(!User::<()>::in_attrs(&attrs, "uid")); + assert!(!in_attrs(&attrs, "nonexistent")); + assert!(!in_attrs(&attrs, "sn")); + assert!(!in_attrs(&attrs, "uid")); // Test empty attributes vector let empty_attrs = vec![]; - assert!(!User::<()>::in_attrs(&empty_attrs, "cn")); - assert!(!User::<()>::in_attrs(&empty_attrs, "any")); + assert!(!in_attrs(&empty_attrs, "cn")); + assert!(!in_attrs(&empty_attrs, "any")); // Test with empty string key - assert!(!User::<()>::in_attrs(&attrs, "")); + assert!(!in_attrs(&attrs, "")); // Test with attributes that have empty values (should still match on key) let attrs_with_empty_values = vec![ ("cn", HashSet::new()), ("mail", hashset!["test@example.com"]), ]; - assert!(User::<()>::in_attrs(&attrs_with_empty_values, "cn")); - assert!(User::<()>::in_attrs(&attrs_with_empty_values, "CN")); - assert!(User::<()>::in_attrs(&attrs_with_empty_values, "mail")); - assert!(!User::<()>::in_attrs(&attrs_with_empty_values, "phone")); + assert!(in_attrs(&attrs_with_empty_values, "cn")); + assert!(in_attrs(&attrs_with_empty_values, "CN")); + assert!(in_attrs(&attrs_with_empty_values, "mail")); + assert!(!in_attrs(&attrs_with_empty_values, "phone")); } } diff --git a/crates/defguard_core/src/enterprise/ldap/sync.rs b/crates/defguard_core/src/enterprise/ldap/sync.rs index f1dbeb27e1..45416ec148 100644 --- a/crates/defguard_core/src/enterprise/ldap/sync.rs +++ b/crates/defguard_core/src/enterprise/ldap/sync.rs @@ -57,7 +57,8 @@ use std::collections::{HashMap, HashSet}; use defguard_common::db::{ Id, models::{ - Settings, + Settings, User, + group::Group, settings::{LdapSyncStatus, update_current_settings}, }, }; @@ -65,7 +66,10 @@ use sqlx::{PgConnection, PgPool}; use super::{LDAPConfig, error::LdapError}; use crate::{ - db::{Group, User}, + enterprise::ldap::model::{ + get_users_without_ldap_path, ldap_sync_allowed_for_user, update_from_ldap_user, + user_from_searchentry, + }, hashset, }; @@ -455,7 +459,7 @@ impl super::LDAPConnection { match authority { Authority::LDAP => { debug!("Applying LDAP user attributes to Defguard user"); - defguard_user.update_from_ldap_user(ldap_user, &self.config); + update_from_ldap_user(defguard_user, ldap_user, &self.config); defguard_user.save(&mut *transaction).await?; } Authority::Defguard => { @@ -546,11 +550,11 @@ impl super::LDAPConnection { debug!("Fixing missing user path in LDAP"); let mut transaction = pool.begin().await?; - let users = User::get_without_ldap_path(&mut *transaction).await?; + let users = get_users_without_ldap_path(&mut *transaction).await?; let mut filtered_users = Vec::new(); for user in users { - if user.ldap_sync_allowed(&mut *transaction).await? { + if ldap_sync_allowed_for_user(&user, &mut *transaction).await? { filtered_users.push(user); } } @@ -558,7 +562,7 @@ impl super::LDAPConnection { for mut defguard_user in users { if defguard_user.ldap_user_path.is_none() { - match self.get_user_by_username(&defguard_user).await { + match self.get_user_by_username(&defguard_user.username).await { Ok(ldap_user) => { debug!( "Found LDAP user {} with missing path in Defguard, fixing their path", @@ -650,7 +654,7 @@ impl super::LDAPConnection { // Filter out users that should be ignored from sync let mut filtered_users = Vec::new(); for user in all_defguard_users { - if user.ldap_sync_allowed(pool).await? { + if ldap_sync_allowed_for_user(&user, pool).await? { filtered_users.push(user); } } @@ -678,7 +682,7 @@ impl super::LDAPConnection { for group in defguard_groups { let mut members = HashSet::new(); for member in group.members(pool).await? { - if member.ldap_sync_allowed(pool).await? { + if ldap_sync_allowed_for_user(&member, pool).await? { members.insert(member); } } @@ -879,7 +883,7 @@ impl super::LDAPConnection { LdapError::ObjectNotFound(format!("No {username_attr} attribute found")) })?; - match User::from_searchentry(&entry, username, None) { + match user_from_searchentry(&entry, username, None) { Ok(user) => all_users.push(user), Err(err) => { warn!( diff --git a/crates/defguard_core/src/enterprise/ldap/test_client.rs b/crates/defguard_core/src/enterprise/ldap/test_client.rs index 05f2748645..d82132b3b2 100644 --- a/crates/defguard_core/src/enterprise/ldap/test_client.rs +++ b/crates/defguard_core/src/enterprise/ldap/test_client.rs @@ -4,13 +4,11 @@ use std::{ vec::Vec, }; +use defguard_common::db::models::{User, group::Group}; use ldap3::{Mod, SearchEntry}; use super::error::LdapError; -use crate::{ - db::{Group, User}, - enterprise::ldap::model::extract_rdn_value, -}; +use crate::enterprise::ldap::model::{extract_rdn_value, user_as_ldap_attrs}; /// Extract attribute value from LDAP filter /// @@ -141,8 +139,7 @@ impl Object { match self { Object::User(user) => SearchEntry { dn: dn.to_string(), - attrs: user - .to_test_attrs(None, config) + attrs: user_to_test_attrs(user, None, config) .into_iter() .map(|(k, v)| (k, v.into_iter().collect())) .collect(), @@ -150,8 +147,7 @@ impl Object { }, Object::Group(group) => SearchEntry { dn: dn.to_string(), - attrs: group - .to_test_attrs(config, None) + attrs: group_to_test_attrs(group, config, None) .into_iter() .map(|(k, v)| (k, v.into_iter().collect())) .collect(), @@ -489,7 +485,8 @@ impl super::LDAPConnection { .ldap_user_rdn_attr .clone() .unwrap_or(config.ldap_username_attr.clone()); - let attrs = user.as_ldap_attrs( + let attrs = user_as_ldap_attrs( + user, "", "", classes.iter().map(|s| s.as_str()).collect(), @@ -540,75 +537,72 @@ impl super::LDAPConnection { } #[cfg(test)] -impl User { - pub(super) fn to_test_attrs( - &self, - password: Option<&str>, - config: &super::LDAPConfig, - ) -> Vec<(String, HashSet)> { - let rdn_attr = config - .ldap_user_rdn_attr - .clone() - .unwrap_or(config.ldap_username_attr.clone()); - let classes = config.get_all_user_obj_classes(); - let ssha_password = if let Some(password) = &password { - super::hash::salted_sha1_hash(password) - } else { - String::new() - }; - let nt_password = if let Some(password) = &password { - super::hash::nthash(password) - } else { - String::new() - }; - self.as_ldap_attrs( - &ssha_password, - &nt_password, - classes.iter().map(|s| s.as_str()).collect(), - false, - &config.ldap_username_attr, - &rdn_attr, - ) - .into_iter() - .map(|(k, v)| (k.to_string(), v.iter().map(|s| s.to_string()).collect())) - .collect() - } +pub(super) fn user_to_test_attrs( + user: &User, + password: Option<&str>, + config: &super::LDAPConfig, +) -> Vec<(String, HashSet)> { + let rdn_attr = config + .ldap_user_rdn_attr + .clone() + .unwrap_or(config.ldap_username_attr.clone()); + let classes = config.get_all_user_obj_classes(); + let ssha_password = if let Some(password) = &password { + super::hash::salted_sha1_hash(password) + } else { + String::new() + }; + let nt_password = if let Some(password) = &password { + super::hash::nthash(password) + } else { + String::new() + }; + user_as_ldap_attrs( + user, + &ssha_password, + &nt_password, + classes.iter().map(|s| s.as_str()).collect(), + false, + &config.ldap_username_attr, + &rdn_attr, + ) + .into_iter() + .map(|(k, v)| (k.to_string(), v.iter().map(|s| s.to_string()).collect())) + .collect() } #[cfg(test)] -impl Group { - pub(super) fn to_test_attrs( - &self, - config: &super::LDAPConfig, - members: Option<&Vec<&User>>, - ) -> Vec<(String, HashSet)> { - use crate::hashset; - - let mut attrs = vec![ - ( - config.ldap_groupname_attr.clone(), - hashset![self.name.clone()], - ), - ( - "objectClass".to_string(), - hashset![config.ldap_group_obj_class.clone()], - ), - ]; - - if let Some(members) = members { - for user in members { - let user_dn = config.user_dn_from_user(user); - attrs.push((config.ldap_group_member_attr.clone(), hashset![user_dn])); - } +pub(super) fn group_to_test_attrs( + group: &Group, + config: &super::LDAPConfig, + members: Option<&Vec<&User>>, +) -> Vec<(String, HashSet)> { + use crate::hashset; + + let mut attrs = vec![ + ( + config.ldap_groupname_attr.clone(), + hashset![group.name.clone()], + ), + ( + "objectClass".to_string(), + hashset![config.ldap_group_obj_class.clone()], + ), + ]; + + if let Some(members) = members { + for user in members { + let user_dn = config.user_dn_from_user(user); + attrs.push((config.ldap_group_member_attr.clone(), hashset![user_dn])); } - - attrs } + + attrs } #[cfg(test)] mod tests { - use crate::db::User; + use defguard_common::db::models::User; #[tokio::test] async fn test_search_users_by_username() { diff --git a/crates/defguard_core/src/enterprise/ldap/tests.rs b/crates/defguard_core/src/enterprise/ldap/tests.rs index 580573c416..d2980b9498 100644 --- a/crates/defguard_core/src/enterprise/ldap/tests.rs +++ b/crates/defguard_core/src/enterprise/ldap/tests.rs @@ -5,16 +5,13 @@ use ldap3::SearchEntry; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use super::*; -use crate::{ - db::{Group, User}, - enterprise::ldap::{ - model::extract_rdn_value, - sync::{ - Authority, compute_group_sync_changes, compute_user_sync_changes, - extract_intersecting_users, - }, - test_client::LdapEvent, +use crate::enterprise::ldap::{ + model::{extract_rdn_value, get_users_without_ldap_path, user_from_searchentry}, + sync::{ + Authority, compute_group_sync_changes, compute_user_sync_changes, + extract_intersecting_users, }, + test_client::{LdapEvent, group_to_test_attrs, user_to_test_attrs}, }; const PASSWORD: &str = "test_password"; @@ -307,7 +304,11 @@ async fn test_update_users_state(_: PgPoolOptions, options: PgConnectOptions) { &[ LdapEvent::ObjectAdded { dn: ldap_conn.config.user_dn_from_user(&active_user_not_in_ldap), - attrs: active_user_not_in_ldap.to_test_attrs(Some(PASSWORD), &ldap_conn.config), + attrs: user_to_test_attrs( + &active_user_not_in_ldap, + Some(PASSWORD), + &ldap_conn.config + ), }, LdapEvent::ObjectDeleted { dn: ldap_conn.config.user_dn_from_user(&inactive_user_in_ldap), @@ -334,7 +335,11 @@ async fn test_update_users_state(_: PgPoolOptions, options: PgConnectOptions) { assert!(ldap_conn.test_client.events_match( &[LdapEvent::ObjectAdded { dn: ldap_conn.config.group_dn(&group.name), - attrs: group.to_test_attrs(&ldap_conn.config, Some(&vec![&active_user_in_ldap])), + attrs: group_to_test_attrs( + &group, + &ldap_conn.config, + Some(&vec![&active_user_in_ldap]) + ), }], false )); @@ -484,7 +489,10 @@ async fn test_get_user() { }; // By username - let result = ldap_conn.get_user_by_username(&test_user).await.unwrap(); + let result = ldap_conn + .get_user_by_username(&test_user.username) + .await + .unwrap(); check(result); // By DN @@ -493,7 +501,9 @@ async fn test_get_user() { // Non-existent user let non_existent_user = make_test_user("nonexistent", None, None); - let result = ldap_conn.get_user_by_username(&non_existent_user).await; + let result = ldap_conn + .get_user_by_username(&non_existent_user.username) + .await; assert!(result.is_err()); } @@ -2422,7 +2432,7 @@ async fn test_get_empty_user_path(_: PgPoolOptions, options: PgConnectOptions) { let user = make_test_user("testuser", None, None); let user = user.save(&pool).await.unwrap(); - let mut users = User::::get_without_ldap_path(&pool).await.unwrap(); + let mut users = get_users_without_ldap_path(&pool).await.unwrap(); let user_found = users.pop().unwrap(); assert_eq!(user_found.username, user.username); } @@ -2466,7 +2476,7 @@ fn test_from_searchentry() { bin_attrs: HashMap::new(), }; - let user = User::from_searchentry(&entry, "user1", Some("password123")).unwrap(); + let user = user_from_searchentry(&entry, "user1", Some("password123")).unwrap(); assert_eq!(user.username, "user1"); assert_eq!(user.last_name, "lastname1"); @@ -2489,7 +2499,7 @@ fn test_from_searchentry() { bin_attrs: HashMap::new(), }; - let user = User::from_searchentry(&entry, "user1", None).unwrap(); + let user = user_from_searchentry(&entry, "user1", None).unwrap(); assert_eq!(user.username, "user1"); assert_eq!(user.last_name, "lastname1"); @@ -2511,7 +2521,7 @@ fn test_from_searchentry() { bin_attrs: HashMap::new(), }; - let result = User::from_searchentry(&entry, "user1", None); + let result = user_from_searchentry(&entry, "user1", None); assert!(result.is_err()); assert!(matches!( result.unwrap_err(), @@ -2531,7 +2541,7 @@ fn test_from_searchentry() { bin_attrs: HashMap::new(), }; - let result = User::from_searchentry(&entry, "user1", None); + let result = user_from_searchentry(&entry, "user1", None); assert!(result.is_err()); assert!(matches!( result.unwrap_err(), @@ -2551,7 +2561,7 @@ fn test_from_searchentry() { bin_attrs: HashMap::new(), }; - let result = User::from_searchentry(&entry, "user1", None); + let result = user_from_searchentry(&entry, "user1", None); assert!(result.is_err()); assert!(matches!( result.unwrap_err(), @@ -2572,7 +2582,7 @@ fn test_from_searchentry() { bin_attrs: HashMap::new(), }; - let result = User::from_searchentry(&entry, "user1", None); + let result = user_from_searchentry(&entry, "user1", None); assert!(result.is_err()); assert!(matches!( result.unwrap_err(), @@ -2593,7 +2603,7 @@ fn test_from_searchentry() { bin_attrs: HashMap::new(), }; - let result = User::from_searchentry(&entry, "user1", None); + let result = user_from_searchentry(&entry, "user1", None); assert!(result.is_err()); assert!(matches!( result.unwrap_err(), @@ -2611,7 +2621,7 @@ fn test_from_searchentry() { bin_attrs: HashMap::new(), }; - let result = User::from_searchentry(&entry, "user1", None); + let result = user_from_searchentry(&entry, "user1", None); assert!(result.is_err()); assert!(matches!( result.unwrap_err(), @@ -2633,7 +2643,7 @@ fn test_from_searchentry() { }; // Test with invalid username (contains special characters) - let result = User::from_searchentry(&entry, "user@#$%", None); + let result = user_from_searchentry(&entry, "user@#$%", None); assert!(result.is_err()); assert!(matches!( result.unwrap_err(), @@ -2655,7 +2665,7 @@ fn test_from_searchentry() { bin_attrs: HashMap::new(), }; - let user = User::from_searchentry(&entry, "user1", Some("password123")).unwrap(); + let user = user_from_searchentry(&entry, "user1", Some("password123")).unwrap(); assert_eq!(user.username, "user1"); assert_eq!(user.last_name, "lastname1"); @@ -2683,7 +2693,7 @@ fn test_from_searchentry() { bin_attrs: HashMap::new(), }; - let user = User::from_searchentry(&entry, "user1", Some("mypassword")).unwrap(); + let user = user_from_searchentry(&entry, "user1", Some("mypassword")).unwrap(); assert_eq!(user.username, "user1"); assert!(user.password_hash.is_some()); @@ -2719,7 +2729,7 @@ fn test_from_searchentry() { bin_attrs: HashMap::new(), }; - let user = User::from_searchentry(&entry, "user1", None).unwrap(); + let user = user_from_searchentry(&entry, "user1", None).unwrap(); // Should use the first value when multiple values are present assert_eq!(user.last_name, "lastname1"); @@ -2742,7 +2752,7 @@ fn test_from_searchentry() { bin_attrs: HashMap::new(), }; - let user = User::from_searchentry(&entry, "testuser", None).unwrap(); + let user = user_from_searchentry(&entry, "testuser", None).unwrap(); // Verify LDAP-specific fields are properly set assert!(user.from_ldap); @@ -2766,7 +2776,8 @@ fn test_as_ldap_attrs() { ); // Basic test with InetOrgPerson - let attrs = user.as_ldap_attrs( + let attrs = user_as_ldap_attrs( + &user, "{SSHA}hashedpw", "NT_HASH", hashset![UserObjectClass::InetOrgPerson.into()], @@ -2783,7 +2794,8 @@ fn test_as_ldap_attrs() { assert!(attrs.contains(&("objectClass", hashset!["inetOrgPerson"]))); // Test with ActiveDirectory - let attrs = user.as_ldap_attrs( + let attrs = user_as_ldap_attrs( + &user, "{SSHA}hashedpw", "NT_HASH", hashset![UserObjectClass::User.into()], @@ -2795,7 +2807,8 @@ fn test_as_ldap_attrs() { assert!(attrs.contains(&("sAMAccountName", hashset!["testuser"]))); // Test with SimpleSecurityObject and SambaSamAccount - let attrs = user.as_ldap_attrs( + let attrs = user_as_ldap_attrs( + &user, "{SSHA}hashedpw", "NT_HASH", hashset![ @@ -2812,7 +2825,8 @@ fn test_as_ldap_attrs() { assert!(attrs.contains(&("sambaNTPassword", hashset!["NT_HASH"]))); // Test with custom RDN attribute - let attrs = user.as_ldap_attrs( + let attrs = user_as_ldap_attrs( + &user, "{SSHA}hashedpw", "NT_HASH", hashset![UserObjectClass::User.into()], @@ -2834,7 +2848,8 @@ fn test_as_ldap_attrs() { Some("".to_string()), ); - let attrs = user_no_phone.as_ldap_attrs( + let attrs = user_as_ldap_attrs( + &user_no_phone, "{SSHA}hashedpw", "NT_HASH", hashset![UserObjectClass::InetOrgPerson.into()], @@ -2867,7 +2882,7 @@ fn test_as_ldap_mod_inetorgperson() { ..Default::default() }; - let mods = user.as_ldap_mod(&config); + let mods = user_as_ldap_mod(&user, &config); assert!(mods.contains(&Mod::Replace("sn", hashset!["Smith"]))); assert!(mods.contains(&Mod::Replace("givenName", hashset!["John"]))); assert!(mods.contains(&Mod::Replace("mail", hashset!["john.smith@example.com"]))); @@ -2891,7 +2906,7 @@ fn test_as_ldap_mod_with_empty_phone() { ..Default::default() }; - let mods = user.as_ldap_mod(&config); + let mods = user_as_ldap_mod(&user, &config); assert!(mods.contains(&Mod::Replace("sn", hashset!["Smith"]))); assert!(mods.contains(&Mod::Replace("givenName", hashset!["John"]))); @@ -2918,7 +2933,7 @@ fn test_as_ldap_mod_with_active_directory() { ..Default::default() }; - let mods = user.as_ldap_mod(&config); + let mods = user_as_ldap_mod(&user, &config); assert!(mods.contains(&Mod::Replace("sn", hashset!["Smith"]))); assert!(mods.contains(&Mod::Replace("givenName", hashset!["John"]))); @@ -2944,7 +2959,7 @@ fn test_as_ldap_mod_with_custom_rdn() { ..Default::default() }; - let mods = user.as_ldap_mod(&config); + let mods = user_as_ldap_mod(&user, &config); assert!(mods.contains(&Mod::Replace("sn", hashset!["Smith"]))); assert!(mods.contains(&Mod::Replace("givenName", hashset!["John"]))); @@ -3017,7 +3032,7 @@ async fn test_ldap_sync_allowed_with_empty_sync_groups( user.password_hash = Some("hash".to_string()); let user = user.save(&pool).await.unwrap(); - let result = user.ldap_sync_allowed(&pool).await.unwrap(); + let result = ldap_sync_allowed_for_user(&user, &pool).await.unwrap(); assert!(result); } @@ -3031,7 +3046,7 @@ async fn test_ldap_sync_allowed_with_inactive_user(_: PgPoolOptions, options: Pg user.password_hash = Some("hash".to_string()); let user = user.save(&pool).await.unwrap(); - let result = user.ldap_sync_allowed(&pool).await.unwrap(); + let result = ldap_sync_allowed_for_user(&user, &pool).await.unwrap(); assert!(!result); } @@ -3047,7 +3062,7 @@ async fn test_ldap_sync_allowed_with_unenrolled_user(_: PgPoolOptions, options: user.from_ldap = false; let user = user.save(&pool).await.unwrap(); - let result = user.ldap_sync_allowed(&pool).await.unwrap(); + let result = ldap_sync_allowed_for_user(&user, &pool).await.unwrap(); assert!(!result); } @@ -3071,7 +3086,7 @@ async fn test_ldap_sync_allowed_with_sync_groups_user_in_group( settings.ldap_sync_groups = vec!["ldap_sync_group".to_string()]; update_current_settings(&pool, settings).await.unwrap(); - let result = user.ldap_sync_allowed(&pool).await.unwrap(); + let result = ldap_sync_allowed_for_user(&user, &pool).await.unwrap(); assert!(result); } @@ -3096,7 +3111,7 @@ async fn test_ldap_sync_allowed_with_sync_groups_user_not_in_group( settings.ldap_sync_groups = vec!["ldap_sync_group".to_string()]; update_current_settings(&pool, settings).await.unwrap(); - let result = user.ldap_sync_allowed(&pool).await.unwrap(); + let result = ldap_sync_allowed_for_user(&user, &pool).await.unwrap(); assert!(!result); } @@ -3127,7 +3142,7 @@ async fn test_ldap_sync_allowed_with_multiple_sync_groups( ]; update_current_settings(&pool, settings).await.unwrap(); - let result = user.ldap_sync_allowed(&pool).await.unwrap(); + let result = ldap_sync_allowed_for_user(&user, &pool).await.unwrap(); assert!(result); } @@ -3143,7 +3158,7 @@ async fn test_ldap_sync_allowed_enrolled_via_openid(_: PgPoolOptions, options: P user.from_ldap = false; let user = user.save(&pool).await.unwrap(); - let result = user.ldap_sync_allowed(&pool).await.unwrap(); + let result = ldap_sync_allowed_for_user(&user, &pool).await.unwrap(); assert!(result); } @@ -3159,7 +3174,7 @@ async fn test_ldap_sync_allowed_enrolled_via_ldap(_: PgPoolOptions, options: PgC user.from_ldap = true; let user = user.save(&pool).await.unwrap(); - let result = user.ldap_sync_allowed(&pool).await.unwrap(); + let result = ldap_sync_allowed_for_user(&user, &pool).await.unwrap(); assert!(result); } @@ -3181,6 +3196,6 @@ async fn test_ldap_sync_allowed_all_conditions_false(_: PgPoolOptions, options: settings.ldap_sync_groups = vec!["ldap_sync_group".to_string()]; update_current_settings(&pool, settings).await.unwrap(); - let result = user.ldap_sync_allowed(&pool).await.unwrap(); + let result = ldap_sync_allowed_for_user(&user, &pool).await.unwrap(); assert!(!result); } diff --git a/crates/defguard_core/src/enterprise/ldap/utils.rs b/crates/defguard_core/src/enterprise/ldap/utils.rs index 4ddc7bc8e6..066e0dbbae 100644 --- a/crates/defguard_core/src/enterprise/ldap/utils.rs +++ b/crates/defguard_core/src/enterprise/ldap/utils.rs @@ -3,14 +3,14 @@ use std::collections::{HashMap, HashSet}; -use defguard_common::db::Id; +use defguard_common::db::{ + Id, + models::{User, group::Group}, +}; use sqlx::PgPool; use super::{LDAPConnection, error::LdapError}; -use crate::{ - db::{Group, User}, - enterprise::ldap::with_ldap_status, -}; +use crate::enterprise::ldap::{model::ldap_sync_allowed_for_user, with_ldap_status}; /// Retrieves a user from LDAP if they are in the configured LDAP sync groups. /// @@ -80,7 +80,7 @@ pub(crate) async fn ldap_update_users_state(users: Vec<&mut User>, pool: &Pg pub(crate) async fn ldap_add_user(user: &mut User, password: Option<&str>, pool: &PgPool) { let _: Result<(), LdapError> = with_ldap_status(pool, async { debug!("Creating user {user} in LDAP"); - if !user.ldap_sync_allowed(pool).await? { + if !ldap_sync_allowed_for_user(user, pool).await? { debug!( "User {user} is not allowed to be synced to LDAP as he is not in the specified \ sync groups, skipping" @@ -88,6 +88,7 @@ pub(crate) async fn ldap_add_user(user: &mut User, password: Option<&str>, p return Ok(()); } let mut ldap_connection = LDAPConnection::create().await?; + // convert to ldap module wrapper if ldap_connection.user_exists(user).await? { debug!("User {user} already exists in LDAP, skipping creation"); return Ok(()); @@ -216,7 +217,7 @@ pub(crate) async fn ldap_add_users_to_groups( let adding_to_sync_groups = groups .iter() .any(|group| sync_groups_lookup.contains(*group)); - if !user.ldap_sync_allowed(pool).await? && !adding_to_sync_groups { + if !ldap_sync_allowed_for_user(user, pool).await? && !adding_to_sync_groups { debug!( "User {user} is not allowed to be synced to LDAP as he is not in the \ specified sync groups, skipping" @@ -251,7 +252,7 @@ pub(crate) async fn ldap_remove_users_from_groups( let removing_from_sync_groups = groups .iter() .any(|group| sync_groups_lookup.contains(*group)); - if !user.ldap_sync_allowed(pool).await? && !removing_from_sync_groups { + if !ldap_sync_allowed_for_user(user, pool).await? && !removing_from_sync_groups { debug!( "User {user} is not allowed to be synced to LDAP as he is not in the specified sync groups, skipping" @@ -275,7 +276,7 @@ pub(crate) async fn ldap_remove_users_from_groups( pub(crate) async fn ldap_change_password(user: &mut User, password: &str, pool: &PgPool) { let _: Result<(), LdapError> = with_ldap_status(pool, async { debug!("Changing password for user {user} in LDAP"); - if !user.ldap_sync_allowed(pool).await? { + if !ldap_sync_allowed_for_user(user, pool).await? { debug!( "User {user} is not allowed to be synced to LDAP as he is not in the specified sync groups, skipping" diff --git a/crates/defguard_core/src/enterprise/snat/handlers.rs b/crates/defguard_core/src/enterprise/snat/handlers.rs index 3db0d81760..40558cb4cf 100644 --- a/crates/defguard_core/src/enterprise/snat/handlers.rs +++ b/crates/defguard_core/src/enterprise/snat/handlers.rs @@ -4,7 +4,10 @@ use axum::{ Json, extract::{Path, State}, }; -use defguard_common::db::Id; +use defguard_common::db::{ + Id, + models::{User, WireguardNetwork}, +}; use reqwest::StatusCode; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -13,12 +16,13 @@ use utoipa::ToSchema; use crate::{ appstate::AppState, auth::{AdminRole, SessionInfo}, - db::{GatewayEvent, User, WireguardNetwork}, enterprise::{ - db::models::snat::UserSnatBinding, handlers::LicenseInfo, snat::error::UserSnatBindingError, + db::models::snat::UserSnatBinding, firewall::try_get_location_firewall_config, + handlers::LicenseInfo, snat::error::UserSnatBindingError, }, error::WebError, events::{ApiEvent, ApiEventType, ApiRequestContext}, + grpc::gateway::events::GatewayEvent, handlers::{ApiResponse, ApiResult}, }; @@ -153,7 +157,9 @@ pub async fn create_snat_binding( // trigger firewall config update on relevant gateways let mut conn = appstate.pool.acquire().await?; if let Some(location) = WireguardNetwork::find_by_id(&appstate.pool, location.id).await? { - if let Some(firewall_config) = location.try_get_firewall_config(&mut conn).await? { + if let Some(firewall_config) = + try_get_location_firewall_config(&location, &mut conn).await? + { debug!( "Sending firewall config update for location {location} affected by adding new SNAT binding" ); @@ -255,7 +261,9 @@ pub async fn modify_snat_binding( // trigger firewall config update on relevant gateways let mut conn = appstate.pool.acquire().await?; if let Some(location) = WireguardNetwork::find_by_id(&appstate.pool, location_id).await? { - if let Some(firewall_config) = location.try_get_firewall_config(&mut conn).await? { + if let Some(firewall_config) = + try_get_location_firewall_config(&location, &mut conn).await? + { debug!( "Sending firewall config update for location {location} affected by adding new SNAT binding" ); @@ -341,7 +349,9 @@ pub async fn delete_snat_binding( // trigger firewall config update on relevant gateways let mut conn = appstate.pool.acquire().await?; if let Some(location) = WireguardNetwork::find_by_id(&appstate.pool, location_id).await? { - if let Some(firewall_config) = location.try_get_firewall_config(&mut conn).await? { + if let Some(firewall_config) = + try_get_location_firewall_config(&location, &mut conn).await? + { debug!( "Sending firewall config update for location {location} affected by adding new SNAT binding" ); diff --git a/crates/defguard_core/src/error.rs b/crates/defguard_core/src/error.rs index e58a2f8876..84aadb2ab3 100644 --- a/crates/defguard_core/src/error.rs +++ b/crates/defguard_core/src/error.rs @@ -1,5 +1,8 @@ use axum::http::StatusCode; -use defguard_common::db::models::{ModelError, settings::SettingsValidationError}; +use defguard_common::db::models::{ + DeviceError, ModelError, WireguardNetworkError, settings::SettingsValidationError, + user::UserError, +}; use defguard_mail::templates::TemplateError; use sqlx::error::Error as SqlxError; use thiserror::Error; @@ -8,13 +11,14 @@ use utoipa::ToSchema; use crate::{ auth::failed_login::FailedLoginError, - db::models::{device::DeviceError, enrollment::TokenError, wireguard::WireguardNetworkError}, + db::models::enrollment::TokenError, enterprise::{ activity_log_stream::error::ActivityLogStreamError, db::models::acl::AclError, firewall::FirewallError, ldap::error::LdapError, license::LicenseError, }, events::ApiEvent, grpc::gateway::map::GatewayMapError, + location_management::LocationManagementError, }; /// Represents kinds of error that occurred @@ -149,7 +153,6 @@ impl From for WebError { | WireguardNetworkError::Unexpected(_) | WireguardNetworkError::DeviceError(_) | WireguardNetworkError::DeviceNotAllowed(_) - | WireguardNetworkError::FirewallError(_) | WireguardNetworkError::TokenError(_) => Self::Http(StatusCode::INTERNAL_SERVER_ERROR), } } @@ -188,3 +191,29 @@ impl From for WebError { } } } + +impl From for WebError { + fn from(err: UserError) -> Self { + error!("{}", err); + match err { + UserError::InvalidMfaState { username: _ } | UserError::DbError(_) => { + WebError::Http(StatusCode::INTERNAL_SERVER_ERROR) + } + UserError::EmailMfaError(msg) => WebError::EmailMfa(msg), + } + } +} + +impl From for WebError { + fn from(err: LocationManagementError) -> Self { + error!("{}", err); + match err { + LocationManagementError::FirewallError(firewall_error) => firewall_error.into(), + LocationManagementError::DbError(error) => error.into(), + LocationManagementError::WireguardNetworkError(wireguard_network_error) => { + wireguard_network_error.into() + } + LocationManagementError::ModelError(model_error) => model_error.into(), + } + } +} diff --git a/crates/defguard_core/src/events.rs b/crates/defguard_core/src/events.rs index dc316c0992..9a288fa280 100644 --- a/crates/defguard_core/src/events.rs +++ b/crates/defguard_core/src/events.rs @@ -3,15 +3,15 @@ use std::net::IpAddr; use chrono::{NaiveDateTime, Utc}; use defguard_common::db::{ Id, - models::{AuthenticationKey, MFAMethod, Settings}, + models::{ + AuthenticationKey, Device, MFAMethod, Settings, User, WebAuthn, WireguardNetwork, + group::Group, oauth2client::OAuth2Client, + }, }; use defguard_proto::proxy::MfaMethod; use crate::{ - db::{ - Device, Group, User, WebAuthn, WebHook, WireguardNetwork, - models::oauth2client::OAuth2Client, - }, + db::WebHook, enterprise::db::models::{ activity_log_stream::ActivityLogStream, api_tokens::ApiToken, openid_provider::OpenIdProvider, snat::UserSnatBinding, diff --git a/crates/defguard_core/src/grpc/auth.rs b/crates/defguard_core/src/grpc/auth.rs index c8099389f6..b64b7f0f57 100644 --- a/crates/defguard_core/src/grpc/auth.rs +++ b/crates/defguard_core/src/grpc/auth.rs @@ -1,6 +1,9 @@ use std::sync::{Arc, Mutex}; -use defguard_common::auth::claims::{Claims, ClaimsType}; +use defguard_common::{ + auth::claims::{Claims, ClaimsType}, + db::models::User, +}; use defguard_proto::auth::{AuthenticateRequest, AuthenticateResponse, auth_service_server}; use jsonwebtoken::errors::Error as JWTError; use sqlx::PgPool; @@ -8,7 +11,6 @@ use tonic::{Request, Response, Status}; use crate::{ auth::failed_login::{FailedLoginMap, check_failed_logins, log_failed_login_attempt}, - db::User, server_config, }; diff --git a/crates/defguard_core/src/grpc/client_mfa.rs b/crates/defguard_core/src/grpc/client_mfa.rs index f688a41a48..11c7b4d1a9 100644 --- a/crates/defguard_core/src/grpc/client_mfa.rs +++ b/crates/defguard_core/src/grpc/client_mfa.rs @@ -5,8 +5,13 @@ use defguard_common::{ auth::claims::{Claims, ClaimsType}, db::{ Id, - models::{BiometricAuth, BiometricChallenge}, + models::{ + BiometricAuth, BiometricChallenge, Device, DeviceNetworkInfo, User, WireguardNetwork, + device::{DeviceInfo, WireguardNetworkDevice}, + wireguard::LocationMfaMode, + }, }, + types::user_info::UserInfo, }; use defguard_mail::Mail; use defguard_proto::proxy::{ @@ -23,16 +28,9 @@ use tokio::sync::{ use tonic::{Code, Status}; use crate::{ - db::{ - Device, GatewayEvent, User, UserInfo, WireguardNetwork, - models::{ - device::{DeviceInfo, DeviceNetworkInfo, WireguardNetworkDevice}, - wireguard::LocationMfaMode, - }, - }, enterprise::{db::models::openid_provider::OpenIdProvider, is_enterprise_enabled}, events::{BidiRequestContext, BidiStreamEvent, BidiStreamEventType, DesktopClientMfaEvent}, - grpc::utils::parse_client_ip_agent, + grpc::{gateway::events::GatewayEvent, utils::parse_client_ip_agent}, handlers::mail::send_email_mfa_code_email, }; diff --git a/crates/defguard_core/src/grpc/enrollment.rs b/crates/defguard_core/src/grpc/enrollment.rs index c7a43069ba..41c0e08d72 100644 --- a/crates/defguard_core/src/grpc/enrollment.rs +++ b/crates/defguard_core/src/grpc/enrollment.rs @@ -4,7 +4,11 @@ use defguard_common::{ csv::AsCsv, db::{ Id, - models::{BiometricAuth, MFAMethod, Settings, settings::defaults::WELCOME_EMAIL_SUBJECT}, + models::{ + BiometricAuth, Device, DeviceConfig, DeviceType, MFAMethod, Settings, User, + WireguardNetwork, device::DeviceInfo, polling_token::PollingToken, + settings::defaults::WELCOME_EMAIL_SUBJECT, wireguard::ServiceLocationMode, + }, }, }; use defguard_mail::{ @@ -13,11 +17,9 @@ use defguard_mail::{ }; use defguard_proto::proxy::{ ActivateUserRequest, AdminInfo, CodeMfaSetupFinishRequest, CodeMfaSetupFinishResponse, - CodeMfaSetupStartRequest, CodeMfaSetupStartResponse, Device as ProtoDevice, - DeviceConfig as ProtoDeviceConfig, DeviceConfigResponse, EnrollmentStartRequest, - EnrollmentStartResponse, ExistingDevice, InitialUserInfo, - LocationMfaMode as ProtoLocationMfaMode, MfaMethod, NewDevice, RegisterMobileAuthRequest, - ServiceLocationMode as ProtoServiceLocationMode, + CodeMfaSetupStartRequest, CodeMfaSetupStartResponse, DeviceConfigResponse, + EnrollmentStartRequest, EnrollmentStartResponse, ExistingDevice, InitialUserInfo, MfaMethod, + NewDevice, RegisterMobileAuthRequest, }; use sqlx::{PgPool, Transaction, query_scalar}; use tokio::sync::{ @@ -28,23 +30,17 @@ use tonic::Status; use super::InstanceInfo; use crate::{ - db::{ - Device, GatewayEvent, User, WireguardNetwork, - models::{ - device::{DeviceConfig, DeviceInfo, DeviceType}, - enrollment::{ENROLLMENT_TOKEN_TYPE, Token, TokenError}, - polling_token::PollingToken, - wireguard::{LocationMfaMode, ServiceLocationMode}, - }, - }, + db::models::enrollment::{ENROLLMENT_TOKEN_TYPE, Token, TokenError}, enterprise::{ db::models::{enterprise_settings::EnterpriseSettings, openid_provider::OpenIdProvider}, + firewall::try_get_location_firewall_config, ldap::utils::ldap_add_user, limits::update_counts, }, events::{BidiRequestContext, BidiStreamEvent, BidiStreamEventType, EnrollmentEvent}, grpc::{ client_version::ClientFeature, + gateway::events::GatewayEvent, utils::{build_device_config_response, new_polling_token, parse_client_ip_agent}, }, handlers::{ @@ -741,13 +737,13 @@ impl EnrollmentServer { Status::internal("unexpected error") })? { - if let Some(firewall_config) = location - .try_get_firewall_config(&mut transaction) - .await - .map_err(|err| { - error!("Failed to get firewall config for location {location}: {err}",); - Status::internal("unexpected error") - })? + if let Some(firewall_config) = + try_get_location_firewall_config(&location, &mut transaction) + .await + .map_err(|err| { + error!("Failed to get firewall config for location {location}: {err}",); + Status::internal("unexpected error") + })? { debug!( "Sending firewall config update for location {location} affected by adding new device {}, user {}({})", @@ -1039,16 +1035,6 @@ impl EnrollmentServer { } } -impl From> for AdminInfo { - fn from(admin: User) -> Self { - Self { - name: format!("{} {}", admin.first_name, admin.last_name), - phone_number: admin.phone, - email: admin.email, - } - } -} - async fn initial_info_from_user( pool: &PgPool, user: User, @@ -1069,49 +1055,6 @@ async fn initial_info_from_user( is_admin, }) } - -impl From for ProtoDeviceConfig { - fn from(config: DeviceConfig) -> Self { - // DEPRECATED(1.5): superseeded by location_mfa_mode - let mfa_enabled = config.location_mfa_mode == LocationMfaMode::Internal; - Self { - network_id: config.network_id, - network_name: config.network_name, - config: config.config, - endpoint: config.endpoint, - assigned_ip: config.address.as_csv(), - pubkey: config.pubkey, - allowed_ips: config.allowed_ips.as_csv(), - dns: config.dns, - keepalive_interval: config.keepalive_interval, - #[allow(deprecated)] - mfa_enabled, - location_mfa_mode: Some( - >::into(config.location_mfa_mode) - .into(), - ), - service_location_mode: Some( - >::into( - config.service_location_mode, - ) - .into(), - ), - } - } -} - -impl From> for ProtoDevice { - fn from(device: Device) -> Self { - Self { - id: device.id, - name: device.name, - pubkey: device.wireguard_pubkey, - user_id: device.user_id, - created_at: device.created.and_utc().timestamp(), - } - } -} - impl Token { // Send configured welcome email to user after finishing enrollment async fn send_welcome_email( @@ -1194,7 +1137,7 @@ mod test { config::{DefGuardConfig, SERVER_CONFIG}, db::{ models::{ - Settings, + Settings, User, settings::{defaults::WELCOME_EMAIL_SUBJECT, initialize_current_settings}, }, setup_pool, @@ -1204,10 +1147,7 @@ mod test { use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use tokio::sync::mpsc::unbounded_channel; - use crate::db::{ - User, - models::enrollment::{ENROLLMENT_TOKEN_TYPE, Token}, - }; + use crate::db::models::enrollment::{ENROLLMENT_TOKEN_TYPE, Token}; #[sqlx::test] async fn dg25_11_test_enrollment_welcome_email(_: PgPoolOptions, options: PgConnectOptions) { diff --git a/crates/defguard_core/src/grpc/gateway/client_state.rs b/crates/defguard_core/src/grpc/gateway/client_state.rs index 1bc49a404c..3402a17286 100644 --- a/crates/defguard_core/src/grpc/gateway/client_state.rs +++ b/crates/defguard_core/src/grpc/gateway/client_state.rs @@ -1,14 +1,14 @@ use std::{collections::HashMap, net::SocketAddr}; use chrono::{NaiveDateTime, TimeDelta, Utc}; -use defguard_common::db::Id; +use defguard_common::db::{ + Id, + models::{Device, User, WireguardNetwork, wireguard_peer_stats::WireguardPeerStats}, +}; use thiserror::Error; use tonic::{Code, Status}; -use crate::{ - db::{Device, User, WireguardNetwork, models::wireguard_peer_stats::WireguardPeerStats}, - events::GrpcRequestContext, -}; +use crate::events::GrpcRequestContext; #[derive(Debug, Error)] pub enum ClientMapError { diff --git a/crates/defguard_core/src/grpc/gateway/events.rs b/crates/defguard_core/src/grpc/gateway/events.rs new file mode 100644 index 0000000000..9f4513bb4b --- /dev/null +++ b/crates/defguard_core/src/grpc/gateway/events.rs @@ -0,0 +1,17 @@ +use defguard_common::db::{ + Id, + models::{WireguardNetwork, device::DeviceInfo}, +}; +use defguard_proto::{enterprise::firewall::FirewallConfig, gateway::Peer}; + +#[derive(Clone, Debug)] +pub enum GatewayEvent { + NetworkCreated(Id, WireguardNetwork), + NetworkModified(Id, WireguardNetwork, Vec, Option), + NetworkDeleted(Id, String), + DeviceCreated(DeviceInfo), + DeviceModified(DeviceInfo), + DeviceDeleted(DeviceInfo), + FirewallConfigChanged(Id, FirewallConfig), + FirewallDisabled(Id), +} diff --git a/crates/defguard_core/src/grpc/gateway/mod.rs b/crates/defguard_core/src/grpc/gateway/mod.rs index ff119fc0fc..73a2afd730 100644 --- a/crates/defguard_core/src/grpc/gateway/mod.rs +++ b/crates/defguard_core/src/grpc/gateway/mod.rs @@ -7,7 +7,13 @@ use std::{ use chrono::{DateTime, TimeDelta, Utc}; use client_state::ClientMap; -use defguard_common::db::{Id, NoId}; +use defguard_common::db::{ + Id, NoId, + models::{ + Device, User, WireguardNetwork, wireguard::ServiceLocationMode, + wireguard_peer_stats::WireguardPeerStats, + }, +}; use defguard_mail::Mail; use defguard_proto::{ enterprise::firewall::FirewallConfig, @@ -18,7 +24,7 @@ use defguard_proto::{ }; use defguard_version::version_info_from_metadata; use semver::Version; -use sqlx::{Error as SqlxError, PgExecutor, PgPool, query}; +use sqlx::PgPool; use thiserror::Error; use tokio::{ sync::{ @@ -33,14 +39,14 @@ use tonic::{Code, Request, Response, Status, metadata::MetadataMap}; use self::map::GatewayMap; use crate::{ - db::{ - Device, GatewayEvent, User, - models::{wireguard::WireguardNetwork, wireguard_peer_stats::WireguardPeerStats}, - }, + enterprise::{firewall::try_get_location_firewall_config, is_enterprise_enabled}, events::{GrpcEvent, GrpcRequestContext}, + grpc::gateway::events::GatewayEvent, + location_management::allowed_peers::get_location_allowed_peers, }; pub mod client_state; +pub mod events; pub mod map; pub(crate) mod state; @@ -66,6 +72,32 @@ pub fn send_multiple_wireguard_events(events: Vec, wg_tx: &Sender< } } +/// Helper used to convert peer stats coming from gRPC client +/// into an internal representation +fn protos_into_internal_stats( + proto_stats: PeerStats, + location_id: Id, + device_id: Id, +) -> WireguardPeerStats { + let endpoint = match proto_stats.endpoint { + endpoint if endpoint.is_empty() => None, + _ => Some(proto_stats.endpoint), + }; + WireguardPeerStats { + id: NoId, + network: location_id, + endpoint, + device_id, + collected_at: Utc::now().naive_utc(), + upload: proto_stats.upload as i64, + download: proto_stats.download as i64, + latest_handshake: DateTime::from_timestamp(proto_stats.latest_handshake as i64, 0) + .unwrap_or_default() + .naive_utc(), + allowed_ips: Some(proto_stats.allowed_ips), + } +} + #[allow(clippy::large_enum_variant)] #[derive(Debug, Error)] pub enum GatewayServerError { @@ -90,68 +122,11 @@ pub struct GatewayServer { grpc_event_tx: UnboundedSender, } -impl WireguardNetwork { - /// Get a list of all allowed peers - /// - /// Each device is marked as allowed or not allowed in a given network, - /// which enables enforcing peer disconnect in MFA-protected networks. - /// - /// If the location is a service location, only returns peers if enterprise features are enabled. - pub async fn get_peers<'e, E>(&self, executor: E) -> Result, SqlxError> - where - E: PgExecutor<'e>, - { - debug!("Fetching all peers for network {}", self.id); - - if self.should_prevent_service_location_usage() { - warn!( - "Tried to use service location {} with disabled enterprise features. No clients will be allowed to connect.", - self.name - ); - return Ok(Vec::new()); - } - - let rows = query!( - "SELECT d.wireguard_pubkey pubkey, preshared_key, \ - -- TODO possible to not use ARRAY-unnest here? - ARRAY( - SELECT host(ip) - FROM unnest(wnd.wireguard_ips) AS ip - ) \"allowed_ips!: Vec\" \ - FROM wireguard_network_device wnd \ - JOIN device d ON wnd.device_id = d.id \ - JOIN \"user\" u ON d.user_id = u.id \ - WHERE wireguard_network_id = $1 AND (is_authorized = true OR NOT $2) \ - AND d.configured = true \ - AND u.is_active = true \ - ORDER BY d.id ASC", - self.id, - self.mfa_enabled() - ) - .fetch_all(executor) - .await?; - - // keepalive has to be added manually because Postgres - // doesn't support unsigned integers - let result = rows - .into_iter() - .map(|row| Peer { - pubkey: row.pubkey, - allowed_ips: row.allowed_ips, - // Don't send preshared key if MFA is not enabled, it can't be used and may - // cause issues with clients connecting if they expect no preshared key - // e.g. when you disable MFA on a location - preshared_key: if self.mfa_enabled() { - row.preshared_key - } else { - None - }, - keepalive_interval: Some(self.keepalive_interval as u32), - }) - .collect(); - - Ok(result) - } +/// If this location is marked as a service location, checks if all requirements are met for it to function: +/// - Enterprise is enabled +#[must_use] +pub fn should_prevent_service_location_usage(location: &WireguardNetwork) -> bool { + location.service_location_mode != ServiceLocationMode::Disabled && !is_enterprise_enabled() } /// Utility struct encapsulating commonly extracted metadata fields during gRPC communication. @@ -332,28 +307,6 @@ fn gen_config( } } -impl WireguardPeerStats { - fn from_peer_stats(stats: PeerStats, network_id: Id, device_id: Id) -> Self { - let endpoint = match stats.endpoint { - endpoint if endpoint.is_empty() => None, - _ => Some(stats.endpoint), - }; - Self { - id: NoId, - network: network_id, - endpoint, - device_id, - collected_at: Utc::now().naive_utc(), - upload: stats.upload as i64, - download: stats.download as i64, - latest_handshake: DateTime::from_timestamp(stats.latest_handshake as i64, 0) - .unwrap_or_default() - .naive_utc(), - allowed_ips: Some(stats.allowed_ips), - } - } -} - /// Helper struct for handling gateway events struct GatewayUpdatesHandler { network_id: Id, @@ -837,7 +790,7 @@ impl gateway_service_server::GatewayService for GatewayServer { let location = self.fetch_location_from_db(network_id).await?; // convert stats to DB storage format - let stats = WireguardPeerStats::from_peer_stats(peer_stats, network_id, device_id); + let stats = protos_into_internal_stats(peer_stats, network_id, device_id); // only perform client state update if stats include an endpoint IP // otherwise a peer was added to the gateway interface @@ -993,24 +946,29 @@ impl gateway_service_server::GatewayService for GatewayServer { error!("Failed to save updated network {network_id} in the database, status: {err}"); } - let peers = network.get_peers(&mut *conn).await.map_err(|error| { - error!("Failed to fetch peers from the database for network {network_id}: {error}",); - Status::new( - Code::Internal, - format!("Failed to retrieve peers from the database for network: {network_id}"), - ) - })?; - let maybe_firewall_config = - network - .try_get_firewall_config(&mut conn) + let peers = + get_location_allowed_peers(&network, &mut *conn) .await - .map_err(|err| { - error!("Failed to generate firewall config for network {network_id}: {err}"); + .map_err(|error| { + error!( + "Failed to fetch peers from the database for network {network_id}: {error}", + ); Status::new( Code::Internal, - format!("Failed to generate firewall config for network: {network_id}"), + format!( + "Failed to retrieve peers from the database for network: {network_id}" + ), ) })?; + let maybe_firewall_config = try_get_location_firewall_config(&network, &mut conn) + .await + .map_err(|err| { + error!("Failed to generate firewall config for network {network_id}: {err}"); + Status::new( + Code::Internal, + format!("Failed to generate firewall config for network: {network_id}"), + ) + })?; info!("Configuration sent to gateway client, network {network}."); diff --git a/crates/defguard_core/src/grpc/mod.rs b/crates/defguard_core/src/grpc/mod.rs index 96dbd24e69..90ef3c681c 100644 --- a/crates/defguard_core/src/grpc/mod.rs +++ b/crates/defguard_core/src/grpc/mod.rs @@ -46,9 +46,10 @@ pub use crate::version::MIN_GATEWAY_VERSION; use crate::{ auth::failed_login::FailedLoginMap, db::{ - AppEvent, GatewayEvent, + AppEvent, models::enrollment::{ENROLLMENT_TOKEN_TYPE, Token}, }, + enrollment_management::clear_unused_enrollment_tokens, enterprise::{ db::models::{enterprise_settings::EnterpriseSettings, openid_provider::OpenIdProvider}, directory_sync::sync_user_groups_if_configured, @@ -60,7 +61,7 @@ use crate::{ ldap::utils::ldap_update_user_state, }, events::{BidiStreamEvent, GrpcEvent}, - grpc::gateway::{client_state::ClientMap, map::GatewayMap}, + grpc::gateway::{client_state::ClientMap, events::GatewayEvent, map::GatewayMap}, server_config, version::{IncompatibleComponents, IncompatibleProxyData, is_proxy_version_supported}, }; @@ -455,7 +456,7 @@ async fn handle_proxy_message_loop( .await { Ok(mut user) => { - user.clear_unused_enrollment_tokens(&pool).await?; + clear_unused_enrollment_tokens(&user, &pool).await?; if let Err(err) = sync_user_groups_if_configured( &user, &pool, diff --git a/crates/defguard_core/src/grpc/password_reset.rs b/crates/defguard_core/src/grpc/password_reset.rs index 4e8b35e6d4..f3d44b6d99 100644 --- a/crates/defguard_core/src/grpc/password_reset.rs +++ b/crates/defguard_core/src/grpc/password_reset.rs @@ -1,3 +1,4 @@ +use defguard_common::db::models::User; use defguard_mail::Mail; use defguard_proto::proxy::{ DeviceInfo, PasswordResetInitializeRequest, PasswordResetRequest, PasswordResetStartRequest, @@ -8,10 +9,7 @@ use tokio::sync::mpsc::{UnboundedSender, error::SendError}; use tonic::Status; use crate::{ - db::{ - User, - models::enrollment::{PASSWORD_RESET_TOKEN_TYPE, Token}, - }, + db::models::enrollment::{PASSWORD_RESET_TOKEN_TYPE, Token}, enterprise::ldap::utils::ldap_change_password, events::{BidiRequestContext, BidiStreamEvent, BidiStreamEventType, PasswordResetEvent}, grpc::utils::parse_client_ip_agent, diff --git a/crates/defguard_core/src/grpc/utils.rs b/crates/defguard_core/src/grpc/utils.rs index b82e91fa1c..6955e02e22 100644 --- a/crates/defguard_core/src/grpc/utils.rs +++ b/crates/defguard_core/src/grpc/utils.rs @@ -2,7 +2,15 @@ use std::{net::IpAddr, str::FromStr}; use defguard_common::{ csv::AsCsv, - db::{Id, models::Settings}, + db::{ + Id, + models::{ + Device, DeviceType, Settings, User, WireguardNetwork, + device::WireguardNetworkDevice, + polling_token::PollingToken, + wireguard::{LocationMfaMode, ServiceLocationMode}, + }, + }, }; use defguard_proto::proxy::{ DeviceConfig as ProtoDeviceConfig, DeviceConfigResponse, DeviceInfo, @@ -13,18 +21,10 @@ use tonic::Status; use super::InstanceInfo; use crate::{ - db::{ - Device, User, - models::{ - device::{DeviceType, WireguardNetworkDevice}, - polling_token::PollingToken, - wireguard::{LocationMfaMode, ServiceLocationMode, WireguardNetwork}, - }, - }, enterprise::db::models::{ enterprise_settings::EnterpriseSettings, openid_provider::OpenIdProvider, }, - grpc::client_version::ClientFeature, + grpc::{client_version::ClientFeature, gateway::should_prevent_service_location_usage}, }; // Create a new token for configuration polling. @@ -179,7 +179,7 @@ pub(crate) async fn build_device_config_response( ); Status::internal(format!("unexpected error: {err}")) })?; - if network.should_prevent_service_location_usage() { + if should_prevent_service_location_usage(&network) { warn!( "Tried to use service location {} with disabled enterprise features.", network.name diff --git a/crates/defguard_core/src/grpc/worker.rs b/crates/defguard_core/src/grpc/worker.rs index 069492a687..2bc001adca 100644 --- a/crates/defguard_core/src/grpc/worker.rs +++ b/crates/defguard_core/src/grpc/worker.rs @@ -5,7 +5,7 @@ use std::{ time::Instant, }; -use defguard_common::db::models::{AuthenticationKey, AuthenticationKeyType}; +use defguard_common::db::models::{AuthenticationKey, AuthenticationKeyType, User, YubiKey}; pub use defguard_proto::worker::JobStatus; use defguard_proto::worker::{GetJobResponse, Worker, worker_service_server}; use sqlx::{PgPool, query}; @@ -13,7 +13,7 @@ use tokio::sync::mpsc::UnboundedSender; use tonic::{Request, Response, Status}; use super::{Job, JobResponse, WorkerDetail, WorkerInfo, WorkerState}; -use crate::db::{AppEvent, HWKeyUserData, User, YubiKey}; +use crate::db::{AppEvent, HWKeyUserData}; impl WorkerInfo { /// Create new `Worker` instance. diff --git a/crates/defguard_core/src/handlers/app_info.rs b/crates/defguard_core/src/handlers/app_info.rs index 344ee41925..daf574b1af 100644 --- a/crates/defguard_core/src/handlers/app_info.rs +++ b/crates/defguard_core/src/handlers/app_info.rs @@ -1,12 +1,14 @@ use axum::{extract::State, http::StatusCode}; -use defguard_common::{VERSION, db::models::Settings}; +use defguard_common::{ + VERSION, + db::models::{Settings, WireguardNetwork}, +}; use serde_json::json; use super::{ApiResponse, ApiResult}; use crate::{ appstate::AppState, auth::SessionInfo, - db::WireguardNetwork, enterprise::{ db::models::openid_provider::OpenIdProvider, is_enterprise_enabled, is_enterprise_free, diff --git a/crates/defguard_core/src/handlers/auth.rs b/crates/defguard_core/src/handlers/auth.rs index 357d68c814..8b3f213f69 100644 --- a/crates/defguard_core/src/handlers/auth.rs +++ b/crates/defguard_core/src/handlers/auth.rs @@ -13,9 +13,12 @@ use axum_extra::{ }, headers::UserAgent, }; -use defguard_common::db::{ - Id, - models::{MFAMethod, Settings}, +use defguard_common::{ + db::{ + Id, + models::{MFAInfo, MFAMethod, Session, SessionState, Settings, User, WebAuthn}, + }, + types::user_info::UserInfo, }; use defguard_mail::Mail; use serde_json::json; @@ -33,10 +36,9 @@ use super::{ use crate::{ appstate::AppState, auth::{ - SessionInfo, + SessionExtractor, SessionInfo, failed_login::{check_failed_logins, log_failed_login_attempt}, }, - db::{MFAInfo, Session, SessionState, User, UserInfo, WebAuthn}, enterprise::ldap::utils::login_through_ldap, error::WebError, events::{ApiEvent, ApiEventType, ApiRequestContext}, @@ -298,7 +300,7 @@ pub(crate) async fn authenticate( /// Logout - forget the session cookie. pub async fn logout( cookies: CookieJar, - session: Session, + SessionExtractor(session): SessionExtractor, user_agent: TypedHeader, InsecureClientIp(insecure_ip): InsecureClientIp, State(appstate): State, @@ -330,7 +332,7 @@ pub async fn logout( /// Enable MFA pub async fn mfa_enable( cookies: CookieJar, - _session: Session, + SessionExtractor(_session): SessionExtractor, session_info: SessionInfo, State(appstate): State, ) -> Result<(CookieJar, ApiResponse), WebError> { @@ -497,7 +499,10 @@ pub async fn webauthn_finish( } /// Start WebAuthn authentication -pub async fn webauthn_start(mut session: Session, State(appstate): State) -> ApiResult { +pub async fn webauthn_start( + SessionExtractor(mut session): SessionExtractor, + State(appstate): State, +) -> ApiResult { let passkeys = WebAuthn::passkeys_for_user(&appstate.pool, session.user_id).await?; match appstate.webauthn.start_passkey_authentication(&passkeys) { @@ -517,7 +522,7 @@ pub async fn webauthn_start(mut session: Session, State(appstate): State, InsecureClientIp(insecure_ip): InsecureClientIp, State(appstate): State, @@ -688,7 +693,7 @@ pub async fn totp_disable( /// Validate one-time passcode pub async fn totp_code( private_cookies: PrivateCookieJar, - mut session: Session, + SessionExtractor(mut session): SessionExtractor, user_agent: TypedHeader, InsecureClientIp(insecure_ip): InsecureClientIp, State(appstate): State, @@ -855,7 +860,7 @@ pub async fn email_mfa_disable( /// Send email code to user pub async fn request_email_mfa_code( - session: Session, + SessionExtractor(session): SessionExtractor, State(appstate): State, ) -> ApiResult { if let Some(user) = User::find_by_id(&appstate.pool, session.user_id).await? { @@ -875,7 +880,7 @@ pub async fn request_email_mfa_code( /// Validate email MFA code pub async fn email_mfa_code( private_cookies: PrivateCookieJar, - mut session: Session, + SessionExtractor(mut session): SessionExtractor, user_agent: TypedHeader, InsecureClientIp(insecure_ip): InsecureClientIp, State(appstate): State, @@ -968,7 +973,7 @@ pub async fn email_mfa_code( /// Authenticate with a recovery code. pub async fn recovery_code( private_cookies: PrivateCookieJar, - mut session: Session, + SessionExtractor(mut session): SessionExtractor, user_agent: TypedHeader, InsecureClientIp(insecure_ip): InsecureClientIp, State(appstate): State, diff --git a/crates/defguard_core/src/handlers/forward_auth.rs b/crates/defguard_core/src/handlers/forward_auth.rs index 1421932008..9799b79f80 100644 --- a/crates/defguard_core/src/handlers/forward_auth.rs +++ b/crates/defguard_core/src/handlers/forward_auth.rs @@ -4,10 +4,11 @@ use axum::{ response::{IntoResponse, Redirect, Response}, }; use axum_extra::extract::cookie::CookieJar; +use defguard_common::db::models::Session; use reqwest::Url; use super::SESSION_COOKIE_NAME; -use crate::{appstate::AppState, db::Session, error::WebError, server_config}; +use crate::{appstate::AppState, error::WebError, server_config}; // Header names static FORWARDED_HOST: &str = "x-forwarded-host"; diff --git a/crates/defguard_core/src/handlers/group.rs b/crates/defguard_core/src/handlers/group.rs index 41bebed9d9..f35f6eaaa8 100644 --- a/crates/defguard_core/src/handlers/group.rs +++ b/crates/defguard_core/src/handlers/group.rs @@ -4,7 +4,13 @@ use axum::{ extract::{Json, Path, State}, http::StatusCode, }; -use defguard_common::db::Id; +use defguard_common::db::{ + Id, + models::{ + User, + group::{Group, Permission}, + }, +}; use serde_json::json; use sqlx::query_as; use utoipa::ToSchema; @@ -13,7 +19,6 @@ use super::{ApiResponse, ApiResult, EditGroupInfo, GroupInfo, Username}; use crate::{ appstate::AppState, auth::{AdminRole, SessionInfo}, - db::{Group, User, WireguardNetwork, models::group::Permission}, enterprise::ldap::utils::{ ldap_add_user_to_groups, ldap_add_users_to_groups, ldap_delete_group, ldap_modify_group, ldap_remove_user_from_groups, ldap_remove_users_from_groups, ldap_update_user_state, @@ -22,6 +27,7 @@ use crate::{ error::WebError, events::{ApiEvent, ApiEventType, ApiRequestContext}, hashset, + location_management::sync_all_networks, }; #[derive(Serialize, ToSchema)] @@ -116,7 +122,7 @@ pub(crate) async fn bulk_assign_to_groups( } } - WireguardNetwork::sync_all_networks(&mut transaction, &appstate.wireguard_tx).await?; + sync_all_networks(&mut transaction, &appstate.wireguard_tx).await?; transaction.commit().await?; @@ -365,7 +371,7 @@ pub(crate) async fn create_group( .insert(&group_info.name); } - WireguardNetwork::sync_all_networks(&mut transaction, &appstate.wireguard_tx).await?; + sync_all_networks(&mut transaction, &appstate.wireguard_tx).await?; transaction.commit().await?; @@ -502,7 +508,7 @@ pub(crate) async fn modify_group( .insert(group.name.as_str()); } - WireguardNetwork::sync_all_networks(&mut transaction, &appstate.wireguard_tx).await?; + sync_all_networks(&mut transaction, &appstate.wireguard_tx).await?; let users_after = group.members(&mut *transaction).await?.clone(); transaction.commit().await?; @@ -602,7 +608,7 @@ pub(crate) async fn delete_group( // sync allowed devices for all locations let mut conn = appstate.pool.acquire().await?; - WireguardNetwork::sync_all_networks(&mut conn, &appstate.wireguard_tx).await?; + sync_all_networks(&mut conn, &appstate.wireguard_tx).await?; info!("User {} deleted group {name}", &session.user.username); appstate.emit_event(ApiEvent { @@ -656,7 +662,7 @@ pub(crate) async fn add_group_member( ldap_add_user_to_groups(&user, hashset![group.name.as_str()], &appstate.pool).await; ldap_update_user_state(&mut user, &appstate.pool).await; let mut conn = appstate.pool.acquire().await?; - WireguardNetwork::sync_all_networks(&mut conn, &appstate.wireguard_tx).await?; + sync_all_networks(&mut conn, &appstate.wireguard_tx).await?; info!("Added user: {} to group: {}", user.username, group.name); appstate.emit_event(ApiEvent { context, @@ -719,7 +725,7 @@ pub(crate) async fn remove_group_member( .await; let mut conn = appstate.pool.acquire().await?; - WireguardNetwork::sync_all_networks(&mut conn, &appstate.wireguard_tx).await?; + sync_all_networks(&mut conn, &appstate.wireguard_tx).await?; info!("Removed user: {} from group: {}", user.username, group.name); appstate.emit_event(ApiEvent { context, diff --git a/crates/defguard_core/src/handlers/mail.rs b/crates/defguard_core/src/handlers/mail.rs index 9a9ce05e8c..7f78913bec 100644 --- a/crates/defguard_core/src/handlers/mail.rs +++ b/crates/defguard_core/src/handlers/mail.rs @@ -5,7 +5,10 @@ use axum::{ http::StatusCode, }; use chrono::{NaiveDateTime, Utc}; -use defguard_common::db::{Id, models::MFAMethod}; +use defguard_common::db::{ + Id, + models::{MFAMethod, User}, +}; use defguard_mail::{ Attachment, Mail, templates::{self, SessionContext, TemplateError, TemplateLocation, support_data_mail}, @@ -23,7 +26,7 @@ use crate::{ PgPool, appstate::AppState, auth::{AdminRole, SessionInfo}, - db::{User, models::enrollment::TokenError}, + db::models::enrollment::TokenError, error::WebError, server_config, support::dump_config, diff --git a/crates/defguard_core/src/handlers/mod.rs b/crates/defguard_core/src/handlers/mod.rs index 93263db869..60570d8e6d 100644 --- a/crates/defguard_core/src/handlers/mod.rs +++ b/crates/defguard_core/src/handlers/mod.rs @@ -6,7 +6,13 @@ use axum::{ }; use axum_client_ip::InsecureClientIp; use axum_extra::{TypedHeader, headers::UserAgent}; -use defguard_common::db::{Id, NoId}; +use defguard_common::{ + db::{ + Id, NoId, + models::{Device, User}, + }, + types::user_info::UserInfo, +}; use serde_json::{Value, json}; use sqlx::PgPool; use utoipa::ToSchema; @@ -15,7 +21,7 @@ use webauthn_rs::prelude::RegisterPublicKeyCredential; use crate::{ appstate::AppState, auth::SessionInfo, - db::{Device, User, UserInfo, WebHook}, + db::WebHook, enterprise::{db::models::acl::AclError, license::LicenseError}, error::WebError, events::ApiRequestContext, @@ -28,14 +34,14 @@ pub(crate) mod forward_auth; pub(crate) mod group; pub(crate) mod mail; pub mod network_devices; -pub(crate) mod openid_clients; +pub mod openid_clients; pub mod openid_flow; pub(crate) mod pagination; pub(crate) mod settings; pub(crate) mod ssh_authorized_keys; pub(crate) mod support; pub(crate) mod updates; -pub(crate) mod user; +pub mod user; pub(crate) mod webhooks; pub mod wireguard; pub mod worker; diff --git a/crates/defguard_core/src/handlers/network_devices.rs b/crates/defguard_core/src/handlers/network_devices.rs index b0b1863fd7..827838eefc 100644 --- a/crates/defguard_core/src/handlers/network_devices.rs +++ b/crates/defguard_core/src/handlers/network_devices.rs @@ -8,7 +8,17 @@ use axum::{ http::StatusCode, }; use chrono::NaiveDateTime; -use defguard_common::{csv::AsCsv, db::Id}; +use defguard_common::{ + csv::AsCsv, + db::{ + Id, + models::{ + Device, DeviceConfig, DeviceType, User, WireguardNetwork, + device::{DeviceInfo, WireguardNetworkDevice}, + wireguard::NetworkAddressError, + }, + }, +}; use defguard_mail::templates::TemplateLocation; use ipnetwork::IpNetwork; use serde_json::json; @@ -18,15 +28,10 @@ use super::{ApiResponse, ApiResult, WebError}; use crate::{ appstate::AppState, auth::{AdminRole, SessionInfo}, - db::{ - Device, GatewayEvent, User, WireguardNetwork, - models::{ - device::{DeviceConfig, DeviceInfo, DeviceType, WireguardNetworkDevice}, - wireguard::NetworkAddressError, - }, - }, - enterprise::limits::update_counts, + enrollment_management::start_desktop_configuration, + enterprise::{firewall::try_get_location_firewall_config, limits::update_counts}, events::{ApiEvent, ApiEventType, ApiRequestContext}, + grpc::gateway::events::GatewayEvent, handlers::mail::send_new_device_added_email, server_config, }; @@ -460,18 +465,18 @@ pub(crate) async fn start_network_device_setup( device: NetworkDeviceInfo::from_device(device, &mut transaction).await?, }; let config = server_config(); - let configuration_token = user - .start_remote_desktop_configuration( - &mut transaction, - &user, - None, - config.enrollment_token_timeout.as_secs(), - config.enrollment_url.clone(), - false, - appstate.mail_tx.clone(), - Some(result.device.id), - ) - .await?; + let configuration_token = start_desktop_configuration( + &user, + &mut transaction, + &user, + None, + config.enrollment_token_timeout.as_secs(), + config.enrollment_url.clone(), + false, + appstate.mail_tx.clone(), + Some(result.device.id), + ) + .await?; debug!( "Generated a new device CLI configuration token for a network device {device_name} with ID {}: {configuration_token}", @@ -526,18 +531,18 @@ pub(crate) async fn start_network_device_setup_for_device( )) })?; let config = server_config(); - let configuration_token = user - .start_remote_desktop_configuration( - &mut transaction, - &user, - None, - config.enrollment_token_timeout.as_secs(), - config.enrollment_url.clone(), - false, - appstate.mail_tx.clone(), - Some(device.id), - ) - .await?; + let configuration_token = start_desktop_configuration( + &user, + &mut transaction, + &user, + None, + config.enrollment_token_timeout.as_secs(), + config.enrollment_url.clone(), + false, + appstate.mail_tx.clone(), + Some(device.id), + ) + .await?; transaction.commit().await?; debug!( @@ -629,7 +634,9 @@ pub(crate) async fn add_network_device( update_counts(&mut *transaction).await?; // send firewall update event if ACLs & enterprise features are enabled - if let Some(firewall_config) = network.try_get_firewall_config(&mut transaction).await? { + if let Some(firewall_config) = + try_get_location_firewall_config(&network, &mut transaction).await? + { appstate.send_wireguard_event(GatewayEvent::FirewallConfigChanged( network.id, firewall_config, @@ -733,9 +740,8 @@ pub async fn modify_network_device( // send firewall update event if ACLs are enabled if device_network.acl_enabled { - if let Some(firewall_config) = device_network - .try_get_firewall_config(&mut transaction) - .await? + if let Some(firewall_config) = + try_get_location_firewall_config(&device_network, &mut transaction).await? { appstate.send_wireguard_event(GatewayEvent::FirewallConfigChanged( device_network.id, diff --git a/crates/defguard_core/src/handlers/openid_clients.rs b/crates/defguard_core/src/handlers/openid_clients.rs index e0a911590d..767c5b16a1 100644 --- a/crates/defguard_core/src/handlers/openid_clients.rs +++ b/crates/defguard_core/src/handlers/openid_clients.rs @@ -2,19 +2,46 @@ use axum::{ extract::{Json, Path, State}, http::StatusCode, }; +use defguard_common::{ + db::{ + NoId, + models::oauth2client::{OAuth2Client, OAuth2ClientSafe}, + }, + random::gen_alphanumeric, +}; use serde_json::json; use super::{ApiResponse, ApiResult, webhooks::ChangeStateData}; use crate::{ appstate::AppState, auth::{AdminRole, SessionInfo}, - db::models::{ - NewOpenIDClient, - oauth2client::{OAuth2Client, OAuth2ClientSafe}, - }, events::{ApiEvent, ApiEventType, ApiRequestContext}, }; +#[derive(Deserialize, Serialize)] +pub struct NewOpenIDClient { + pub name: String, + pub redirect_uri: Vec, + pub scope: Vec, + pub enabled: bool, +} + +impl From for OAuth2Client { + fn from(value: NewOpenIDClient) -> Self { + let client_id = gen_alphanumeric(16); + let client_secret = gen_alphanumeric(32); + Self { + id: NoId, + client_id, + client_secret, + redirect_uri: value.redirect_uri, + scope: value.scope, + name: value.name, + enabled: value.enabled, + } + } +} + pub async fn add_openid_client( _admin: AdminRole, session: SessionInfo, @@ -36,7 +63,8 @@ pub async fn add_openid_client( status: StatusCode::BAD_REQUEST, }); } - let client = OAuth2Client::from_new(data).save(&appstate.pool).await?; + let client: OAuth2Client = data.into(); + let client = client.save(&appstate.pool).await?; info!( "User {} added OpenID client {}", session.user.username, client.name diff --git a/crates/defguard_core/src/handlers/openid_flow.rs b/crates/defguard_core/src/handlers/openid_flow.rs index cad6c47f7d..6a2da694cf 100644 --- a/crates/defguard_core/src/handlers/openid_flow.rs +++ b/crates/defguard_core/src/handlers/openid_flow.rs @@ -5,7 +5,7 @@ use std::{ use axum::{ Form, - extract::{FromRef, OptionalFromRequestParts, Query, State}, + extract::{FromRef, FromRequestParts, Query, State}, http::{ HeaderMap, HeaderValue, StatusCode, header::{AUTHORIZATION, LOCATION}, @@ -15,7 +15,13 @@ use axum::{ use axum_extra::extract::cookie::{Cookie, CookieJar, PrivateCookieJar, SameSite}; use base64::{Engine, prelude::BASE64_STANDARD}; use chrono::Utc; -use defguard_common::db::{Id, NoId, models::AuthCode}; +use defguard_common::db::{ + Id, NoId, + models::{ + AuthCode, OAuth2AuthorizedApp, OAuth2Token, Session, SessionState, User, + oauth2client::OAuth2Client, + }, +}; use openidconnect::{ AccessToken, AdditionalClaims, Audience, AuthUrl, AuthorizationCode, EmptyAdditionalProviderMetadata, EmptyExtraTokenFields, EndUserEmail, EndUserFamilyName, @@ -43,10 +49,6 @@ use super::{ApiResponse, ApiResult, SESSION_COOKIE_NAME}; use crate::{ appstate::AppState, auth::{SessionInfo, UserClaims}, - db::{ - OAuth2AuthorizedApp, OAuth2Token, Session, SessionState, User, - models::oauth2client::OAuth2Client, - }, error::WebError, handlers::{SIGN_IN_COOKIE_NAME, mail::send_new_device_ocid_login_email}, server_config, @@ -111,19 +113,17 @@ pub type DefguardIdTokenFields = IdTokenFields< >; pub type DefguardTokenResponse = StandardTokenResponse; +pub struct OAuth2ClientExtractor(Option>); /// Provide `OAuth2Client` when Basic Authorization header contains `client_id` and `client_secret`. -impl OptionalFromRequestParts for OAuth2Client +impl FromRequestParts for OAuth2ClientExtractor where S: Send + Sync, AppState: FromRef, { type Rejection = WebError; - async fn from_request_parts( - parts: &mut Parts, - state: &S, - ) -> Result, Self::Rejection> { + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { if let Some(basic_auth) = parts.headers.get(AUTHORIZATION).and_then(|value| { if let Ok(value) = value.to_str() { if value.starts_with("Basic ") { @@ -136,19 +136,17 @@ where if let Ok(auth_pair) = String::from_utf8(decoded) { if let Some((client_id, client_secret)) = auth_pair.split_once(':') { let appstate = AppState::from_ref(state); - return OAuth2Client::find_by_auth( - &appstate.pool, - client_id, - client_secret, - ) - .await - .map_err(Into::into); + return Ok(Self( + OAuth2Client::find_by_auth(&appstate.pool, client_id, client_secret) + .await + .map_err(Into::::into)?, + )); } } } Err(WebError::Authorization("Invalid credentials".into())) } else { - Ok(None) + Ok(Self(None)) } } } @@ -803,7 +801,7 @@ impl TokenRequest { /// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens pub async fn token( State(appstate): State, - oauth2client: Option>, + OAuth2ClientExtractor(oauth2client): OAuth2ClientExtractor, Form(form): Form, ) -> ApiResult { // TODO: cleanup branches diff --git a/crates/defguard_core/src/handlers/ssh_authorized_keys.rs b/crates/defguard_core/src/handlers/ssh_authorized_keys.rs index 8b6be6d37e..6d1b7c96d6 100644 --- a/crates/defguard_core/src/handlers/ssh_authorized_keys.rs +++ b/crates/defguard_core/src/handlers/ssh_authorized_keys.rs @@ -5,7 +5,7 @@ use axum::{ }; use defguard_common::db::{ Id, - models::{AuthenticationKey, AuthenticationKeyType}, + models::{AuthenticationKey, AuthenticationKeyType, User, group::Group}, }; use serde_json::json; use sqlx::{Error as SqlxError, PgExecutor, PgPool, query}; @@ -15,7 +15,6 @@ use super::{ApiResponse, ApiResult, user_for_admin_or_self}; use crate::{ appstate::AppState, auth::SessionInfo, - db::{Group, User}, error::WebError, events::{ApiEvent, ApiEventType, ApiRequestContext}, }; diff --git a/crates/defguard_core/src/handlers/user.rs b/crates/defguard_core/src/handlers/user.rs index 128522ddd4..a901bde27e 100644 --- a/crates/defguard_core/src/handlers/user.rs +++ b/crates/defguard_core/src/handlers/user.rs @@ -4,9 +4,21 @@ use axum::{ extract::{Json, Path, State}, http::StatusCode, }; +use defguard_common::{ + db::{ + Id, + models::{ + BiometricAuth, OAuth2AuthorizedApp, User, WebAuthn, device::UserDevice, + user::SecurityKey, + }, + }, + types::{group_diff::GroupDiff, user_info::UserInfo}, +}; use defguard_mail::{Mail, templates}; use humantime::parse_duration; use serde_json::json; +use sqlx::{Error as SqlxError, PgPool}; +use utoipa::ToSchema; use super::{ AddUserData, ApiResponse, ApiResult, PasswordChange, PasswordChangeSelf, @@ -17,24 +29,26 @@ use crate::{ appstate::AppState, auth::{AdminRole, SessionInfo}, db::{ - AppEvent, OAuth2AuthorizedApp, User, UserDetails, UserInfo, WebAuthn, - models::{ - GroupDiff, - enrollment::{PASSWORD_RESET_TOKEN_TYPE, Token}, - }, + AppEvent, + models::enrollment::{PASSWORD_RESET_TOKEN_TYPE, Token}, }, + enrollment_management::{start_desktop_configuration, start_user_enrollment}, enterprise::{ db::models::api_tokens::ApiToken, handlers::CanManageDevices, - ldap::utils::{ - ldap_add_user, ldap_add_user_to_groups, ldap_change_password, ldap_delete_user, - ldap_handle_user_modify, ldap_remove_user_from_groups, ldap_update_user_state, + ldap::{ + model::{ldap_sync_allowed_for_user, maybe_update_rdn}, + utils::{ + ldap_add_user, ldap_add_user_to_groups, ldap_change_password, ldap_delete_user, + ldap_handle_user_modify, ldap_remove_user_from_groups, ldap_update_user_state, + }, }, limits::update_counts, }, error::WebError, events::{ApiEvent, ApiEventType, ApiRequestContext}, is_valid_phone_number, server_config, + user_management::{delete_user_and_cleanup_devices, sync_allowed_user_devices}, }; /// The maximum length for the commonName (CN) attribute in LDAP schemas is commonly set to 64 @@ -107,6 +121,35 @@ pub(crate) fn check_password_strength(password: &str) -> Result<(), WebError> { Ok(()) } +// Full user info with related objects +#[derive(Deserialize, Serialize, Debug, ToSchema)] +pub struct UserDetails { + pub user: UserInfo, + #[serde(default)] + pub devices: Vec, + pub biometric_enabled_devices: Vec, + #[serde(default)] + pub security_keys: Vec, +} + +impl UserDetails { + pub async fn from_user(pool: &PgPool, user: &User) -> Result { + let devices = user.user_devices(pool).await?; + let security_keys = user.security_keys(pool).await?; + let biometric_enabled_devices = BiometricAuth::find_by_user_id(pool, user.id) + .await? + .iter() + .map(|a| a.device_id) + .collect::>(); + Ok(Self { + user: UserInfo::from_user(pool, user).await?, + devices, + security_keys, + biometric_enabled_devices, + }) + } +} + /// List of all users /// /// Retrieves list of users. @@ -441,17 +484,17 @@ pub async fn start_enrollment( None => config.enrollment_token_timeout.as_secs(), }; - let enrollment_token = user - .start_enrollment( - &mut transaction, - &session.user, - data.email, - token_expiration_time_seconds, - config.enrollment_url.clone(), - data.send_enrollment_notification, - appstate.mail_tx.clone(), - ) - .await?; + let enrollment_token = start_user_enrollment( + &mut user, + &mut transaction, + &session.user, + data.email, + token_expiration_time_seconds, + config.enrollment_url.clone(), + data.send_enrollment_notification, + appstate.mail_tx.clone(), + ) + .await?; debug!("Try to commit transaction to save the enrollment token into the database."); transaction.commit().await?; @@ -544,18 +587,18 @@ pub async fn start_remote_desktop_configuration( session.user.username ); let config = server_config(); - let desktop_configuration_token = user - .start_remote_desktop_configuration( - &mut transaction, - &session.user, - Some(email), - config.enrollment_token_timeout.as_secs(), - config.enrollment_url.clone(), - data.send_enrollment_notification, - appstate.mail_tx.clone(), - None, - ) - .await?; + let desktop_configuration_token = start_desktop_configuration( + &user, + &mut transaction, + &session.user, + Some(email), + config.enrollment_token_timeout.as_secs(), + config.enrollment_url.clone(), + data.send_enrollment_notification, + appstate.mail_tx.clone(), + None, + ) + .await?; debug!("Try to submit transaction to save the desktop configuration token into the databse."); transaction.commit().await?; @@ -700,7 +743,7 @@ pub async fn modify_user( let status_changing = user_info.is_active != user.is_active; let mut transaction = appstate.pool.begin().await?; - let ldap_sync_allowed = user.ldap_sync_allowed(&mut *transaction).await?; + let ldap_sync_allowed = ldap_sync_allowed_for_user(&user, &mut *transaction).await?; // remove authorized apps if needed let request_app_ids: Vec = user_info @@ -742,8 +785,7 @@ pub async fn modify_user( "User {} changed {username} groups or status, syncing allowed network devices.", session.user.username ); - user.sync_allowed_devices(&mut transaction, &appstate.wireguard_tx) - .await?; + sync_allowed_user_devices(&user, &mut transaction, &appstate.wireguard_tx).await?; } // remove API tokens when deactivating a user @@ -766,7 +808,7 @@ pub async fn modify_user( ldap_handle_user_modify(&old_username, &mut user, &appstate.pool).await; } - user.maybe_update_rdn(); + maybe_update_rdn(&mut user); user.save(&appstate.pool).await?; Box::pin(ldap_update_user_state(&mut user, &appstate.pool)).await; @@ -874,13 +916,12 @@ pub async fn delete_user( session.user.username ); let mut transaction = appstate.pool.begin().await?; - let user_for_ldap = if user.ldap_sync_allowed(&mut *transaction).await? { + let user_for_ldap = if ldap_sync_allowed_for_user(&user, &mut *transaction).await? { Some(user.clone().as_noid()) } else { None }; - user.clone() - .delete_and_cleanup(&mut transaction, &appstate.wireguard_tx) + delete_user_and_cleanup_devices(user.clone(), &mut transaction, &appstate.wireguard_tx) .await?; appstate.trigger_action(AppEvent::UserDeleted(username.clone())); diff --git a/crates/defguard_core/src/handlers/wireguard.rs b/crates/defguard_core/src/handlers/wireguard.rs index 9410134bdd..3c07815458 100644 --- a/crates/defguard_core/src/handlers/wireguard.rs +++ b/crates/defguard_core/src/handlers/wireguard.rs @@ -11,7 +11,22 @@ use axum::{ http::StatusCode, }; use chrono::{DateTime, NaiveDateTime, TimeDelta, Utc}; -use defguard_common::{csv::AsCsv, db::Id}; +use defguard_common::{ + csv::AsCsv, + db::{ + Id, + models::{ + Device, DeviceConfig, DeviceNetworkInfo, DeviceType, WireguardNetwork, + device::{AddDevice, DeviceInfo, ModifyDevice, WireguardNetworkDevice}, + wireguard::{ + DateTimeAggregation, LocationMfaMode, MappedDevice, ServiceLocationMode, + WireguardDeviceStatsRow, WireguardNetworkStats, WireguardUserStatsRow, + networks_stats, + }, + }, + }, + utils::{parse_address_list, parse_network_address_list}, +}; use defguard_mail::templates::TemplateLocation; use ipnetwork::IpNetwork; use serde_json::{Value, json}; @@ -23,53 +38,31 @@ use super::{ApiResponse, ApiResult, WebError, device_for_admin_or_self, user_for use crate::{ appstate::AppState, auth::{AdminRole, SessionInfo}, - db::{ - AddDevice, Device, GatewayEvent, WireguardNetwork, - models::{ - device::{ - DeviceConfig, DeviceInfo, DeviceNetworkInfo, DeviceType, ModifyDevice, - WireguardNetworkDevice, - }, - wireguard::{ - DateTimeAggregation, LocationMfaMode, MappedDevice, ServiceLocationMode, - WireguardDeviceStatsRow, WireguardNetworkInfo, WireguardNetworkStats, - WireguardUserStatsRow, networks_stats, - }, - }, - }, enterprise::{ db::models::{enterprise_settings::EnterpriseSettings, openid_provider::OpenIdProvider}, + firewall::try_get_location_firewall_config, handlers::CanManageDevices, is_enterprise_enabled, limits::update_counts, }, events::{ApiEvent, ApiEventType, ApiRequestContext}, - grpc::gateway::map::GatewayMap, + grpc::gateway::{events::GatewayEvent, map::GatewayMap, state::GatewayState}, handlers::mail::send_new_device_added_email, + location_management::{ + allowed_peers::get_location_allowed_peers, handle_imported_devices, handle_mapped_devices, + sync_location_allowed_devices, + }, server_config, wg_config::{ImportedDevice, parse_wireguard_config}, }; -/// Parse a string with comma-separated IP addresses. -/// Invalid addresses will be silently ignored. -pub(crate) fn parse_address_list(ips: &str) -> Vec { - ips.split(',') - .filter_map(|ip| ip.trim().parse().ok()) - .collect() -} - -/// Parse a string with comma-separated IP network addresses. -/// Host bits will be stripped. -/// Invalid addresses will be silently ignored. -pub(crate) fn parse_network_address_list(ips: &str) -> Vec { - ips.split(',') - .filter_map(|ip| ip.trim().parse().ok()) - .filter_map(|ip: IpNetwork| { - let network_address = ip.network(); - let network_mask = ip.mask(); - IpNetwork::with_netmask(network_address, network_mask).ok() - }) - .collect() +#[derive(Serialize, ToSchema)] +pub struct WireguardNetworkInfo { + #[serde(flatten)] + pub network: WireguardNetwork, + pub connected: bool, + pub gateways: Vec, + pub allowed_groups: Vec, } #[derive(Deserialize, Serialize, ToSchema)] @@ -335,10 +328,11 @@ pub(crate) async fn modify_network( network .set_allowed_groups(&mut transaction, data.allowed_groups) .await?; - let _events = network.sync_allowed_devices(&mut transaction, None).await?; + let _events = sync_location_allowed_devices(&network, &mut transaction, None).await?; - let peers = network.get_peers(&mut *transaction).await?; - let maybe_firewall_config = network.try_get_firewall_config(&mut transaction).await?; + let peers = get_location_allowed_peers(&network, &mut *transaction).await?; + let maybe_firewall_config = + try_get_location_firewall_config(&network, &mut transaction).await?; appstate.send_wireguard_event(GatewayEvent::NetworkModified( network.id, network.clone(), @@ -626,16 +620,14 @@ pub(crate) async fn import_network( .iter() .flat_map(|dev| dev.wireguard_ips.clone()) .collect(); - let (devices, gateway_events) = network - .handle_imported_devices(&mut transaction, imported_devices) - .await?; + let (devices, gateway_events) = + handle_imported_devices(&network, &mut transaction, imported_devices).await?; appstate.send_multiple_wireguard_events(gateway_events); // assign IPs for other existing devices debug!("Assigning IPs in imported network for remaining existing devices"); - let gateway_events = network - .sync_allowed_devices(&mut transaction, Some(&reserved_ips)) - .await?; + let gateway_events = + sync_location_allowed_devices(&network, &mut transaction, Some(&reserved_ips)).await?; appstate.send_multiple_wireguard_events(gateway_events); debug!("Assigned IPs in imported network for remaining existing devices"); @@ -685,9 +677,7 @@ pub(crate) async fn add_user_devices( if let Some(network) = WireguardNetwork::find_by_id(&appstate.pool, network_id).await? { // wrap loop in transaction to abort if a device is invalid let mut transaction = appstate.pool.begin().await?; - let events = network - .handle_mapped_devices(&mut transaction, mapped_devices) - .await?; + let events = handle_mapped_devices(&network, &mut transaction, mapped_devices).await?; appstate.send_multiple_wireguard_events(events); transaction.commit().await?; @@ -866,7 +856,7 @@ pub(crate) async fn add_device( if let Some(location) = WireguardNetwork::find_by_id(&mut *transaction, location_id).await? { if let Some(firewall_config) = - location.try_get_firewall_config(&mut transaction).await? + try_get_location_firewall_config(&location, &mut transaction).await? { debug!( "Sending firewall config update for location {location} affected by adding new user {username} devices" @@ -1176,7 +1166,7 @@ pub(crate) async fn delete_device( WireguardNetwork::find_by_id(&mut *transaction, info.network_id).await? { if let Some(firewall_config) = - location.try_get_firewall_config(&mut transaction).await? + try_get_location_firewall_config(&location, &mut transaction).await? { debug!( "Sending firewall config update for location {location} affected by deleting user {username} device" diff --git a/crates/defguard_core/src/handlers/worker.rs b/crates/defguard_core/src/handlers/worker.rs index cab88e8989..9c4b2fccec 100644 --- a/crates/defguard_core/src/handlers/worker.rs +++ b/crates/defguard_core/src/handlers/worker.rs @@ -4,14 +4,16 @@ use axum::{ extract::{Extension, Json, Path, State}, http::StatusCode, }; -use defguard_common::auth::claims::{Claims, ClaimsType}; +use defguard_common::{ + auth::claims::{Claims, ClaimsType}, + db::models::User, +}; use serde_json::json; use super::{ApiResponse, ApiResult}; use crate::{ appstate::AppState, auth::{AdminRole, SessionInfo}, - db::User, error::WebError, grpc::WorkerState, }; diff --git a/crates/defguard_core/src/handlers/yubikey.rs b/crates/defguard_core/src/handlers/yubikey.rs index 8f112ea296..b8a83b0c0c 100644 --- a/crates/defguard_core/src/handlers/yubikey.rs +++ b/crates/defguard_core/src/handlers/yubikey.rs @@ -3,10 +3,11 @@ use axum::{ extract::{Path, State}, http::StatusCode, }; +use defguard_common::db::models::YubiKey; use serde_json::json; use super::{ApiResponse, ApiResult, user_for_admin_or_self}; -use crate::{appstate::AppState, auth::SessionInfo, db::YubiKey, error::WebError}; +use crate::{appstate::AppState, auth::SessionInfo, error::WebError}; pub async fn delete_yubikey( State(appstate): State, diff --git a/crates/defguard_core/src/headers.rs b/crates/defguard_core/src/headers.rs index cb2fbbbd4c..ca9c2b78b2 100644 --- a/crates/defguard_core/src/headers.rs +++ b/crates/defguard_core/src/headers.rs @@ -1,7 +1,10 @@ use std::{borrow::Borrow, sync::LazyLock}; use axum::http::{HeaderName, HeaderValue}; -use defguard_common::db::{Id, models::DeviceLoginEvent}; +use defguard_common::db::{ + Id, + models::{DeviceLoginEvent, User}, +}; use defguard_mail::{ Mail, templates::{SessionContext, TemplateError}, @@ -10,7 +13,7 @@ use sqlx::PgPool; use tokio::sync::mpsc::UnboundedSender; use uaparser::{Client, Parser, UserAgentParser}; -use crate::{db::User, handlers::mail::send_new_device_login_email}; +use crate::handlers::mail::send_new_device_login_email; pub(crate) const CONTENT_SECURITY_POLICY_HEADER_NAME: HeaderName = HeaderName::from_static("content-security-policy"); diff --git a/crates/defguard_core/src/lib.rs b/crates/defguard_core/src/lib.rs index 0c4f17c3c4..c67f9c8e31 100644 --- a/crates/defguard_core/src/lib.rs +++ b/crates/defguard_core/src/lib.rs @@ -13,12 +13,21 @@ use axum::{ routing::{delete, get, post, put}, serve, }; -use db::models::{device::DeviceType, wireguard::LocationMfaMode}; use defguard_common::{ VERSION, auth::claims::{Claims, ClaimsType}, config::{DefGuardConfig, InitVpnLocationArgs, server_config}, - db::init_db, + db::{ + init_db, + models::{ + Device, DeviceType, User, WireguardNetwork, + oauth2client::OAuth2Client, + wireguard::{ + DEFAULT_DISCONNECT_THRESHOLD, DEFAULT_KEEPALIVE_INTERVAL, LocationMfaMode, + ServiceLocationMode, + }, + }, + }, }; use defguard_mail::Mail; use defguard_version::server::DefguardVersionLayer; @@ -91,13 +100,7 @@ use utoipa_swagger_ui::SwaggerUi; use self::{ appstate::AppState, auth::failed_login::FailedLoginMap, - db::{ - AppEvent, Device, GatewayEvent, User, WireguardNetwork, - models::{ - oauth2client::OAuth2Client, - wireguard::{DEFAULT_DISCONNECT_THRESHOLD, DEFAULT_KEEPALIVE_INTERVAL}, - }, - }, + db::AppEvent, grpc::{WorkerState, gateway::map::GatewayMap}, handlers::{ app_info::get_app_info, @@ -146,19 +149,25 @@ use self::{ worker::{create_job, create_worker_token, job_status, list_workers, remove_worker}, }, }; -use crate::{db::models::wireguard::ServiceLocationMode, version::IncompatibleComponents}; +use crate::{ + grpc::gateway::events::GatewayEvent, location_management::sync_location_allowed_devices, + version::IncompatibleComponents, +}; pub mod appstate; pub mod auth; pub mod db; +pub mod enrollment_management; pub mod enterprise; mod error; pub mod events; pub mod grpc; pub mod handlers; pub mod headers; +pub mod location_management; pub mod support; pub mod updates; +pub mod user_management; pub mod utility_thread; pub mod version; pub mod wg_config; @@ -176,13 +185,13 @@ static PHONE_NUMBER_REGEX: LazyLock = LazyLock::new(|| { .expect("Failed to parse phone number regex") }); -// WireGuard key length in bytes. -pub(crate) const KEY_LENGTH: usize = 32; - mod openapi { - use db::{ - AddDevice, UserDetails, UserInfo, - models::device::{ModifyDevice, UserDevice}, + use defguard_common::{ + db::models::{ + Device, + device::{AddDevice, ModifyDevice, UserDevice}, + }, + types::user_info::UserInfo, }; use handlers::{ ApiResponse, EditGroupInfo, GroupInfo, PasswordChange, PasswordChangeSelf, @@ -197,7 +206,7 @@ mod openapi { }; use super::*; - use crate::{enterprise::snat::handlers as snat, error::WebError}; + use crate::{enterprise::snat::handlers as snat, error::WebError, handlers::user::UserDetails}; #[derive(OpenApi)] #[openapi( @@ -855,7 +864,7 @@ pub async fn init_vpn_location( network.dns.clone_from(&args.dns); network.allowed_ips.clone_from(&args.allowed_ips); network.save(&mut *transaction).await?; - network.sync_allowed_devices(&mut transaction, None).await?; + sync_location_allowed_devices(&network, &mut transaction, None).await?; network } // Otherwise create it with the predefined ID diff --git a/crates/defguard_core/src/location_management/allowed_peers.rs b/crates/defguard_core/src/location_management/allowed_peers.rs new file mode 100644 index 0000000000..7349d70d45 --- /dev/null +++ b/crates/defguard_core/src/location_management/allowed_peers.rs @@ -0,0 +1,263 @@ +use defguard_common::db::{Id, models::WireguardNetwork}; +use defguard_proto::gateway::Peer; +use sqlx::{Error as SqlxError, PgExecutor, query}; + +use crate::grpc::gateway::should_prevent_service_location_usage; + +/// Get a list of all allowed peers for a given location +/// +/// Each device is marked as allowed or not allowed in a given network, +/// which enables enforcing peer disconnect in MFA-protected networks. +/// +/// If the location is a service location, only returns peers if enterprise features are enabled. +pub async fn get_location_allowed_peers<'e, E>( + location: &WireguardNetwork, + executor: E, +) -> Result, SqlxError> +where + E: PgExecutor<'e>, +{ + debug!("Fetching all peers for network {}", location.id); + + if should_prevent_service_location_usage(location) { + warn!( + "Tried to use service location {} with disabled enterprise features. No clients will be allowed to connect.", + location.name + ); + return Ok(Vec::new()); + } + + let rows = query!( + "SELECT d.wireguard_pubkey pubkey, preshared_key, \ + -- TODO possible to not use ARRAY-unnest here? + ARRAY( + SELECT host(ip) + FROM unnest(wnd.wireguard_ips) AS ip + ) \"allowed_ips!: Vec\" \ + FROM wireguard_network_device wnd \ + JOIN device d ON wnd.device_id = d.id \ + JOIN \"user\" u ON d.user_id = u.id \ + WHERE wireguard_network_id = $1 AND (is_authorized = true OR NOT $2) \ + AND d.configured = true \ + AND u.is_active = true \ + ORDER BY d.id ASC", + location.id, + location.mfa_enabled() + ) + .fetch_all(executor) + .await?; + + // keepalive has to be added manually because Postgres + // doesn't support unsigned integers + let result = rows + .into_iter() + .map(|row| Peer { + pubkey: row.pubkey, + allowed_ips: row.allowed_ips, + // Don't send preshared key if MFA is not enabled, it can't be used and may + // cause issues with clients connecting if they expect no preshared key + // e.g. when you disable MFA on a location + preshared_key: if location.mfa_enabled() { + row.preshared_key + } else { + None + }, + keepalive_interval: Some(location.keepalive_interval as u32), + }) + .collect(); + + Ok(result) +} + +#[cfg(test)] +mod test { + use std::{net::IpAddr, str::FromStr}; + + use defguard_common::db::{ + models::{ + Device, DeviceType, WireguardNetwork, + device::WireguardNetworkDevice, + user::User, + wireguard::{LocationMfaMode, ServiceLocationMode}, + }, + setup_pool, + }; + use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; + + use crate::location_management::allowed_peers::get_location_allowed_peers; + + #[sqlx::test] + async fn test_get_peers_service_location_modes(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + + let user = User::new( + "testuser", + Some("password123"), + "Test", + "User", + "test@example.com", + None, + ) + .save(&pool) + .await + .unwrap(); + + let device1 = Device::new( + "device1".into(), + "pubkey1".into(), + user.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + let device2 = Device::new( + "device2".into(), + "pubkey2".into(), + user.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + // Normal location (service_location_mode = Disabled) should return peers + let mut network_normal = WireguardNetwork { + name: "normal-location".to_string(), + service_location_mode: ServiceLocationMode::Disabled, + location_mfa_mode: LocationMfaMode::Disabled, + ..Default::default() + }; + network_normal.try_set_address("10.1.1.1/24").unwrap(); + let network_normal = network_normal.save(&pool).await.unwrap(); + + WireguardNetworkDevice::new( + network_normal.id, + device1.id, + vec![IpAddr::from_str("10.1.1.2").unwrap()], + ) + .insert(&pool) + .await + .unwrap(); + + let peers_normal = get_location_allowed_peers(&network_normal, &pool) + .await + .unwrap(); + assert_eq!(peers_normal.len(), 1, "Normal location should return peers"); + assert_eq!(peers_normal[0].pubkey, "pubkey1"); + + // Service location with PreLogon mode returns peers when enterprise is enabled (test env default) + let mut network_prelogon = WireguardNetwork { + name: "prelogon-service-location".to_string(), + service_location_mode: ServiceLocationMode::PreLogon, + location_mfa_mode: LocationMfaMode::Disabled, + ..Default::default() + }; + network_prelogon.try_set_address("10.2.1.1/24").unwrap(); + let network_prelogon = network_prelogon.save(&pool).await.unwrap(); + + WireguardNetworkDevice::new( + network_prelogon.id, + device2.id, + vec![IpAddr::from_str("10.2.1.2").unwrap()], + ) + .insert(&pool) + .await + .unwrap(); + + // PreLogon service location should return peers when enterprise is enabled + let peers_prelogon = get_location_allowed_peers(&network_prelogon, &pool) + .await + .unwrap(); + assert_eq!( + peers_prelogon.len(), + 1, + "PreLogon service location should return peers when enterprise is enabled" + ); + assert_eq!(peers_prelogon[0].pubkey, "pubkey2"); + + // Service location with AlwaysOn mode also returns peers when enterprise is enabled + let mut network_alwayson = WireguardNetwork { + name: "alwayson-service-location".to_string(), + service_location_mode: ServiceLocationMode::AlwaysOn, + location_mfa_mode: LocationMfaMode::Disabled, + ..Default::default() + }; + network_alwayson.try_set_address("10.3.1.1/24").unwrap(); + let network_alwayson = network_alwayson.save(&pool).await.unwrap(); + + let device3 = Device::new( + "device3".into(), + "pubkey3".into(), + user.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + WireguardNetworkDevice::new( + network_alwayson.id, + device3.id, + vec![IpAddr::from_str("10.3.1.2").unwrap()], + ) + .insert(&pool) + .await + .unwrap(); + + // AlwaysOn service location should return peers when enterprise is enabled + let peers_alwayson = get_location_allowed_peers(&network_alwayson, &pool) + .await + .unwrap(); + assert_eq!( + peers_alwayson.len(), + 1, + "AlwaysOn service location should return peers when enterprise is enabled" + ); + assert_eq!(peers_alwayson[0].pubkey, "pubkey3"); + + // Now test the negative case: service locations with enterprise disabled + // Exceed the enterprise limits to disable enterprise features + use crate::enterprise::limits::{Counts, DEFAULT_LOCATIONS_LIMIT, set_counts}; + let over_limit_counts = Counts::new(1, 1, DEFAULT_LOCATIONS_LIMIT + 1, 0); + set_counts(over_limit_counts); + + // Test that normal location still returns peers even without enterprise + let peers_normal_no_ent = get_location_allowed_peers(&network_normal, &pool) + .await + .unwrap(); + assert_eq!( + peers_normal_no_ent.len(), + 1, + "Normal location should still return peers without enterprise" + ); + + // Test that PreLogon service location returns NO peers without enterprise + let peers_prelogon_no_ent = get_location_allowed_peers(&network_prelogon, &pool) + .await + .unwrap(); + assert!( + peers_prelogon_no_ent.is_empty(), + "PreLogon service location should return NO peers when enterprise is disabled" + ); + + // Test that AlwaysOn service location returns NO peers without enterprise + let peers_alwayson_no_ent = get_location_allowed_peers(&network_alwayson, &pool) + .await + .unwrap(); + assert!( + peers_alwayson_no_ent.is_empty(), + "AlwaysOn service location should return NO peers when enterprise is disabled" + ); + + let normal_counts = Counts::new(0, 0, 0, 0); + set_counts(normal_counts); + } +} diff --git a/crates/defguard_core/src/location_management/mod.rs b/crates/defguard_core/src/location_management/mod.rs new file mode 100644 index 0000000000..8ff7a45642 --- /dev/null +++ b/crates/defguard_core/src/location_management/mod.rs @@ -0,0 +1,660 @@ +use std::{collections::HashMap, net::IpAddr}; + +use defguard_common::{ + csv::AsCsv, + db::{ + Id, + models::{ + Device, DeviceNetworkInfo, DeviceType, ModelError, WireguardNetwork, + WireguardNetworkError, + device::{DeviceInfo, WireguardNetworkDevice}, + user::User, + wireguard::MappedDevice, + }, + }, +}; +use sqlx::PgConnection; +use thiserror::Error; +use tokio::sync::broadcast::Sender; + +use crate::{ + enterprise::firewall::{FirewallError, try_get_location_firewall_config}, + grpc::gateway::{events::GatewayEvent, send_multiple_wireguard_events}, + wg_config::ImportedDevice, +}; + +pub mod allowed_peers; + +#[derive(Debug, Error)] +pub enum LocationManagementError { + #[error(transparent)] + FirewallError(#[from] FirewallError), + #[error(transparent)] + DbError(#[from] sqlx::Error), + #[error(transparent)] + WireguardNetworkError(#[from] WireguardNetworkError), + #[error(transparent)] + ModelError(#[from] ModelError), +} + +// run sync_allowed_devices on all wireguard networks +pub(crate) async fn sync_all_networks( + conn: &mut PgConnection, + wireguard_tx: &Sender, +) -> Result<(), LocationManagementError> { + info!("Syncing allowed devices for all WireGuard locations"); + let locations = WireguardNetwork::all(&mut *conn).await?; + for network in locations { + // sync allowed devices for location + let mut gateway_events = sync_location_allowed_devices(&network, &mut *conn, None).await?; + + // send firewall config update if ACLs are enabled for a given location + if let Some(firewall_config) = + try_get_location_firewall_config(&network, &mut *conn).await? + { + gateway_events.push(GatewayEvent::FirewallConfigChanged( + network.id, + firewall_config, + )); + } + // check if any gateway events need to be sent + if !gateway_events.is_empty() { + send_multiple_wireguard_events(gateway_events, wireguard_tx); + } + } + Ok(()) +} + +/// Refresh network IPs for all relevant devices +/// +/// If the list of allowed devices has changed add/remove devices accordingly +/// +/// If the network address has changed readdress existing devices +pub(crate) async fn sync_location_allowed_devices( + location: &WireguardNetwork, + conn: &mut PgConnection, + reserved_ips: Option<&[IpAddr]>, +) -> Result, LocationManagementError> { + info!("Synchronizing IPs in network {location} for all allowed devices "); + // list all allowed devices + let mut allowed_devices = location.get_allowed_devices(&mut *conn).await?; + + // network devices are always allowed, make sure to take only network devices already assigned to that network + let network_devices = + Device::find_by_type_and_network(&mut *conn, DeviceType::Network, location.id).await?; + allowed_devices.extend(network_devices); + + // convert to a map for easier processing + let allowed_devices: HashMap> = allowed_devices + .into_iter() + .map(|dev| (dev.id, dev)) + .collect(); + + // check if all devices can fit within network + // include address, network, and broadcast in the calculation + let count = allowed_devices.len() + 3; + location.validate_network_size(count)?; + + // list all assigned IPs + let assigned_ips = WireguardNetworkDevice::all_for_network(&mut *conn, location.id).await?; + + let events = process_device_access_changes( + location, + &mut *conn, + allowed_devices, + assigned_ips, + reserved_ips, + ) + .await?; + + Ok(events) +} + +/// Refresh network IPs for all relevant devices of a given user +/// If the list of allowed devices has changed add/remove devices accordingly +/// If the network address has changed readdress existing devices +pub(crate) async fn sync_allowed_devices_for_user( + location: &WireguardNetwork, + transaction: &mut PgConnection, + user: &User, + reserved_ips: Option<&[IpAddr]>, +) -> Result, WireguardNetworkError> { + info!("Synchronizing IPs in network {location} for all allowed devices "); + // list all allowed devices + let allowed_devices = location + .get_allowed_devices_for_user(&mut *transaction, user.id) + .await?; + + // convert to a map for easier processing + let allowed_devices: HashMap> = allowed_devices + .into_iter() + .map(|dev| (dev.id, dev)) + .collect(); + + // check if all devices can fit within network + // include address, network, and broadcast in the calculation + let count = allowed_devices.len() + 3; + location.validate_network_size(count)?; + + // list all assigned IPs + let assigned_ips = + WireguardNetworkDevice::all_for_network_and_user(&mut *transaction, location.id, user.id) + .await?; + + let events = process_device_access_changes( + location, + &mut *transaction, + allowed_devices, + assigned_ips, + reserved_ips, + ) + .await?; + + Ok(events) +} + +/// Works out which devices need to be added, removed, or readdressed based on the list +/// of currently configured devices and the list of devices which should be allowed. +pub async fn process_device_access_changes( + location: &WireguardNetwork, + transaction: &mut PgConnection, + mut allowed_devices: HashMap>, + currently_configured_devices: Vec, + reserved_ips: Option<&[IpAddr]>, +) -> Result, WireguardNetworkError> { + // Loop through current device configurations; remove no longer allowed, readdress + // when necessary; remove processed entry from all devices list initial list should + // now contain only devices to be added. + let mut events: Vec = Vec::new(); + for device_network_config in currently_configured_devices { + // Device is allowed and an IP was already assigned + if let Some(device) = allowed_devices.remove(&device_network_config.device_id) { + // Network address has changed and IP addresses need to be updated + if !location.contains_all(&device_network_config.wireguard_ips) + || location.address.len() != device_network_config.wireguard_ips.len() + { + let wireguard_network_device = device + .assign_next_network_ip( + &mut *transaction, + location, + reserved_ips, + Some(&device_network_config.wireguard_ips), + ) + .await?; + events.push(GatewayEvent::DeviceModified(DeviceInfo { + device, + network_info: vec![DeviceNetworkInfo { + network_id: location.id, + device_wireguard_ips: wireguard_network_device.wireguard_ips, + preshared_key: wireguard_network_device.preshared_key, + is_authorized: wireguard_network_device.is_authorized, + }], + })); + } + // Device is no longer allowed + } else { + debug!( + "Device {} no longer allowed, removing network config for {location}", + device_network_config.device_id + ); + device_network_config.delete(&mut *transaction).await?; + if let Some(device) = + Device::find_by_id(&mut *transaction, device_network_config.device_id).await? + { + events.push(GatewayEvent::DeviceDeleted(DeviceInfo { + device, + network_info: vec![DeviceNetworkInfo { + network_id: location.id, + device_wireguard_ips: device_network_config.wireguard_ips, + preshared_key: device_network_config.preshared_key, + is_authorized: device_network_config.is_authorized, + }], + })); + } else { + let msg = format!("Device {} does not exist", device_network_config.device_id); + error!(msg); + return Err(WireguardNetworkError::Unexpected(msg)); + } + } + } + + // Add configs for new allowed devices + for device in allowed_devices.into_values() { + let wireguard_network_device = device + .assign_next_network_ip(&mut *transaction, location, reserved_ips, None) + .await?; + events.push(GatewayEvent::DeviceCreated(DeviceInfo { + device, + network_info: vec![DeviceNetworkInfo { + network_id: location.id, + device_wireguard_ips: wireguard_network_device.wireguard_ips, + preshared_key: wireguard_network_device.preshared_key, + is_authorized: wireguard_network_device.is_authorized, + }], + })); + } + + Ok(events) +} + +/// Check if devices found in an imported config file exist already, +/// if they do assign a specified IP. +/// Return a list of imported devices which need to be manually mapped to a user +/// and a list of WireGuard events to be sent out. +pub(crate) async fn handle_imported_devices( + location: &WireguardNetwork, + transaction: &mut PgConnection, + imported_devices: Vec, +) -> Result<(Vec, Vec), WireguardNetworkError> { + let allowed_devices = location.get_allowed_devices(&mut *transaction).await?; + // convert to a map for easier processing + let allowed_devices: HashMap> = allowed_devices + .into_iter() + .map(|dev| (dev.id, dev)) + .collect(); + + let mut devices_to_map = Vec::new(); + let mut assigned_device_ids = Vec::new(); + let mut events = Vec::new(); + for imported_device in imported_devices { + // check if device with a given pubkey exists already + match Device::find_by_pubkey(&mut *transaction, &imported_device.wireguard_pubkey).await? { + Some(existing_device) => { + // check if device is allowed in network + match allowed_devices.get(&existing_device.id) { + Some(_) => { + info!( + "Device with pubkey {} exists already, assigning IPs {} for new network: {location}", + existing_device.wireguard_pubkey, + imported_device.wireguard_ips.as_csv() + ); + let wireguard_network_device = WireguardNetworkDevice::new( + location.id, + existing_device.id, + imported_device.wireguard_ips, + ); + wireguard_network_device.insert(&mut *transaction).await?; + // store ID of device with already generated config + assigned_device_ids.push(existing_device.id); + // send device to connected gateways + events.push(GatewayEvent::DeviceModified(DeviceInfo { + device: existing_device, + network_info: vec![DeviceNetworkInfo { + network_id: location.id, + device_wireguard_ips: wireguard_network_device.wireguard_ips, + preshared_key: wireguard_network_device.preshared_key, + is_authorized: wireguard_network_device.is_authorized, + }], + })); + } + None => { + warn!( + "Device with pubkey {} exists already, but is not allowed in network {location}. Skipping...", + existing_device.wireguard_pubkey + ); + } + } + } + None => devices_to_map.push(imported_device), + } + } + + Ok((devices_to_map, events)) +} + +/// Handle device -> user mapping in second step of network import wizard +pub(crate) async fn handle_mapped_devices( + location: &WireguardNetwork, + transaction: &mut PgConnection, + mapped_devices: Vec, +) -> Result, WireguardNetworkError> { + info!("Mapping user devices for network {}", location); + // get allowed groups for network + let allowed_groups = location.get_allowed_groups(&mut *transaction).await?; + + let mut events = Vec::new(); + // use a helper hashmap to avoid repeated queries + let mut user_groups = HashMap::new(); + for mapped_device in &mapped_devices { + debug!("Mapping device {}", mapped_device.name); + // validate device pubkey + Device::validate_pubkey(&mapped_device.wireguard_pubkey).map_err(|_| { + WireguardNetworkError::InvalidDevicePubkey(mapped_device.wireguard_pubkey.clone()) + })?; + // save a new device + let device = Device::new( + mapped_device.name.clone(), + mapped_device.wireguard_pubkey.clone(), + mapped_device.user_id, + DeviceType::User, + None, + true, + ) + .save(&mut *transaction) + .await?; + debug!("Saved new device {device}"); + + // get a list of groups user is assigned to + let groups = match user_groups.get(&device.user_id) { + // user info has already been fetched before + Some(groups) => groups, + // fetch user info + None => match User::find_by_id(&mut *transaction, device.user_id).await? { + Some(user) => { + let groups = user.member_of_names(&mut *transaction).await?; + user_groups.insert(device.user_id, groups); + // FIXME: ugly workaround to get around `groups` being dropped + user_groups.get(&device.user_id).unwrap() + } + None => return Err(WireguardNetworkError::from(ModelError::NotFound)), + }, + }; + + let mut network_info = Vec::new(); + match &allowed_groups { + None => { + let wireguard_network_device = WireguardNetworkDevice::new( + location.id, + device.id, + mapped_device.wireguard_ips.clone(), + ); + wireguard_network_device.insert(&mut *transaction).await?; + network_info.push(DeviceNetworkInfo { + network_id: location.id, + device_wireguard_ips: wireguard_network_device.wireguard_ips, + preshared_key: wireguard_network_device.preshared_key, + is_authorized: wireguard_network_device.is_authorized, + }); + } + Some(allowed) => { + // check if user belongs to an allowed group + if allowed.iter().any(|group| groups.contains(group)) { + // assign specified IP in imported network + let wireguard_network_device = WireguardNetworkDevice::new( + location.id, + device.id, + mapped_device.wireguard_ips.clone(), + ); + wireguard_network_device.insert(&mut *transaction).await?; + network_info.push(DeviceNetworkInfo { + network_id: location.id, + device_wireguard_ips: wireguard_network_device.wireguard_ips, + preshared_key: wireguard_network_device.preshared_key, + is_authorized: wireguard_network_device.is_authorized, + }); + } + } + } + + // assign IPs in other networks + let (mut all_network_info, _configs) = + device.add_to_all_networks(&mut *transaction).await?; + + network_info.append(&mut all_network_info); + + // send device to connected gateways + if !network_info.is_empty() { + events.push(GatewayEvent::DeviceCreated(DeviceInfo { + device, + network_info, + })); + } + } + + Ok(events) +} + +#[cfg(test)] +mod test { + use defguard_common::db::{models::group::Group, setup_pool}; + use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; + + use super::*; + use crate::grpc::gateway::events::GatewayEvent; + + #[sqlx::test] + async fn test_sync_allowed_devices_for_user(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + let mut network = WireguardNetwork::default(); + network.try_set_address("10.1.1.1/29").unwrap(); + let network = network.save(&pool).await.unwrap(); + + let user1 = User::new( + "testuser1", + Some("pass1"), + "Tester1", + "Test1", + "test1@test.com", + None, + ) + .save(&pool) + .await + .unwrap(); + + let user2 = User::new( + "testuser2", + Some("pass2"), + "Tester2", + "Test2", + "test2@test.com", + None, + ) + .save(&pool) + .await + .unwrap(); + + let device1 = Device::new( + "device1".into(), + "key1".into(), + user1.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + let device2 = Device::new( + "device2".into(), + "key2".into(), + user1.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + let device3 = Device::new( + "device3".into(), + "key3".into(), + user2.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + let mut transaction = pool.begin().await.unwrap(); + + // user1 sync + let events = sync_allowed_devices_for_user(&network, &mut transaction, &user1, None) + .await + .unwrap(); + + assert_eq!(events.len(), 2); + assert!(events.iter().any(|e| match e { + GatewayEvent::DeviceCreated(info) => info.device.id == device1.id, + _ => false, + })); + assert!(events.iter().any(|e| match e { + GatewayEvent::DeviceCreated(info) => info.device.id == device2.id, + _ => false, + })); + + // user 2 sync + let events = sync_allowed_devices_for_user(&network, &mut transaction, &user2, None) + .await + .unwrap(); + + assert_eq!(events.len(), 1); + match &events[0] { + GatewayEvent::DeviceCreated(info) => { + assert_eq!(info.device.id, device3.id); + } + _ => panic!("Expected DeviceCreated event"), + } + + // Second sync should not generate any events + let events = sync_allowed_devices_for_user(&network, &mut transaction, &user1, None) + .await + .unwrap(); + assert_eq!(events.len(), 0); + + transaction.commit().await.unwrap(); + } + + #[sqlx::test] + async fn test_sync_allowed_devices_for_user_with_groups( + _: PgPoolOptions, + options: PgConnectOptions, + ) { + let pool = setup_pool(options).await; + let mut network = WireguardNetwork::default(); + network.try_set_address("10.1.1.1/29").unwrap(); + let network = network.save(&pool).await.unwrap(); + + let user1 = User::new( + "testuser1", + Some("pass1"), + "Tester1", + "Test1", + "test1@test.com", + None, + ) + .save(&pool) + .await + .unwrap(); + + let user2 = User::new( + "testuser2", + Some("pass2"), + "Tester2", + "Test2", + "test2@test.com", + None, + ) + .save(&pool) + .await + .unwrap(); + + let user3 = User::new( + "testuser3", + Some("pass3"), + "Tester3", + "Test3", + "test3@test.com", + None, + ) + .save(&pool) + .await + .unwrap(); + + let device1 = Device::new( + "device1".into(), + "key1".into(), + user1.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + let device2 = Device::new( + "device2".into(), + "key2".into(), + user2.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + let device3 = Device::new( + "device3".into(), + "key3".into(), + user3.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .unwrap(); + + let group1 = Group::new("group1").save(&pool).await.unwrap(); + let group2 = Group::new("group2").save(&pool).await.unwrap(); + + let mut transaction = pool.begin().await.unwrap(); + + network + .set_allowed_groups( + &mut transaction, + vec![group1.name.clone(), group2.name.clone()], + ) + .await + .unwrap(); + + let events = sync_allowed_devices_for_user(&network, &mut transaction, &user1, None) + .await + .unwrap(); + assert_eq!(events.len(), 0); + + user1.add_to_group(&pool, &group1).await.unwrap(); + user2.add_to_group(&pool, &group1).await.unwrap(); + user3.add_to_group(&pool, &group2).await.unwrap(); + + let events = sync_allowed_devices_for_user(&network, &mut transaction, &user1, None) + .await + .unwrap(); + assert_eq!(events.len(), 1); + match &events[0] { + GatewayEvent::DeviceCreated(info) => { + assert_eq!(info.device.id, device1.id); + } + _ => panic!("Expected DeviceCreated event"), + } + + let events = sync_allowed_devices_for_user(&network, &mut transaction, &user2, None) + .await + .unwrap(); + assert_eq!(events.len(), 1); + match &events[0] { + GatewayEvent::DeviceCreated(info) => { + assert_eq!(info.device.id, device2.id); + } + _ => panic!("Expected DeviceCreated event"), + } + + let events = sync_allowed_devices_for_user(&network, &mut transaction, &user3, None) + .await + .unwrap(); + assert_eq!(events.len(), 1); + match &events[0] { + GatewayEvent::DeviceCreated(info) => { + assert_eq!(info.device.id, device3.id); + } + _ => panic!("Expected DeviceCreated event"), + } + + transaction.commit().await.unwrap(); + } +} diff --git a/crates/defguard_core/src/support.rs b/crates/defguard_core/src/support.rs index e37bf4e770..f9d0a99d77 100644 --- a/crates/defguard_core/src/support.rs +++ b/crates/defguard_core/src/support.rs @@ -2,16 +2,16 @@ use std::{collections::HashMap, fmt::Display}; use defguard_common::{ VERSION, - db::{Id, models::Settings}, + db::{ + Id, + models::{Settings, User, WireguardNetwork, device::WireguardNetworkDevice}, + }, }; use serde::Serialize; use serde_json::{Value, json, value::to_value}; use sqlx::PgPool; -use crate::{ - db::{User, WireguardNetwork, models::device::WireguardNetworkDevice}, - server_config, -}; +use crate::server_config; /// Unwraps the result returning a JSON representation of value or error fn unwrap_json(result: Result) -> Value { diff --git a/crates/defguard_core/src/user_management.rs b/crates/defguard_core/src/user_management.rs new file mode 100644 index 0000000000..449e5d5b25 --- /dev/null +++ b/crates/defguard_core/src/user_management.rs @@ -0,0 +1,111 @@ +use std::collections::HashSet; + +use defguard_common::db::{ + Id, + models::{User, WireguardNetwork, device::DeviceInfo}, +}; +use sqlx::PgConnection; +use tokio::sync::broadcast::Sender; + +use crate::{ + enterprise::{firewall::try_get_location_firewall_config, limits::update_counts}, + error::WebError, + grpc::gateway::{events::GatewayEvent, send_multiple_wireguard_events, send_wireguard_event}, + location_management::sync_allowed_devices_for_user, +}; + +/// Deletes the user and cleans up his devices from gateways +pub async fn delete_user_and_cleanup_devices( + user: User, + conn: &mut PgConnection, + wg_tx: &Sender, +) -> Result<(), WebError> { + let username = user.username.clone(); + debug!("Deleting user {username}, removing his devices from gateways and updating ldap...",); + let devices = user.devices(&mut *conn).await?; + let mut events = Vec::new(); + + // get all locations affected by devices being deleted + let mut affected_location_ids = HashSet::new(); + + for device in devices { + let device_info = DeviceInfo::from_device(&mut *conn, device).await?; + for network_info in &device_info.network_info { + affected_location_ids.insert(network_info.network_id); + } + events.push(GatewayEvent::DeviceDeleted(device_info)); + } + + user.delete(&mut *conn).await?; + update_counts(&mut *conn).await?; + + // send firewall config updates to affected locations + // if they have ACL enabled & enterprise features are active + for location_id in affected_location_ids { + if let Some(location) = WireguardNetwork::find_by_id(&mut *conn, location_id).await? { + if let Some(firewall_config) = + try_get_location_firewall_config(&location, &mut *conn).await? + { + debug!( + "Sending firewall config update for location {location} affected by deleting user {username} devices" + ); + events.push(GatewayEvent::FirewallConfigChanged( + location_id, + firewall_config, + )); + } + } + } + + send_multiple_wireguard_events(events, wg_tx); + info!( + "The user {} has been deleted and his devices removed from gateways.", + &username + ); + Ok(()) +} + +/// Disable user, log out all his sessions and update gateways state. +pub async fn disable_user( + user: &mut User, + conn: &mut PgConnection, + wg_tx: &Sender, +) -> Result<(), WebError> { + user.is_active = false; + user.save(&mut *conn).await?; + user.logout_all_sessions(&mut *conn).await?; + sync_allowed_user_devices(user, conn, wg_tx).await?; + Ok(()) +} + +/// Update gateway state based on this user device access rights +pub async fn sync_allowed_user_devices( + user: &User, + conn: &mut PgConnection, + wg_tx: &Sender, +) -> Result<(), WebError> { + debug!("Syncing allowed devices of user {}", user.username); + let locations = WireguardNetwork::all(&mut *conn).await?; + for location in locations { + let gateway_events = + sync_allowed_devices_for_user(&location, &mut *conn, user, None).await?; + + // check if any peers were updated + if !gateway_events.is_empty() { + // send peer update events + send_multiple_wireguard_events(gateway_events, wg_tx); + } + + // send firewall config update if ACLs & enterprise features are enabled + if let Some(firewall_config) = + try_get_location_firewall_config(&location, &mut *conn).await? + { + send_wireguard_event( + GatewayEvent::FirewallConfigChanged(location.id, firewall_config), + wg_tx, + ); + } + } + info!("Allowed devices of user {} synced", user.username); + Ok(()) +} diff --git a/crates/defguard_core/src/utility_thread.rs b/crates/defguard_core/src/utility_thread.rs index 5a0de67b32..621c64d120 100644 --- a/crates/defguard_core/src/utility_thread.rs +++ b/crates/defguard_core/src/utility_thread.rs @@ -1,6 +1,9 @@ use std::{collections::HashSet, time::Duration}; -use defguard_common::db::Id; +use defguard_common::db::{ + Id, + models::{WireguardNetwork, wireguard::ServiceLocationMode}, +}; use sqlx::{PgPool, query_as}; use tokio::{ sync::broadcast::Sender, @@ -9,14 +12,16 @@ use tokio::{ use tracing::Instrument; use crate::{ - db::{GatewayEvent, WireguardNetwork, models::wireguard::ServiceLocationMode}, enterprise::{ db::models::acl::{AclRule, RuleState}, directory_sync::{do_directory_sync, get_directory_sync_interval}, + firewall::try_get_location_firewall_config, is_enterprise_enabled, ldap::{do_ldap_sync, sync::get_ldap_sync_interval}, limits::do_count_update, }, + grpc::gateway::events::GatewayEvent, + location_management::allowed_peers::get_location_allowed_peers, updates::do_new_version_check, }; @@ -175,14 +180,14 @@ async fn enterprise_status_check( let mut transaction = pool.begin().await?; for location in locations { debug!("Re-enabling gateway firewall configuration for location {location:?}"); - let firewall_config = location - .try_get_firewall_config(&mut transaction) + let firewall_config = try_get_location_firewall_config(&location, &mut transaction) .await? .expect("ACL-enabled location must have firewall config"); // Handle service location update or just update the firewall if location.service_location_mode != ServiceLocationMode::Disabled { - let new_peers = location.get_peers(&mut *transaction).await?; + let new_peers = + get_location_allowed_peers(&location, &mut *transaction).await?; wireguard_tx.send(GatewayEvent::NetworkModified( location.id, location, @@ -266,7 +271,7 @@ async fn expired_acl_rules_check( let mut conn = pool.acquire().await?; for location in affected_locations { - match location.try_get_firewall_config(&mut conn).await? { + match try_get_location_firewall_config(&location, &mut conn).await? { Some(firewall_config) => { debug!("Sending firewall update event for location {location}"); wireguard_tx.send(GatewayEvent::FirewallConfigChanged( diff --git a/crates/defguard_core/src/wg_config.rs b/crates/defguard_core/src/wg_config.rs index 38d40c6f0c..cde33aa48b 100644 --- a/crates/defguard_core/src/wg_config.rs +++ b/crates/defguard_core/src/wg_config.rs @@ -5,11 +5,11 @@ use ipnetwork::{IpNetwork, IpNetworkError}; use thiserror::Error; use x25519_dalek::{PublicKey, StaticSecret}; -use crate::{ +use defguard_common::{ KEY_LENGTH, - db::{ + db::models::{ Device, WireguardNetwork, - models::wireguard::{ + wireguard::{ DEFAULT_DISCONNECT_THRESHOLD, DEFAULT_KEEPALIVE_INTERVAL, LocationMfaMode, ServiceLocationMode, WireguardNetworkError, }, diff --git a/crates/defguard_core/src/wireguard_peer_disconnect.rs b/crates/defguard_core/src/wireguard_peer_disconnect.rs index 12a5d02acb..94a4102d2c 100644 --- a/crates/defguard_core/src/wireguard_peer_disconnect.rs +++ b/crates/defguard_core/src/wireguard_peer_disconnect.rs @@ -11,7 +11,14 @@ use std::{ }; use chrono::NaiveDateTime; -use defguard_common::db::{Id, models::ModelError}; +use defguard_common::db::{ + Id, + models::{ + Device, DeviceNetworkInfo, DeviceType, ModelError, WireguardNetwork, WireguardNetworkError, + device::{DeviceInfo, WireguardNetworkDevice}, + wireguard::{LocationMfaMode, ServiceLocationMode}, + }, +}; use sqlx::{Error as SqlxError, PgPool, query_as}; use thiserror::Error; use tokio::{ @@ -23,14 +30,8 @@ use tokio::{ }; use crate::{ - db::{ - Device, GatewayEvent, WireguardNetwork, - models::{ - device::{DeviceInfo, DeviceNetworkInfo, DeviceType, WireguardNetworkDevice}, - wireguard::{LocationMfaMode, ServiceLocationMode, WireguardNetworkError}, - }, - }, events::{InternalEvent, InternalEventContext}, + grpc::gateway::events::GatewayEvent, }; // How long to sleep between loop iterations diff --git a/crates/defguard_core/src/wireguard_stats_purge.rs b/crates/defguard_core/src/wireguard_stats_purge.rs index 25932685f5..ff36e1da16 100644 --- a/crates/defguard_core/src/wireguard_stats_purge.rs +++ b/crates/defguard_core/src/wireguard_stats_purge.rs @@ -1,12 +1,11 @@ use std::time::Duration; use chrono::{TimeDelta, Utc}; +use defguard_common::db::models::wireguard_peer_stats::WireguardPeerStats; use humantime::format_duration; use sqlx::PgPool; use tokio::time::sleep; -use crate::db::models::wireguard_peer_stats::WireguardPeerStats; - // How long to sleep between loop iterations const PURGE_LOOP_SLEEP: Duration = Duration::from_secs(300); // 5 minutes diff --git a/crates/defguard_core/tests/integration/api/acl.rs b/crates/defguard_core/tests/integration/api/acl.rs index 1272d5d8d9..469aa4a756 100644 --- a/crates/defguard_core/tests/integration/api/acl.rs +++ b/crates/defguard_core/tests/integration/api/acl.rs @@ -1,15 +1,16 @@ use defguard_common::{ config::DefGuardConfig, - db::{Id, models::settings::initialize_current_settings}, -}; -use defguard_core::{ db::{ - Device, Group, User, WireguardNetwork, + Id, models::{ - device::DeviceType, + Device, DeviceType, User, WireguardNetwork, + group::Group, + settings::initialize_current_settings, wireguard::{LocationMfaMode, ServiceLocationMode}, }, }, +}; +use defguard_core::{ enterprise::{ db::models::acl::{AclAlias, AclRule, AliasKind, AliasState, RuleState}, handlers::acl::{ApiAclAlias, ApiAclRule, EditAclAlias, EditAclRule}, diff --git a/crates/defguard_core/tests/integration/api/api_tokens.rs b/crates/defguard_core/tests/integration/api/api_tokens.rs index a692c156c2..81c563251b 100644 --- a/crates/defguard_core/tests/integration/api/api_tokens.rs +++ b/crates/defguard_core/tests/integration/api/api_tokens.rs @@ -1,6 +1,9 @@ use chrono::Utc; +use defguard_common::{ + db::models::group::{Group, Permission}, + types::user_info::UserInfo, +}; use defguard_core::{ - db::{Group, UserInfo, models::group::Permission}, enterprise::{ db::models::api_tokens::{ApiToken, ApiTokenInfo}, handlers::api_tokens::{AddApiTokenData, RenameRequest}, diff --git a/crates/defguard_core/tests/integration/api/auth.rs b/crates/defguard_core/tests/integration/api/auth.rs index 60d3c29e2b..08b32b55cf 100644 --- a/crates/defguard_core/tests/integration/api/auth.rs +++ b/crates/defguard_core/tests/integration/api/auth.rs @@ -2,12 +2,14 @@ use std::time::SystemTime; use chrono::DateTime; use claims::{assert_err, assert_ok}; -use defguard_common::db::models::{MFAMethod, Settings, settings::update_current_settings}; +use defguard_common::db::models::{ + MFAInfo, MFAMethod, Settings, User, + settings::update_current_settings, + user::{TOTP_CODE_DIGITS, TOTP_CODE_VALIDITY_PERIOD}, +}; use defguard_core::{ - auth::{TOTP_CODE_DIGITS, TOTP_CODE_VALIDITY_PERIOD}, - db::{MFAInfo, User, UserDetails}, events::ApiEventType, - handlers::{Auth, AuthCode, AuthResponse, AuthTotp}, + handlers::{Auth, AuthCode, AuthResponse, AuthTotp, user::UserDetails}, }; use reqwest::{StatusCode, header::USER_AGENT}; use serde::Deserialize; diff --git a/crates/defguard_core/tests/integration/api/common/mod.rs b/crates/defguard_core/tests/integration/api/common/mod.rs index 1c4e222445..96c3b86652 100644 --- a/crates/defguard_core/tests/integration/api/common/mod.rs +++ b/crates/defguard_core/tests/integration/api/common/mod.rs @@ -9,16 +9,22 @@ pub use defguard_common::db::setup_pool; use defguard_common::{ VERSION, config::DefGuardConfig, - db::{Id, NoId, models::settings::initialize_current_settings}, + db::{ + Id, NoId, + models::{Device, User, WireguardNetwork, settings::initialize_current_settings}, + }, }; use defguard_core::{ auth::failed_login::FailedLoginMap, build_webapp, - db::{AppEvent, Device, GatewayEvent, User, UserDetails, WireguardNetwork}, + db::AppEvent, enterprise::license::{License, set_cached_license}, events::ApiEvent, - grpc::{WorkerState, gateway::map::GatewayMap}, - handlers::Auth, + grpc::{ + WorkerState, + gateway::{events::GatewayEvent, map::GatewayMap}, + }, + handlers::{Auth, user::UserDetails}, }; use defguard_mail::Mail; use reqwest::{StatusCode, header::HeaderName}; diff --git a/crates/defguard_core/tests/integration/api/enrollment.rs b/crates/defguard_core/tests/integration/api/enrollment.rs index 5ca0c8eea5..1997d38135 100644 --- a/crates/defguard_core/tests/integration/api/enrollment.rs +++ b/crates/defguard_core/tests/integration/api/enrollment.rs @@ -1,6 +1,7 @@ use chrono::Duration; +use defguard_common::db::models::User; use defguard_core::{ - db::{User, models::enrollment::Token}, + db::models::enrollment::Token, handlers::{AddUserData, Auth}, }; use reqwest::StatusCode; diff --git a/crates/defguard_core/tests/integration/api/oauth.rs b/crates/defguard_core/tests/integration/api/oauth.rs index 4e7329734f..7705d96b26 100644 --- a/crates/defguard_core/tests/integration/api/oauth.rs +++ b/crates/defguard_core/tests/integration/api/oauth.rs @@ -1,16 +1,13 @@ use std::borrow::Cow; -use defguard_common::db::Id; -use defguard_core::{ - db::{ +use defguard_common::db::{ + Id, + models::{ OAuth2AuthorizedApp, - models::{ - NewOpenIDClient, - oauth2client::{OAuth2Client, OAuth2ClientSafe}, - }, + oauth2client::{OAuth2Client, OAuth2ClientSafe}, }, - handlers::Auth, }; +use defguard_core::handlers::{Auth, openid_clients::NewOpenIDClient}; use reqwest::{StatusCode, Url, header::CONTENT_TYPE}; use serde_json::json; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; diff --git a/crates/defguard_core/tests/integration/api/openid.rs b/crates/defguard_core/tests/integration/api/openid.rs index a53a0f0c84..ceb643ebe3 100644 --- a/crates/defguard_core/tests/integration/api/openid.rs +++ b/crates/defguard_core/tests/integration/api/openid.rs @@ -2,14 +2,11 @@ use std::str::FromStr; use axum::http::header::ToStrError; use claims::assert_err; -use defguard_common::db::Id; -use defguard_core::{ - db::{ - User, - models::{NewOpenIDClient, oauth2client::OAuth2Client}, - }, - handlers::Auth, +use defguard_common::db::{ + Id, + models::{OAuth2AuthorizedApp, User, oauth2client::OAuth2Client}, }; +use defguard_core::handlers::{Auth, openid_clients::NewOpenIDClient}; use openidconnect::{ AuthenticationFlow, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EmptyAdditionalClaims, HttpRequest, HttpResponse, IssuerUrl, Nonce, OAuth2TokenResponse, @@ -977,7 +974,6 @@ async fn dg25_23_test_openid_client_scope_change_clears_authorizations( assert_eq!(response.status(), StatusCode::FOUND); // Verify that the authorization was created - use defguard_core::db::OAuth2AuthorizedApp; let authorized_app = OAuth2AuthorizedApp::find_by_user_and_oauth2client_id( &state.pool, admin.id, diff --git a/crates/defguard_core/tests/integration/api/openid_login.rs b/crates/defguard_core/tests/integration/api/openid_login.rs index 923633fe1a..8ef0ce1b42 100644 --- a/crates/defguard_core/tests/integration/api/openid_login.rs +++ b/crates/defguard_core/tests/integration/api/openid_login.rs @@ -1,13 +1,15 @@ use chrono::{Duration, Utc}; -use defguard_common::db::{Id, models::settings::OpenidUsernameHandling}; +use defguard_common::db::{ + Id, + models::{oauth2client::OAuth2Client, settings::OpenidUsernameHandling}, +}; use defguard_core::{ - db::models::{NewOpenIDClient, oauth2client::OAuth2Client}, enterprise::{ db::models::openid_provider::{DirectorySyncTarget, DirectorySyncUserBehavior}, handlers::openid_providers::AddProviderData, license::{License, set_cached_license}, }, - handlers::Auth, + handlers::{Auth, openid_clients::NewOpenIDClient}, }; use reqwest::{StatusCode, Url}; use serde::Deserialize; diff --git a/crates/defguard_core/tests/integration/api/user.rs b/crates/defguard_core/tests/integration/api/user.rs index 1028c6093e..dfcf4fdb15 100644 --- a/crates/defguard_core/tests/integration/api/user.rs +++ b/crates/defguard_core/tests/integration/api/user.rs @@ -1,11 +1,16 @@ -use defguard_common::db::Id; -use defguard_core::{ +use defguard_common::{ db::{ - AddDevice, UserInfo, - models::{NewOpenIDClient, oauth2client::OAuth2Client}, + Id, + models::{device::AddDevice, oauth2client::OAuth2Client}, }, + types::user_info::UserInfo, +}; +use defguard_core::{ events::ApiEventType, - handlers::{AddUserData, Auth, PasswordChange, PasswordChangeSelf, Username}, + handlers::{ + AddUserData, Auth, PasswordChange, PasswordChangeSelf, Username, + openid_clients::NewOpenIDClient, + }, }; use reqwest::{StatusCode, header::USER_AGENT}; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; diff --git a/crates/defguard_core/tests/integration/api/wireguard.rs b/crates/defguard_core/tests/integration/api/wireguard.rs index 36c5ac4e2b..edac444096 100644 --- a/crates/defguard_core/tests/integration/api/wireguard.rs +++ b/crates/defguard_core/tests/integration/api/wireguard.rs @@ -1,22 +1,24 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; -use defguard_common::db::{Id, models::settings::OpenidUsernameHandling}; -use defguard_core::{ - db::{ - Device, GatewayEvent, WireguardNetwork, - models::{ - device::WireguardNetworkDevice, - wireguard::{ - DEFAULT_DISCONNECT_THRESHOLD, DEFAULT_KEEPALIVE_INTERVAL, LocationMfaMode, - ServiceLocationMode, - }, +use defguard_common::db::{ + Id, + models::{ + Device, WireguardNetwork, + device::WireguardNetworkDevice, + settings::OpenidUsernameHandling, + wireguard::{ + DEFAULT_DISCONNECT_THRESHOLD, DEFAULT_KEEPALIVE_INTERVAL, LocationMfaMode, + ServiceLocationMode, }, }, +}; +use defguard_core::{ enterprise::{ db::models::openid_provider::{DirectorySyncTarget, DirectorySyncUserBehavior}, handlers::openid_providers::AddProviderData, license::{get_cached_license, set_cached_license}, }, + grpc::gateway::events::GatewayEvent, handlers::{Auth, GroupInfo, wireguard::WireguardNetworkData}, }; use ipnetwork::IpNetwork; diff --git a/crates/defguard_core/tests/integration/api/wireguard_network_allowed_groups.rs b/crates/defguard_core/tests/integration/api/wireguard_network_allowed_groups.rs index d641de589e..275bef35b9 100644 --- a/crates/defguard_core/tests/integration/api/wireguard_network_allowed_groups.rs +++ b/crates/defguard_core/tests/integration/api/wireguard_network_allowed_groups.rs @@ -1,10 +1,17 @@ use std::net::IpAddr; use claims::assert_err; -use defguard_common::{csv::AsCsv, db::Id}; +use defguard_common::{ + csv::AsCsv, + db::{ + Id, + models::{Device, DeviceType, User, WireguardNetwork, group::Group}, + }, +}; use defguard_core::{ - db::{Device, GatewayEvent, Group, User, WireguardNetwork, models::device::DeviceType}, + grpc::gateway::events::GatewayEvent, handlers::{Auth, wireguard::ImportedNetworkData}, + location_management::allowed_peers::get_location_allowed_peers, }; use matches::assert_matches; use reqwest::StatusCode; @@ -164,7 +171,9 @@ async fn test_create_new_network(_: PgPoolOptions, options: PgConnectOptions) { assert_err!(wg_rx.try_recv()); // network configuration was created only for admin and allowed user - let peers = network.get_peers(&client_state.pool).await.unwrap(); + let peers = get_location_allowed_peers(&network, &client_state.pool) + .await + .unwrap(); assert_eq!(peers.len(), 2); assert_eq!(peers[0].pubkey, devices[0].wireguard_pubkey); assert_eq!(peers[1].pubkey, devices[1].wireguard_pubkey); @@ -210,7 +219,9 @@ async fn test_modify_network(_: PgPoolOptions, options: PgConnectOptions) { assert_matches!(event, GatewayEvent::NetworkCreated(..)); // network configuration was created for all devices - let peers = network.get_peers(&client_state.pool).await.unwrap(); + let peers = get_location_allowed_peers(&network, &client_state.pool) + .await + .unwrap(); assert_eq!(peers.len(), 4); assert_eq!(peers[0].pubkey, devices[0].wireguard_pubkey); assert_eq!(peers[1].pubkey, devices[1].wireguard_pubkey); @@ -240,7 +251,9 @@ async fn test_modify_network(_: PgPoolOptions, options: PgConnectOptions) { assert_eq!(response.status(), StatusCode::OK); assert_matches!(wg_rx.try_recv().unwrap(), GatewayEvent::NetworkModified(..)); - let new_peers = network.get_peers(&client_state.pool).await.unwrap(); + let new_peers = get_location_allowed_peers(&network, &client_state.pool) + .await + .unwrap(); assert_eq!(new_peers.len(), 2); assert_eq!(new_peers[0].pubkey, devices[0].wireguard_pubkey); assert_eq!(new_peers[1].pubkey, devices[1].wireguard_pubkey); @@ -268,7 +281,9 @@ async fn test_modify_network(_: PgPoolOptions, options: PgConnectOptions) { assert_eq!(response.status(), StatusCode::OK); assert_matches!(wg_rx.try_recv().unwrap(), GatewayEvent::NetworkModified(..)); - let new_peers = network.get_peers(&client_state.pool).await.unwrap(); + let new_peers = get_location_allowed_peers(&network, &client_state.pool) + .await + .unwrap(); assert_eq!(new_peers.len(), 3); assert_eq!(new_peers[0].pubkey, devices[0].wireguard_pubkey); assert_eq!(new_peers[1].pubkey, devices[1].wireguard_pubkey); @@ -297,7 +312,9 @@ async fn test_modify_network(_: PgPoolOptions, options: PgConnectOptions) { assert_eq!(response.status(), StatusCode::OK); assert_matches!(wg_rx.try_recv().unwrap(), GatewayEvent::NetworkModified(..)); - let new_peers = network.get_peers(&client_state.pool).await.unwrap(); + let new_peers = get_location_allowed_peers(&network, &client_state.pool) + .await + .unwrap(); assert_eq!(new_peers.len(), 2); assert_eq!(new_peers[0].pubkey, devices[0].wireguard_pubkey); assert_eq!(new_peers[1].pubkey, devices[2].wireguard_pubkey); @@ -325,7 +342,9 @@ async fn test_modify_network(_: PgPoolOptions, options: PgConnectOptions) { assert_eq!(response.status(), StatusCode::OK); assert_matches!(wg_rx.try_recv().unwrap(), GatewayEvent::NetworkModified(..)); - let new_peers = network.get_peers(&client_state.pool).await.unwrap(); + let new_peers = get_location_allowed_peers(&network, &client_state.pool) + .await + .unwrap(); assert_eq!(new_peers.len(), 4); assert_eq!(new_peers[0].pubkey, devices[0].wireguard_pubkey); assert_eq!(new_peers[1].pubkey, devices[1].wireguard_pubkey); @@ -395,7 +414,9 @@ async fn test_import_network_existing_devices(_: PgPoolOptions, options: PgConne ); let network = response.network; - let peers = network.get_peers(&client_state.pool).await.unwrap(); + let peers = get_location_allowed_peers(&network, &client_state.pool) + .await + .unwrap(); assert_eq!(peers.len(), 2); assert_eq!(peers[0].pubkey, devices[0].wireguard_pubkey); assert_eq!(peers[1].pubkey, devices[1].wireguard_pubkey); @@ -499,7 +520,9 @@ PersistentKeepalive = 300 .await; assert_eq!(response.status(), StatusCode::CREATED); - let peers = network.get_peers(&client_state.pool).await.unwrap(); + let peers = get_location_allowed_peers(&network, &client_state.pool) + .await + .unwrap(); assert_eq!(peers.len(), 4); assert_eq!(peers[0].pubkey, devices[0].wireguard_pubkey); assert_eq!(peers[1].pubkey, devices[1].wireguard_pubkey); @@ -580,7 +603,9 @@ async fn test_modify_user(_: PgPoolOptions, options: PgConnectOptions) { assert_err!(wg_rx.try_recv()); // network configuration was created only for admin and allowed user - let peers = network.get_peers(&client_state.pool).await.unwrap(); + let peers = get_location_allowed_peers(&network, &client_state.pool) + .await + .unwrap(); assert_eq!(peers.len(), 2); assert_eq!(peers[0].pubkey, devices[0].wireguard_pubkey); assert_eq!(peers[1].pubkey, devices[1].wireguard_pubkey); @@ -599,7 +624,9 @@ async fn test_modify_user(_: PgPoolOptions, options: PgConnectOptions) { assert_matches!(event, GatewayEvent::DeviceDeleted(..)); assert_err!(wg_rx.try_recv()); - let peers = network.get_peers(&client_state.pool).await.unwrap(); + let peers = get_location_allowed_peers(&network, &client_state.pool) + .await + .unwrap(); assert_eq!(peers.len(), 1); assert_eq!(peers[0].pubkey, devices[0].wireguard_pubkey); @@ -615,7 +642,9 @@ async fn test_modify_user(_: PgPoolOptions, options: PgConnectOptions) { assert_err!(wg_rx.try_recv()); - let peers = network.get_peers(&client_state.pool).await.unwrap(); + let peers = get_location_allowed_peers(&network, &client_state.pool) + .await + .unwrap(); assert_eq!(peers.len(), 1); assert_eq!(peers[0].pubkey, devices[0].wireguard_pubkey); @@ -633,7 +662,9 @@ async fn test_modify_user(_: PgPoolOptions, options: PgConnectOptions) { assert_matches!(event, GatewayEvent::DeviceCreated(..)); assert_err!(wg_rx.try_recv()); - let peers = network.get_peers(&client_state.pool).await.unwrap(); + let peers = get_location_allowed_peers(&network, &client_state.pool) + .await + .unwrap(); assert_eq!(peers.len(), 2); assert_eq!(peers[0].pubkey, devices[0].wireguard_pubkey); assert_eq!(peers[1].pubkey, devices[3].wireguard_pubkey); @@ -678,7 +709,9 @@ async fn test_delete_only_allowed_group(_: PgPoolOptions, options: PgConnectOpti let event = wg_rx.try_recv().unwrap(); assert_matches!(event, GatewayEvent::NetworkCreated(..)); - let peers = network.get_peers(&client_state.pool).await.unwrap(); + let peers = get_location_allowed_peers(&network, &client_state.pool) + .await + .unwrap(); assert_eq!(peers.len(), 2); assert_eq!(peers[0].pubkey, devices[0].wireguard_pubkey); assert_eq!(peers[1].pubkey, devices[1].wireguard_pubkey); @@ -688,7 +721,9 @@ async fn test_delete_only_allowed_group(_: PgPoolOptions, options: PgConnectOpti assert_eq!(response.status(), StatusCode::OK); // network configuration was created for all devices - let peers = network.get_peers(&client_state.pool).await.unwrap(); + let peers = get_location_allowed_peers(&network, &client_state.pool) + .await + .unwrap(); assert_eq!(peers.len(), 4); assert_eq!(peers[0].pubkey, devices[0].wireguard_pubkey); assert_eq!(peers[1].pubkey, devices[1].wireguard_pubkey); diff --git a/crates/defguard_core/tests/integration/api/wireguard_network_devices.rs b/crates/defguard_core/tests/integration/api/wireguard_network_devices.rs index 225df95bb2..8cfb4995de 100644 --- a/crates/defguard_core/tests/integration/api/wireguard_network_devices.rs +++ b/crates/defguard_core/tests/integration/api/wireguard_network_devices.rs @@ -1,8 +1,11 @@ use std::{net::IpAddr, str::FromStr}; -use defguard_common::db::Id; +use defguard_common::db::{ + Id, + models::{Device, WireguardNetwork}, +}; use defguard_core::{ - db::{Device, GatewayEvent, WireguardNetwork}, + grpc::gateway::events::GatewayEvent, handlers::{Auth, network_devices::AddNetworkDevice}, }; use ipnetwork::IpNetwork; diff --git a/crates/defguard_core/tests/integration/api/wireguard_network_import.rs b/crates/defguard_core/tests/integration/api/wireguard_network_import.rs index 9db3ce63eb..e6056d5ab8 100644 --- a/crates/defguard_core/tests/integration/api/wireguard_network_import.rs +++ b/crates/defguard_core/tests/integration/api/wireguard_network_import.rs @@ -1,16 +1,15 @@ use std::net::IpAddr; -use defguard_core::{ - db::{ - Device, GatewayEvent, WireguardNetwork, - models::{ - device::{DeviceType, UserDevice}, - wireguard::{ - DEFAULT_DISCONNECT_THRESHOLD, DEFAULT_KEEPALIVE_INTERVAL, LocationMfaMode, - ServiceLocationMode, - }, - }, +use defguard_common::db::models::{ + Device, DeviceType, WireguardNetwork, + device::UserDevice, + wireguard::{ + DEFAULT_DISCONNECT_THRESHOLD, DEFAULT_KEEPALIVE_INTERVAL, LocationMfaMode, + ServiceLocationMode, }, +}; +use defguard_core::{ + grpc::gateway::events::GatewayEvent, handlers::{Auth, wireguard::ImportedNetworkData}, }; use matches::assert_matches; diff --git a/crates/defguard_core/tests/integration/api/wireguard_network_stats.rs b/crates/defguard_core/tests/integration/api/wireguard_network_stats.rs index 5a03e8169b..21718da443 100644 --- a/crates/defguard_core/tests/integration/api/wireguard_network_stats.rs +++ b/crates/defguard_core/tests/integration/api/wireguard_network_stats.rs @@ -1,16 +1,16 @@ use chrono::{Datelike, Duration, NaiveDate, SubsecRound, Timelike, Utc}; -use defguard_common::db::{Id, NoId}; -use defguard_core::{ - db::models::{ - device::Device, +use defguard_common::db::{ + Id, NoId, + models::{ + Device, wireguard::{ WireguardDeviceStatsRow, WireguardDeviceTransferRow, WireguardNetworkStats, WireguardUserStatsRow, }, wireguard_peer_stats::WireguardPeerStats, }, - handlers::Auth, }; +use defguard_core::handlers::Auth; use reqwest::StatusCode; use serde::Deserialize; use serde_json::json; diff --git a/crates/defguard_core/tests/integration/common.rs b/crates/defguard_core/tests/integration/common.rs index a5ef802aa2..6370300b1d 100644 --- a/crates/defguard_core/tests/integration/common.rs +++ b/crates/defguard_core/tests/integration/common.rs @@ -1,5 +1,7 @@ -use defguard_common::config::{DefGuardConfig, SERVER_CONFIG}; -use defguard_core::db::User; +use defguard_common::{ + config::{DefGuardConfig, SERVER_CONFIG}, + db::models::User, +}; use reqwest::Url; use secrecy::ExposeSecret; use sqlx::PgPool; diff --git a/crates/defguard_core/tests/integration/grpc/common/mod.rs b/crates/defguard_core/tests/integration/grpc/common/mod.rs index 96609dbfa7..7b2e43a17c 100644 --- a/crates/defguard_core/tests/integration/grpc/common/mod.rs +++ b/crates/defguard_core/tests/integration/grpc/common/mod.rs @@ -4,12 +4,12 @@ use axum::http::Uri; use defguard_common::db::models::settings::initialize_current_settings; use defguard_core::{ auth::failed_login::FailedLoginMap, - db::{AppEvent, GatewayEvent}, + db::AppEvent, enterprise::license::{License, set_cached_license}, events::GrpcEvent, grpc::{ WorkerState, build_grpc_service_router, - gateway::{client_state::ClientMap, map::GatewayMap}, + gateway::{client_state::ClientMap, events::GatewayEvent, map::GatewayMap}, }, }; use defguard_mail::Mail; diff --git a/crates/defguard_core/tests/integration/grpc/gateway.rs b/crates/defguard_core/tests/integration/grpc/gateway.rs index d27fca1e72..841cbcdf5b 100644 --- a/crates/defguard_core/tests/integration/grpc/gateway.rs +++ b/crates/defguard_core/tests/integration/grpc/gateway.rs @@ -5,19 +5,19 @@ use std::{ use chrono::{Days, Utc}; use claims::{assert_err_eq, assert_matches}; -use defguard_common::db::{Id, NoId, setup_pool}; -use defguard_core::{ - db::{ - Device, User, WireguardNetwork, - models::{ - device::DeviceType, - wireguard::{LocationMfaMode, ServiceLocationMode}, - wireguard_peer_stats::WireguardPeerStats, - }, +use defguard_common::db::{ + Id, NoId, + models::{ + Device, DeviceType, User, WireguardNetwork, + wireguard::{LocationMfaMode, ServiceLocationMode}, + wireguard_peer_stats::WireguardPeerStats, }, + setup_pool, +}; +use defguard_core::{ enterprise::{license::set_cached_license, limits::update_counts}, events::GrpcEvent, - grpc::MIN_GATEWAY_VERSION, + grpc::{MIN_GATEWAY_VERSION, gateway::events::GatewayEvent}, }; use defguard_proto::{ enterprise::firewall::FirewallPolicy, @@ -428,7 +428,7 @@ async fn test_gateway_update_routing(_: PgPoolOptions, options: PgConnectOptions gateway_2.connect_to_updates_stream().await; // send update for location 1 - test_server.send_wireguard_event(defguard_core::db::GatewayEvent::NetworkDeleted( + test_server.send_wireguard_event(GatewayEvent::NetworkDeleted( test_location.id, "network name".into(), )); @@ -450,7 +450,7 @@ async fn test_gateway_update_routing(_: PgPoolOptions, options: PgConnectOptions assert_eq!(update, expected_update); // send update for location 2 - test_server.send_wireguard_event(defguard_core::db::GatewayEvent::NetworkDeleted( + test_server.send_wireguard_event(GatewayEvent::NetworkDeleted( test_location_2.id, "network name 2".into(), )); @@ -472,10 +472,7 @@ async fn test_gateway_update_routing(_: PgPoolOptions, options: PgConnectOptions assert_eq!(update, expected_update); // send update for location which does not exist - test_server.send_wireguard_event(defguard_core::db::GatewayEvent::NetworkDeleted( - 1234, - "does not exist".into(), - )); + test_server.send_wireguard_event(GatewayEvent::NetworkDeleted(1234, "does not exist".into())); // no gateway should receive this update assert!(gateway_1.receive_next_update().await.is_none()); diff --git a/crates/defguard_event_logger/src/message.rs b/crates/defguard_event_logger/src/message.rs index 0dc0eebcda..1befe676c0 100644 --- a/crates/defguard_event_logger/src/message.rs +++ b/crates/defguard_event_logger/src/message.rs @@ -3,13 +3,13 @@ use std::net::IpAddr; use chrono::NaiveDateTime; use defguard_common::db::{ Id, - models::{AuthenticationKey, MFAMethod, Settings}, + models::{ + AuthenticationKey, Device, MFAMethod, Settings, User, WebAuthn, WireguardNetwork, + group::Group, oauth2client::OAuth2Client, + }, }; use defguard_core::{ - db::{ - Device, Group, User, WebAuthn, WebHook, WireguardNetwork, - models::oauth2client::OAuth2Client, - }, + db::WebHook, enterprise::db::models::{ activity_log_stream::ActivityLogStream, api_tokens::ApiToken, openid_provider::OpenIdProvider, snat::UserSnatBinding, diff --git a/crates/defguard_event_router/src/lib.rs b/crates/defguard_event_router/src/lib.rs index b9633f9f07..636d41de63 100644 --- a/crates/defguard_event_router/src/lib.rs +++ b/crates/defguard_event_router/src/lib.rs @@ -20,8 +20,8 @@ use std::sync::Arc; use defguard_core::{ - db::GatewayEvent, events::{ApiEvent, BidiStreamEvent, GrpcEvent, InternalEvent}, + grpc::gateway::events::GatewayEvent, }; use defguard_event_logger::message::{EventContext, EventLoggerMessage, LoggerEvent}; use defguard_mail::Mail; diff --git a/crates/defguard_mail/src/templates.rs b/crates/defguard_mail/src/templates.rs index b209d8db13..9e8af09b8c 100644 --- a/crates/defguard_mail/src/templates.rs +++ b/crates/defguard_mail/src/templates.rs @@ -1,7 +1,17 @@ use std::collections::HashMap; use chrono::{Datelike, NaiveDateTime, Utc}; -use defguard_common::{VERSION, config::server_config, db::models::user::MFAMethod}; +use defguard_common::{ + VERSION, + config::server_config, + db::{ + Id, + models::{ + Session, + user::{MFAMethod, User}, + }, + }, +}; use reqwest::Url; use serde::Serialize; use serde_json::Value; @@ -67,11 +77,29 @@ pub struct SessionContext { pub device_info: Option, } +impl From for SessionContext { + fn from(value: Session) -> Self { + Self { + ip_address: value.ip_address, + device_info: value.device_info, + } + } +} + pub struct UserContext { pub last_name: String, pub first_name: String, } +impl From> for UserContext { + fn from(value: User) -> Self { + Self { + last_name: value.last_name, + first_name: value.first_name, + } + } +} + fn get_base_tera( external_context: Option, session: Option<&SessionContext>, diff --git a/crates/defguard_proto/Cargo.toml b/crates/defguard_proto/Cargo.toml index 5a749a3eb1..3400b5ed92 100644 --- a/crates/defguard_proto/Cargo.toml +++ b/crates/defguard_proto/Cargo.toml @@ -8,6 +8,7 @@ repository.workspace = true rust-version.workspace = true [dependencies] +defguard_common.workspace = true prost.workspace = true serde.workspace = true tonic.workspace = true diff --git a/crates/defguard_proto/src/lib.rs b/crates/defguard_proto/src/lib.rs index 0313436669..df37922bb0 100644 --- a/crates/defguard_proto/src/lib.rs +++ b/crates/defguard_proto/src/lib.rs @@ -18,6 +18,16 @@ pub mod enterprise { } } +use defguard_common::{ + csv::AsCsv, + db::{ + Id, + models::{ + Device, DeviceConfig, User, + wireguard::{LocationMfaMode, ServiceLocationMode}, + }, + }, +}; use proxy::{CoreError, MfaMethod}; use serde::Serialize; use tonic::Status; @@ -64,3 +74,75 @@ impl From for CoreError { } } } + +impl From for proxy::DeviceConfig { + fn from(config: DeviceConfig) -> Self { + // DEPRECATED(1.5): superseeded by location_mfa_mode + let mfa_enabled = config.location_mfa_mode == LocationMfaMode::Internal; + Self { + network_id: config.network_id, + network_name: config.network_name, + config: config.config, + endpoint: config.endpoint, + assigned_ip: config.address.as_csv(), + pubkey: config.pubkey, + allowed_ips: config.allowed_ips.as_csv(), + dns: config.dns, + keepalive_interval: config.keepalive_interval, + #[allow(deprecated)] + mfa_enabled, + location_mfa_mode: Some( + >::into(config.location_mfa_mode) + .into(), + ), + service_location_mode: Some( + >::into( + config.service_location_mode, + ) + .into(), + ), + } + } +} + +impl From> for proxy::Device { + fn from(device: Device) -> Self { + Self { + id: device.id, + name: device.name, + pubkey: device.wireguard_pubkey, + user_id: device.user_id, + created_at: device.created.and_utc().timestamp(), + } + } +} + +impl From> for proxy::AdminInfo { + fn from(admin: User) -> Self { + Self { + name: format!("{} {}", admin.first_name, admin.last_name), + phone_number: admin.phone, + email: admin.email, + } + } +} + +impl From for proxy::LocationMfaMode { + fn from(value: LocationMfaMode) -> Self { + match value { + LocationMfaMode::Disabled => proxy::LocationMfaMode::Disabled, + LocationMfaMode::Internal => proxy::LocationMfaMode::Internal, + LocationMfaMode::External => proxy::LocationMfaMode::External, + } + } +} + +impl From for proxy::ServiceLocationMode { + fn from(value: ServiceLocationMode) -> Self { + match value { + ServiceLocationMode::Disabled => proxy::ServiceLocationMode::Disabled, + ServiceLocationMode::PreLogon => proxy::ServiceLocationMode::Prelogon, + ServiceLocationMode::AlwaysOn => proxy::ServiceLocationMode::Alwayson, + } + } +} diff --git a/crates/defguard_session_manager/Cargo.toml b/crates/defguard_session_manager/Cargo.toml new file mode 100644 index 0000000000..be683e13f0 --- /dev/null +++ b/crates/defguard_session_manager/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "defguard_session_manager" +version = "0.0.0" +edition.workspace = true +license-file.workspace = true +homepage.workspace = true +repository.workspace = true +rust-version.workspace = true + +[dependencies] +defguard_common.workspace = true +sqlx.workspace = true +tokio.workspace = true + diff --git a/crates/defguard_session_manager/src/lib.rs b/crates/defguard_session_manager/src/lib.rs new file mode 100644 index 0000000000..beb3865559 --- /dev/null +++ b/crates/defguard_session_manager/src/lib.rs @@ -0,0 +1,10 @@ +use defguard_common::db::models::wireguard_peer_stats::WireguardPeerStats; +use sqlx::PgPool; +use tokio::sync::mpsc::UnboundedReceiver; + +pub async fn run_session_manager( + _pool: PgPool, + _peer_stats_rx: UnboundedReceiver, +) { + unimplemented!() +} diff --git a/deny.toml b/deny.toml index 4bbfa288c1..18f5a9b7fa 100644 --- a/deny.toml +++ b/deny.toml @@ -133,6 +133,9 @@ exceptions = [ { allow = [ "AGPL-3.0-only", ], crate = "defguard_event_logger" }, + { allow = [ + "AGPL-3.0-only", + ], crate = "defguard_session_manager" }, { allow = [ "AGPL-3.0-only", ], crate = "defguard_version" }, diff --git a/web/package.json b/web/package.json index d3d9ddbd62..2a4e8a48b4 100644 --- a/web/package.json +++ b/web/package.json @@ -50,8 +50,8 @@ "@react-rxjs/core": "^0.10.8", "@stablelib/base64": "^2.0.1", "@stablelib/x25519": "^2.0.1", - "@tanstack/query-core": "^5.90.9", - "@tanstack/react-query": "^5.90.9", + "@tanstack/query-core": "^5.90.10", + "@tanstack/react-query": "^5.90.10", "@tanstack/react-virtual": "3.13.12", "@tanstack/virtual-core": "3.13.12", "@use-gesture/react": "^10.3.1", @@ -70,7 +70,7 @@ "fuse.js": "^7.1.0", "get-text-width": "^1.0.3", "hex-rgb": "^5.0.0", - "html-react-parser": "^5.2.8", + "html-react-parser": "^5.2.10", "humanize-duration": "^3.33.1", "ipaddr.js": "^2.2.0", "itertools": "^2.5.0", @@ -124,7 +124,7 @@ "@types/lodash-es": "^4.17.12", "@types/node": "^24.10.1", "@types/qs": "^6.14.0", - "@types/react": "^19.2.4", + "@types/react": "^19.2.5", "@types/react-dom": "^19.2.3", "@types/react-router-dom": "^5.3.3", "@vitejs/plugin-react-swc": "^4.2.2", diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index af6ae57457..21e6759ba9 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -30,11 +30,11 @@ importers: specifier: ^2.0.1 version: 2.0.1 '@tanstack/query-core': - specifier: ^5.90.9 - version: 5.90.9 + specifier: ^5.90.10 + version: 5.90.10 '@tanstack/react-query': - specifier: ^5.90.9 - version: 5.90.9(react@19.2.0) + specifier: ^5.90.10 + version: 5.90.10(react@19.2.0) '@tanstack/react-virtual': specifier: 3.13.12 version: 3.13.12(react-dom@19.2.0(react@19.2.0))(react@19.2.0) @@ -90,8 +90,8 @@ importers: specifier: ^5.0.0 version: 5.0.0 html-react-parser: - specifier: ^5.2.8 - version: 5.2.8(@types/react@19.2.4)(react@19.2.0) + specifier: ^5.2.10 + version: 5.2.10(@types/react@19.2.5)(react@19.2.0) humanize-duration: specifier: ^3.33.1 version: 3.33.1 @@ -109,7 +109,7 @@ importers: version: 4.17.21 merge-refs: specifier: ^2.0.0 - version: 2.0.0(@types/react@19.2.4) + version: 2.0.0(@types/react@19.2.5) millify: specifier: ^6.1.0 version: 6.1.0 @@ -157,7 +157,7 @@ importers: version: 3.5.0(react@19.2.0) react-markdown: specifier: ^10.1.0 - version: 10.1.0(@types/react@19.2.4)(react@19.2.0) + version: 10.1.0(@types/react@19.2.5)(react@19.2.0) react-qr-code: specifier: ^2.0.18 version: 2.0.18(react@19.2.0) @@ -178,7 +178,7 @@ importers: version: 1.0.26(react-dom@19.2.0(react@19.2.0))(react@19.2.0) recharts: specifier: ^3.4.1 - version: 3.4.1(@types/react@19.2.4)(react-dom@19.2.0(react@19.2.0))(react-is@19.2.0)(react@19.2.0)(redux@5.0.1) + version: 3.4.1(@types/react@19.2.5)(react-dom@19.2.0(react@19.2.0))(react-is@19.2.0)(react@19.2.0)(redux@5.0.1) rehype-external-links: specifier: ^3.0.0 version: 3.0.0 @@ -208,7 +208,7 @@ importers: version: 3.25.76 zustand: specifier: ^5.0.8 - version: 5.0.8(@types/react@19.2.4)(immer@10.2.0)(react@19.2.0)(use-sync-external-store@1.6.0(react@19.2.0)) + version: 5.0.8(@types/react@19.2.5)(immer@10.2.0)(react@19.2.0)(use-sync-external-store@1.6.0(react@19.2.0)) devDependencies: '@babel/core': specifier: ^7.28.5 @@ -224,10 +224,10 @@ importers: version: 3.0.4 '@hookform/devtools': specifier: ^4.4.0 - version: 4.4.0(@types/react@19.2.4)(react-dom@19.2.0(react@19.2.0))(react@19.2.0) + version: 4.4.0(@types/react@19.2.5)(react-dom@19.2.0(react@19.2.0))(react@19.2.0) '@tanstack/react-query-devtools': specifier: ^5.90.2 - version: 5.90.2(@tanstack/react-query@5.90.9(react@19.2.0))(react@19.2.0) + version: 5.90.2(@tanstack/react-query@5.90.10(react@19.2.0))(react@19.2.0) '@types/byte-size': specifier: ^8.1.2 version: 8.1.2 @@ -247,11 +247,11 @@ importers: specifier: ^6.14.0 version: 6.14.0 '@types/react': - specifier: ^19.2.4 - version: 19.2.4 + specifier: ^19.2.5 + version: 19.2.5 '@types/react-dom': specifier: ^19.2.3 - version: 19.2.3(@types/react@19.2.4) + version: 19.2.3(@types/react@19.2.5) '@types/react-router-dom': specifier: ^5.3.3 version: 5.3.3 @@ -963,8 +963,8 @@ packages: '@swc/types@0.1.25': resolution: {integrity: sha512-iAoY/qRhNH8a/hBvm3zKj9qQ4oc2+3w1unPJa2XvTK3XjeLXtzcCingVPw/9e5mn1+0yPqxcBGp9Jf0pkfMb1g==} - '@tanstack/query-core@5.90.9': - resolution: {integrity: sha512-UFOCQzi6pRGeVTVlPNwNdnAvT35zugcIydqjvFUzG62dvz2iVjElmNp/hJkUoM5eqbUPfSU/GJIr/wbvD8bTUw==} + '@tanstack/query-core@5.90.10': + resolution: {integrity: sha512-EhZVFu9rl7GfRNuJLJ3Y7wtbTnENsvzp+YpcAV7kCYiXni1v8qZh++lpw4ch4rrwC0u/EZRnBHIehzCGzwXDSQ==} '@tanstack/query-devtools@5.90.1': resolution: {integrity: sha512-GtINOPjPUH0OegJExZ70UahT9ykmAhmtNVcmtdnOZbxLwT7R5OmRztR5Ahe3/Cu7LArEmR6/588tAycuaWb1xQ==} @@ -975,8 +975,8 @@ packages: '@tanstack/react-query': ^5.90.2 react: ^18 || ^19 - '@tanstack/react-query@5.90.9': - resolution: {integrity: sha512-Zke2AaXiaSfnG8jqPZR52m8SsclKT2d9//AgE/QIzyNvbpj/Q2ln+FsZjb1j69bJZUouBvX2tg9PHirkTm8arw==} + '@tanstack/react-query@5.90.10': + resolution: {integrity: sha512-BKLss9Y8PQ9IUjPYQiv3/Zmlx92uxffUOX8ZZNoQlCIZBJPT5M+GOMQj7xislvVQ6l1BstBjcX0XB/aHfFYVNw==} peerDependencies: react: ^18 || ^19 @@ -1078,8 +1078,8 @@ packages: '@types/react-router@5.1.20': resolution: {integrity: sha512-jGjmu/ZqS7FjSH6owMcD5qpq19+1RS9DeVRqfl1FeBMxTDQAGwlMWOcs52NDoXaNKyG3d1cYQFMs9rCrb88o9Q==} - '@types/react@19.2.4': - resolution: {integrity: sha512-tBFxBp9Nfyy5rsmefN+WXc1JeW/j2BpBHFdLZbEVfs9wn3E3NRFxwV0pJg8M1qQAexFpvz73hJXFofV0ZAu92A==} + '@types/react@19.2.5': + resolution: {integrity: sha512-keKxkZMqnDicuvFoJbzrhbtdLSPhj/rZThDlKWCDbgXmUg0rEUFtRssDXKYmtXluZlIqiC5VqkCgRwzuyLHKHw==} '@types/unist@2.0.11': resolution: {integrity: sha512-CmBKiL6NNo/OqgmMn95Fk9Whlp2mtvIv+KNpQKN2F4SjvrEesubTRWGYSg+BnWZOnlCaSTU1sMpsBOzgbYhnsA==} @@ -1165,8 +1165,8 @@ packages: balanced-match@1.0.2: resolution: {integrity: sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==} - baseline-browser-mapping@2.8.28: - resolution: {integrity: sha512-gYjt7OIqdM0PcttNYP2aVrr2G0bMALkBaoehD4BuRGjAOtipg0b6wHg1yNL+s5zSnLZZrGHOw4IrND8CD+3oIQ==} + baseline-browser-mapping@2.8.29: + resolution: {integrity: sha512-sXdt2elaVnhpDNRDz+1BDx1JQoJRuNk7oVlAlbGiFkLikHCAQiccexF/9e91zVi6RCgqspl04aP+6Cnl9zRLrA==} hasBin: true bignumber.js@9.3.1: @@ -1220,8 +1220,8 @@ packages: resolution: {integrity: sha512-L28STB170nwWS63UjtlEOE3dldQApaJXZkOI1uMFfzf3rRuPegHaHesyee+YxQ+W6SvRDQV6UrdOdRiR153wJg==} engines: {node: '>=6'} - caniuse-lite@1.0.30001754: - resolution: {integrity: sha512-x6OeBXueoAceOmotzx3PO4Zpt4rzpeIFsSr6AAePTZxSkXiYDUmpypEl7e2+8NCd9bD7bXjqyef8CJYPC1jfxg==} + caniuse-lite@1.0.30001755: + resolution: {integrity: sha512-44V+Jm6ctPj7R52Na4TLi3Zri4dWUljJd+RDm+j8LtNCc/ihLCT+X1TzoOAkRETEWqjuLnh9581Tl80FvK7jVA==} ccount@2.0.1: resolution: {integrity: sha512-eyrF0jiFpY+3drT6383f1qhkbGsLSifNAjA61IUjZjmLCWjItY6LB9ft9YhoDgwfmclB2zhu51Lc7+95b8NRAg==} @@ -1388,8 +1388,8 @@ packages: resolution: {integrity: sha512-AdmX6xUzdNASswsFtmwSt7Vj8po9IuqXm0UXz7QKPuEUmPB4XyjGfaAr2PSuELMwkRMVH1EpIkX5bTZGRB3eCA==} engines: {node: '>=10'} - csstype@3.2.0: - resolution: {integrity: sha512-si++xzRAY9iPp60roQiFta7OFbhrgvcthrhlNAGeQptSY25uJjkfUV8OArC3KLocB8JT8ohz+qgxWCmz8RhjIg==} + csstype@3.2.3: + resolution: {integrity: sha512-z1HGKcYy2xA8AGQfwrn0PAy+PB7X/GSj3UVJW9qKyn43xWa+gl5nXmU4qqLMRzWVLFC8KusUX8T/0kCiOYpAIQ==} d3-array@3.2.4: resolution: {integrity: sha512-tdQAmyA18i4J7wprpYq8ClcxZy3SC31QMeByyCFyRt7BVHdREQZ5lpzoe5mFEYZUWe+oq8HBvk9JjpibyEV4Jg==} @@ -1533,8 +1533,8 @@ packages: resolution: {integrity: sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==} engines: {node: '>= 0.4'} - electron-to-chromium@1.5.252: - resolution: {integrity: sha512-53uTpjtRgS7gjIxZ4qCgFdNO2q+wJt/Z8+xAvxbCqXPJrY6h7ighUkadQmNMXH96crtpa6gPFNP7BF4UBGDuaA==} + electron-to-chromium@1.5.254: + resolution: {integrity: sha512-DcUsWpVhv9svsKRxnSCZ86SjD+sp32SGidNB37KpqXJncp1mfUgKbHvBomE89WJDbfVKw1mdv5+ikrvd43r+Bg==} emoji-regex@8.0.0: resolution: {integrity: sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==} @@ -1566,8 +1566,8 @@ packages: resolution: {integrity: sha512-j6vWzfrGVfyXxge+O0x5sh6cvxAog0a/4Rdd2K36zCMV5eJ+/+tOAngRO8cODMNWbVRdVlmGZQL2YS3yR8bIUA==} engines: {node: '>= 0.4'} - es-toolkit@1.41.0: - resolution: {integrity: sha512-bDd3oRmbVgqZCJS6WmeQieOrzpl3URcWBUVDXxOELlUW2FuW+0glPOz1n0KnRie+PdyvUZcXz2sOn00c6pPRIA==} + es-toolkit@1.42.0: + resolution: {integrity: sha512-SLHIyY7VfDJBM8clz4+T2oquwTQxEzu263AyhVK4jREOAwJ+8eebaa4wM3nlvnAqhDrMm2EsA6hWHaQsMPQ1nA==} esbuild@0.25.12: resolution: {integrity: sha512-bbPBYYrtZbkt6Os6FiTLCTFxvq4tt3JKall1vRwshA3fdVztsLAatFaZobhkBC8/BrPetoa0oksYoKXoG4ryJg==} @@ -1650,8 +1650,8 @@ packages: debug: optional: true - form-data@4.0.4: - resolution: {integrity: sha512-KrGhL9Q4zjj0kiUt5OO4Mr/A/jlI2jDYs5eHBpYHPcBEVSiipAvn2Ko2HnPe20rmcuuvMHNdZFp+4IlGTMF0Ow==} + form-data@4.0.5: + resolution: {integrity: sha512-8RipRLol37bNs2bhoV67fiTEvdTrbMUYcFTiy3+wuuOnUog2QBHCZWXDRijWQfAkhBj2Uf5UnVaiWwA5vdd82w==} engines: {node: '>= 6'} fraction.js@5.3.4: @@ -1809,11 +1809,11 @@ packages: resolution: {integrity: sha512-kyCuEOWjJqZuDbRHzL8V93NzQhwIB71oFWSyzVo+KPZI+pnQPPxucdkrOZvkLRnrf5URsQM+IJ09Dw29cRALIA==} engines: {node: '>=10'} - html-dom-parser@5.1.1: - resolution: {integrity: sha512-+o4Y4Z0CLuyemeccvGN4bAO20aauB2N9tFEAep5x4OW34kV4PTarBHm6RL02afYt2BMKcr0D2Agep8S3nJPIBg==} + html-dom-parser@5.1.2: + resolution: {integrity: sha512-9nD3Rj3/FuQt83AgIa1Y3ruzspwFFA54AJbQnohXN+K6fL1/bhcDQJJY5Ne4L4A163ADQFVESd/0TLyNoV0mfg==} - html-react-parser@5.2.8: - resolution: {integrity: sha512-09WaI81tbpwhXWeMe1m9VptZVJUcigo0l59zVt+2HUIQT7+baU38/oNhllj6MKhOuGXqh0nrlwOgxbxbm6xXHw==} + html-react-parser@5.2.10: + resolution: {integrity: sha512-DjOLloguuDA+Ed7Q7PKhvMQmCl2+Yk/pfvvca68fvn15QFBbL4uHGxXwoXQ4sqS0UyuRH2lJb0S8yZCL3lvehQ==} peerDependencies: '@types/react': 0.14 || 15 || 16 || 17 || 18 || 19 react: 0.14 || 15 || 16 || 17 || 18 || 19 @@ -1853,8 +1853,8 @@ packages: ini@1.3.8: resolution: {integrity: sha512-JV/yugV2uzW5iMRSiZAyDtQd+nxtUnjeLt0acNdw98kKLrvuRVyB80tsREOE7yvGVgalhZ6RNXCmEHkUKBKxew==} - inline-style-parser@0.2.6: - resolution: {integrity: sha512-gtGXVaBdl5mAes3rPcMedEBm12ibjt1kDMFfheul1wUAOVEJW60voNdMVzVkfLN06O7ZaD/rxhfKgtlgtTbMjg==} + inline-style-parser@0.2.7: + resolution: {integrity: sha512-Nb2ctOyNR8DqQoR0OwRG95uNWIC0C1lCgf5Naz5H6Ji72KZ8OcFZLz2P5sNgwlyoJ8Yif11oMuYs5pBQa86csA==} internmap@2.0.3: resolution: {integrity: sha512-5Hh7Y1wQbvY5ooGgPbDaL5iYLAPzMTUrjMulskHLH6wnv/A+1q5rgEaiuqEjB+oxGXIVZs1FF+R/KPN3ZSQYYg==} @@ -2701,11 +2701,11 @@ packages: resolution: {integrity: sha512-laJTa3Jb+VQpaC6DseHhF7dXVqHTfJPCRDaEbid/drOhgitgYku/letMUqOXFoWV0zIIUbjpdH2t+tYj4bQMRQ==} engines: {node: '>=8'} - style-to-js@1.1.19: - resolution: {integrity: sha512-Ev+SgeqiNGT1ufsXyVC5RrJRXdrkRJ1Gol9Qw7Pb72YCKJXrBvP0ckZhBeVSrw2m06DJpei2528uIpjMb4TsoQ==} + style-to-js@1.1.21: + resolution: {integrity: sha512-RjQetxJrrUJLQPHbLku6U/ocGtzyjbJMP9lCNK7Ag0CNh690nSH8woqWH9u16nMjYBAok+i7JO1NP2pOy8IsPQ==} - style-to-object@1.0.12: - resolution: {integrity: sha512-ddJqYnoT4t97QvN2C95bCgt+m7AAgXjVnkk/jxAfmp7EAB8nnqqZYEbMd3em7/vEomDb2LAQKAy1RFfv41mdNw==} + style-to-object@1.0.14: + resolution: {integrity: sha512-LIN7rULI0jBscWQYaSswptyderlarFkjQ+t79nzty8tcIAceVomEVlLzH5VP4Cmsv6MtKhs7qaAiwlcp+Mgaxw==} stylis@4.2.0: resolution: {integrity: sha512-Orov6g6BB1sDfYgzWfTHDOxamtX1bE/zo104Dh9e6fqJ3PooipYyfJ0pUmrZO2wAvO8YbEyeFrkV91XTsGMSrw==} @@ -3262,7 +3262,7 @@ snapshots: '@emotion/memoize@0.9.0': {} - '@emotion/react@11.14.0(@types/react@19.2.4)(react@19.2.0)': + '@emotion/react@11.14.0(@types/react@19.2.5)(react@19.2.0)': dependencies: '@babel/runtime': 7.28.4 '@emotion/babel-plugin': 11.13.5 @@ -3274,7 +3274,7 @@ snapshots: hoist-non-react-statics: 3.3.2 react: 19.2.0 optionalDependencies: - '@types/react': 19.2.4 + '@types/react': 19.2.5 transitivePeerDependencies: - supports-color @@ -3284,22 +3284,22 @@ snapshots: '@emotion/memoize': 0.9.0 '@emotion/unitless': 0.10.0 '@emotion/utils': 1.4.2 - csstype: 3.2.0 + csstype: 3.2.3 '@emotion/sheet@1.4.0': {} - '@emotion/styled@11.14.1(@emotion/react@11.14.0(@types/react@19.2.4)(react@19.2.0))(@types/react@19.2.4)(react@19.2.0)': + '@emotion/styled@11.14.1(@emotion/react@11.14.0(@types/react@19.2.5)(react@19.2.0))(@types/react@19.2.5)(react@19.2.0)': dependencies: '@babel/runtime': 7.28.4 '@emotion/babel-plugin': 11.13.5 '@emotion/is-prop-valid': 1.4.0 - '@emotion/react': 11.14.0(@types/react@19.2.4)(react@19.2.0) + '@emotion/react': 11.14.0(@types/react@19.2.5)(react@19.2.0) '@emotion/serialize': 1.3.3 '@emotion/use-insertion-effect-with-fallbacks': 1.2.0(react@19.2.0) '@emotion/utils': 1.4.2 react: 19.2.0 optionalDependencies: - '@types/react': 19.2.4 + '@types/react': 19.2.5 transitivePeerDependencies: - supports-color @@ -3418,10 +3418,10 @@ snapshots: '@github/webauthn-json@2.1.1': {} - '@hookform/devtools@4.4.0(@types/react@19.2.4)(react-dom@19.2.0(react@19.2.0))(react@19.2.0)': + '@hookform/devtools@4.4.0(@types/react@19.2.5)(react-dom@19.2.0(react@19.2.0))(react@19.2.0)': dependencies: - '@emotion/react': 11.14.0(@types/react@19.2.4)(react@19.2.0) - '@emotion/styled': 11.14.1(@emotion/react@11.14.0(@types/react@19.2.4)(react@19.2.0))(@types/react@19.2.4)(react@19.2.0) + '@emotion/react': 11.14.0(@types/react@19.2.5)(react@19.2.0) + '@emotion/styled': 11.14.1(@emotion/react@11.14.0(@types/react@19.2.5)(react@19.2.0))(@types/react@19.2.5)(react@19.2.0) '@types/lodash': 4.17.20 little-state-machine: 4.8.1(react@19.2.0) lodash: 4.17.21 @@ -3487,7 +3487,7 @@ snapshots: rxjs: 7.8.2 use-sync-external-store: 1.6.0(react@19.2.0) - '@reduxjs/toolkit@2.10.1(react-redux@9.2.0(@types/react@19.2.4)(react@19.2.0)(redux@5.0.1))(react@19.2.0)': + '@reduxjs/toolkit@2.10.1(react-redux@9.2.0(@types/react@19.2.5)(react@19.2.0)(redux@5.0.1))(react@19.2.0)': dependencies: '@standard-schema/spec': 1.0.0 '@standard-schema/utils': 0.3.0 @@ -3497,7 +3497,7 @@ snapshots: reselect: 5.1.1 optionalDependencies: react: 19.2.0 - react-redux: 9.2.0(@types/react@19.2.4)(react@19.2.0)(redux@5.0.1) + react-redux: 9.2.0(@types/react@19.2.5)(react@19.2.0)(redux@5.0.1) '@remix-run/router@1.23.1': {} @@ -3656,19 +3656,19 @@ snapshots: dependencies: '@swc/counter': 0.1.3 - '@tanstack/query-core@5.90.9': {} + '@tanstack/query-core@5.90.10': {} '@tanstack/query-devtools@5.90.1': {} - '@tanstack/react-query-devtools@5.90.2(@tanstack/react-query@5.90.9(react@19.2.0))(react@19.2.0)': + '@tanstack/react-query-devtools@5.90.2(@tanstack/react-query@5.90.10(react@19.2.0))(react@19.2.0)': dependencies: '@tanstack/query-devtools': 5.90.1 - '@tanstack/react-query': 5.90.9(react@19.2.0) + '@tanstack/react-query': 5.90.10(react@19.2.0) react: 19.2.0 - '@tanstack/react-query@5.90.9(react@19.2.0)': + '@tanstack/react-query@5.90.10(react@19.2.0)': dependencies: - '@tanstack/query-core': 5.90.9 + '@tanstack/query-core': 5.90.10 react: 19.2.0 '@tanstack/react-virtual@3.13.12(react-dom@19.2.0(react@19.2.0))(react@19.2.0)': @@ -3749,24 +3749,24 @@ snapshots: '@types/qs@6.14.0': {} - '@types/react-dom@19.2.3(@types/react@19.2.4)': + '@types/react-dom@19.2.3(@types/react@19.2.5)': dependencies: - '@types/react': 19.2.4 + '@types/react': 19.2.5 '@types/react-router-dom@5.3.3': dependencies: '@types/history': 4.7.11 - '@types/react': 19.2.4 + '@types/react': 19.2.5 '@types/react-router': 5.1.20 '@types/react-router@5.1.20': dependencies: '@types/history': 4.7.11 - '@types/react': 19.2.4 + '@types/react': 19.2.5 - '@types/react@19.2.4': + '@types/react@19.2.5': dependencies: - csstype: 3.2.0 + csstype: 3.2.3 '@types/unist@2.0.11': {} @@ -3825,7 +3825,7 @@ snapshots: autoprefixer@10.4.22(postcss@8.5.6): dependencies: browserslist: 4.28.0 - caniuse-lite: 1.0.30001754 + caniuse-lite: 1.0.30001755 fraction.js: 5.3.4 normalize-range: 0.1.2 picocolors: 1.1.1 @@ -3835,7 +3835,7 @@ snapshots: axios@1.13.2: dependencies: follow-redirects: 1.15.11 - form-data: 4.0.4 + form-data: 4.0.5 proxy-from-env: 1.1.0 transitivePeerDependencies: - debug @@ -3850,7 +3850,7 @@ snapshots: balanced-match@1.0.2: {} - baseline-browser-mapping@2.8.28: {} + baseline-browser-mapping@2.8.29: {} bignumber.js@9.3.1: {} @@ -3867,9 +3867,9 @@ snapshots: browserslist@4.28.0: dependencies: - baseline-browser-mapping: 2.8.28 - caniuse-lite: 1.0.30001754 - electron-to-chromium: 1.5.252 + baseline-browser-mapping: 2.8.29 + caniuse-lite: 1.0.30001755 + electron-to-chromium: 1.5.254 node-releases: 2.0.27 update-browserslist-db: 1.1.4(browserslist@4.28.0) @@ -3897,7 +3897,7 @@ snapshots: camelcase@5.3.1: {} - caniuse-lite@1.0.30001754: {} + caniuse-lite@1.0.30001755: {} ccount@2.0.1: {} @@ -4124,7 +4124,7 @@ snapshots: path-type: 4.0.0 yaml: 1.10.2 - csstype@3.2.0: {} + csstype@3.2.3: {} d3-array@3.2.4: dependencies: @@ -4246,7 +4246,7 @@ snapshots: es-errors: 1.3.0 gopd: 1.2.0 - electron-to-chromium@1.5.252: {} + electron-to-chromium@1.5.254: {} emoji-regex@8.0.0: {} @@ -4273,7 +4273,7 @@ snapshots: has-tostringtag: 1.0.2 hasown: 2.0.2 - es-toolkit@1.41.0: {} + es-toolkit@1.42.0: {} esbuild@0.25.12: optionalDependencies: @@ -4356,7 +4356,7 @@ snapshots: follow-redirects@1.15.11: {} - form-data@4.0.4: + form-data@4.0.5: dependencies: asynckit: 0.4.0 combined-stream: 1.0.8 @@ -4526,7 +4526,7 @@ snapshots: mdast-util-mdxjs-esm: 2.0.1 property-information: 7.1.0 space-separated-tokens: 2.0.2 - style-to-js: 1.1.19 + style-to-js: 1.1.21 unist-util-position: 5.0.0 vfile-message: 4.0.3 transitivePeerDependencies: @@ -4566,20 +4566,20 @@ snapshots: dependencies: lru-cache: 6.0.0 - html-dom-parser@5.1.1: + html-dom-parser@5.1.2: dependencies: domhandler: 5.0.3 htmlparser2: 10.0.0 - html-react-parser@5.2.8(@types/react@19.2.4)(react@19.2.0): + html-react-parser@5.2.10(@types/react@19.2.5)(react@19.2.0): dependencies: domhandler: 5.0.3 - html-dom-parser: 5.1.1 + html-dom-parser: 5.1.2 react: 19.2.0 react-property: 2.0.2 - style-to-js: 1.1.19 + style-to-js: 1.1.21 optionalDependencies: - '@types/react': 19.2.4 + '@types/react': 19.2.5 html-url-attributes@3.0.1: {} @@ -4609,7 +4609,7 @@ snapshots: ini@1.3.8: {} - inline-style-parser@0.2.6: {} + inline-style-parser@0.2.7: {} internmap@2.0.3: {} @@ -4843,9 +4843,9 @@ snapshots: type-fest: 0.18.1 yargs-parser: 20.2.9 - merge-refs@2.0.0(@types/react@19.2.4): + merge-refs@2.0.0(@types/react@19.2.5): optionalDependencies: - '@types/react': 19.2.4 + '@types/react': 19.2.5 micromark-core-commonmark@2.0.3: dependencies: @@ -5228,11 +5228,11 @@ snapshots: dependencies: react: 19.2.0 - react-markdown@10.1.0(@types/react@19.2.4)(react@19.2.0): + react-markdown@10.1.0(@types/react@19.2.5)(react@19.2.0): dependencies: '@types/hast': 3.0.4 '@types/mdast': 4.0.4 - '@types/react': 19.2.4 + '@types/react': 19.2.5 devlop: 1.1.0 hast-util-to-jsx-runtime: 2.3.6 html-url-attributes: 3.0.1 @@ -5254,18 +5254,18 @@ snapshots: qr.js: 0.0.0 react: 19.2.0 - react-redux@9.2.0(@types/react@19.2.4)(react@19.2.0)(redux@5.0.1): + react-redux@9.2.0(@types/react@19.2.5)(react@19.2.0)(redux@5.0.1): dependencies: '@types/use-sync-external-store': 0.0.6 react: 19.2.0 use-sync-external-store: 1.6.0(react@19.2.0) optionalDependencies: - '@types/react': 19.2.4 + '@types/react': 19.2.5 redux: 5.0.1 react-resize-detector@12.3.0(react@19.2.0): dependencies: - es-toolkit: 1.41.0 + es-toolkit: 1.42.0 react: 19.2.0 react-router-dom@6.30.2(react-dom@19.2.0(react@19.2.0))(react@19.2.0): @@ -5342,18 +5342,18 @@ snapshots: dependencies: picomatch: 2.3.1 - recharts@3.4.1(@types/react@19.2.4)(react-dom@19.2.0(react@19.2.0))(react-is@19.2.0)(react@19.2.0)(redux@5.0.1): + recharts@3.4.1(@types/react@19.2.5)(react-dom@19.2.0(react@19.2.0))(react-is@19.2.0)(react@19.2.0)(redux@5.0.1): dependencies: - '@reduxjs/toolkit': 2.10.1(react-redux@9.2.0(@types/react@19.2.4)(react@19.2.0)(redux@5.0.1))(react@19.2.0) + '@reduxjs/toolkit': 2.10.1(react-redux@9.2.0(@types/react@19.2.5)(react@19.2.0)(redux@5.0.1))(react@19.2.0) clsx: 2.1.1 decimal.js-light: 2.5.1 - es-toolkit: 1.41.0 + es-toolkit: 1.42.0 eventemitter3: 5.0.1 immer: 10.2.0 react: 19.2.0 react-dom: 19.2.0(react@19.2.0) react-is: 19.2.0 - react-redux: 9.2.0(@types/react@19.2.4)(react@19.2.0)(redux@5.0.1) + react-redux: 9.2.0(@types/react@19.2.5)(react@19.2.0)(redux@5.0.1) reselect: 5.1.1 tiny-invariant: 1.3.3 use-sync-external-store: 1.6.0(react@19.2.0) @@ -5592,13 +5592,13 @@ snapshots: dependencies: min-indent: 1.0.1 - style-to-js@1.1.19: + style-to-js@1.1.21: dependencies: - style-to-object: 1.0.12 + style-to-object: 1.0.14 - style-to-object@1.0.12: + style-to-object@1.0.14: dependencies: - inline-style-parser: 0.2.6 + inline-style-parser: 0.2.7 stylis@4.2.0: {} @@ -5974,9 +5974,9 @@ snapshots: zod@3.25.76: {} - zustand@5.0.8(@types/react@19.2.4)(immer@10.2.0)(react@19.2.0)(use-sync-external-store@1.6.0(react@19.2.0)): + zustand@5.0.8(@types/react@19.2.5)(immer@10.2.0)(react@19.2.0)(use-sync-external-store@1.6.0(react@19.2.0)): optionalDependencies: - '@types/react': 19.2.4 + '@types/react': 19.2.5 immer: 10.2.0 react: 19.2.0 use-sync-external-store: 1.6.0(react@19.2.0)