Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 38 additions & 27 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/defguard_core/src/db/models/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ pub struct DeviceConfig {
// Network: A stand-alone device added by a user permanently bound to one network, e.g. a printer
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema, Type)]
#[sqlx(type_name = "device_type", rename_all = "snake_case")]
#[serde(rename_all = "snake_case")]
pub enum DeviceType {
User,
Network,
Expand Down
8 changes: 3 additions & 5 deletions crates/defguard_core/src/db/models/wireguard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,10 @@ impl WireguardNetwork {
acl_enabled: bool,
acl_default_allow: bool,
location_mfa_mode: LocationMfaMode,
) -> Result<Self, WireguardNetworkError> {
) -> Self {
let prvkey = StaticSecret::random_from_rng(OsRng);
let pubkey = PublicKey::from(&prvkey);
Ok(Self {
Self {
id: NoId,
name,
address,
Expand All @@ -266,7 +266,7 @@ impl WireguardNetwork {
acl_enabled,
acl_default_allow,
location_mfa_mode,
})
}
}

/// Try to set `address` from `&str`.
Expand Down Expand Up @@ -2036,7 +2036,6 @@ mod test {
false,
LocationMfaMode::Disabled,
)
.unwrap()
.save(&pool)
.await
.unwrap();
Expand Down Expand Up @@ -2168,7 +2167,6 @@ mod test {
false,
LocationMfaMode::Disabled,
)
.unwrap()
.save(&pool)
.await
.unwrap();
Expand Down
2 changes: 0 additions & 2 deletions crates/defguard_core/src/enterprise/db/models/acl/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ async fn test_rule_relations(_: PgPoolOptions, options: PgConnectOptions) {
false,
LocationMfaMode::Disabled,
)
.unwrap()
.save(&pool)
.await
.unwrap();
Expand All @@ -200,7 +199,6 @@ async fn test_rule_relations(_: PgPoolOptions, options: PgConnectOptions) {
false,
LocationMfaMode::Disabled,
)
.unwrap()
.save(&pool)
.await
.unwrap();
Expand Down
1 change: 0 additions & 1 deletion crates/defguard_core/src/enterprise/directory_sync/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,6 @@ mod test {
false,
LocationMfaMode::Disabled,
)
.unwrap()
.save(pool)
.await
.unwrap();
Expand Down
7 changes: 7 additions & 0 deletions crates/defguard_core/src/handlers/app_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::{
auth::SessionInfo,
db::{Settings, WireguardNetwork},
enterprise::{
db::models::openid_provider::OpenIdProvider,
is_enterprise_enabled, is_enterprise_free,
license::get_cached_license,
limits::{LimitsExceeded, get_counts},
Expand Down Expand Up @@ -41,19 +42,24 @@ pub struct AppInfo {
smtp_enabled: bool,
license_info: LicenseInfo,
ldap_info: LdapInfo,
external_openid_enabled: bool,
}

pub(crate) async fn get_app_info(
State(appstate): State<AppState>,
_session: SessionInfo,
) -> ApiResult {
// both `await`s are executed upfront to avoid holding license `RwLock` across an await point
let networks = WireguardNetwork::all(&appstate.pool).await?;
let external_openid_enabled = OpenIdProvider::get_current(&appstate.pool).await?.is_some();

let settings = Settings::get_current_settings();
let enterprise = is_enterprise_enabled();
let license = get_cached_license();
let counts = get_counts();
let limits_exceeded = counts.get_exceeded_limits(license.as_ref());
let any_limit_exceeded = limits_exceeded.any();

let res = AppInfo {
network_present: !networks.is_empty(),
smtp_enabled: settings.smtp_configured(),
Expand All @@ -68,6 +74,7 @@ pub(crate) async fn get_app_info(
enabled: settings.ldap_enabled,
ad: settings.ldap_uses_ad,
},
external_openid_enabled,
};

Ok(ApiResponse::new(json!(res), StatusCode::OK))
Expand Down
43 changes: 40 additions & 3 deletions crates/defguard_core/src/handlers/wireguard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ use crate::{
},
},
},
enterprise::{handlers::CanManageDevices, limits::update_counts},
enterprise::{
db::models::openid_provider::OpenIdProvider, handlers::CanManageDevices,
is_enterprise_enabled, limits::update_counts,
},
events::{ApiEvent, ApiEventType, ApiRequestContext},
grpc::GatewayMap,
handlers::mail::send_new_device_added_email,
Expand Down Expand Up @@ -88,6 +91,36 @@ impl WireguardNetworkData {
.as_ref()
.map_or(Vec::new(), |ips| parse_network_address_list(ips))
}

pub(crate) async fn validate_location_mfa_mode<'e, E: sqlx::PgExecutor<'e>>(
&self,
executor: E,
) -> Result<(), WebError> {
// if external MFA was chosen verify if enterprise features are enabled
// and external OpenID provider is configured
if self.location_mfa_mode == LocationMfaMode::External {
if !is_enterprise_enabled() {
error!(
"Unable to create location with external MFA. External OpenID provider is not configured"
);

return Err(WebError::Forbidden(
"Cannot enable external MFA. Enterprise features are disabled".into(),
));
}

if OpenIdProvider::get_current(executor).await?.is_none() {
error!(
"Unable to create location with external MFA. External OpenID provider is not configured"
);
return Err(WebError::BadRequest(
"Cannot enable external MFA. External OpenID provider is not configured".into(),
));
}
}

Ok(())
}
}

// Used in process of importing network from WireGuard config
Expand Down Expand Up @@ -137,6 +170,9 @@ pub(crate) async fn create_network(
"User {} creating WireGuard network {network_name}",
session.user.username
);

data.validate_location_mfa_mode(&appstate.pool).await?;

let allowed_ips = data.parse_allowed_ips();
let network = WireguardNetwork::new(
data.name,
Expand All @@ -150,8 +186,7 @@ pub(crate) async fn create_network(
data.acl_enabled,
data.acl_default_allow,
data.location_mfa_mode,
)
.map_err(|_| WebError::Serialization("Invalid network address".into()))?;
);

let mut transaction = appstate.pool.begin().await?;
let network = network.save(&mut *transaction).await?;
Expand Down Expand Up @@ -220,6 +255,8 @@ pub(crate) async fn modify_network(
"User {} updating WireGuard network {network_id}",
session.user.username
);
data.validate_location_mfa_mode(&appstate.pool).await?;

let mut network = find_network(network_id, &appstate.pool).await?;
// store network before mods
let before = network.clone();
Expand Down
Loading