diff --git a/crates/defguard_core/build.rs b/crates/defguard_core/build.rs index 96c212192d..d3deeca182 100644 --- a/crates/defguard_core/build.rs +++ b/crates/defguard_core/build.rs @@ -9,6 +9,6 @@ fn main() -> Result<(), Box> { &["src/enterprise/proto/license.proto"], &["src/enterprise/proto"], )?; - println!("cargo:rerun-if-changed=src/enterprise"); + println!("cargo:rerun-if-changed=src/enterprise/proto"); Ok(()) } diff --git a/crates/defguard_core/src/auth/mod.rs b/crates/defguard_core/src/auth/mod.rs index 462f904fa6..3077df406e 100644 --- a/crates/defguard_core/src/auth/mod.rs +++ b/crates/defguard_core/src/auth/mod.rs @@ -18,7 +18,7 @@ use crate::{ Group, OAuth2Token, Session, SessionState, User, models::{group::Permission, oauth2client::OAuth2Client}, }, - enterprise::{db::models::api_tokens::ApiToken, is_enterprise_enabled}, + enterprise::{db::models::api_tokens::ApiToken, is_business_license_active}, error::WebError, handlers::SESSION_COOKIE_NAME, }; @@ -38,7 +38,7 @@ where let appstate = AppState::from_ref(state); // first try to authenticate by API token if one is found in header - if is_enterprise_enabled() { + if is_business_license_active() { let maybe_auth_header: Option>> = as OptionalFromRequestParts>::from_request_parts(parts, state) .await diff --git a/crates/defguard_core/src/db/models/wireguard.rs b/crates/defguard_core/src/db/models/wireguard.rs index 33c26e4989..91966330f0 100644 --- a/crates/defguard_core/src/db/models/wireguard.rs +++ b/crates/defguard_core/src/db/models/wireguard.rs @@ -40,7 +40,7 @@ use super::{ wireguard_peer_stats::WireguardPeerStats, }; use crate::{ - enterprise::{firewall::FirewallError, is_enterprise_enabled}, + enterprise::{firewall::FirewallError, is_enterprise_license_active}, grpc::gateway::{send_multiple_wireguard_events, state::GatewayState}, wg_config::ImportedDevice, }; @@ -1335,7 +1335,8 @@ impl WireguardNetwork { /// - Enterprise is enabled #[must_use] pub fn should_prevent_service_location_usage(&self) -> bool { - self.service_location_mode != ServiceLocationMode::Disabled && !is_enterprise_enabled() + self.service_location_mode != ServiceLocationMode::Disabled + && !is_enterprise_license_active() } } diff --git a/crates/defguard_core/src/enterprise/activity_log_stream/activity_log_stream_manager.rs b/crates/defguard_core/src/enterprise/activity_log_stream/activity_log_stream_manager.rs index 3faff81ab0..e1a2a467b4 100644 --- a/crates/defguard_core/src/enterprise/activity_log_stream/activity_log_stream_manager.rs +++ b/crates/defguard_core/src/enterprise/activity_log_stream/activity_log_stream_manager.rs @@ -10,7 +10,7 @@ use super::ActivityLogStreamReconfigurationNotification; use crate::enterprise::{ activity_log_stream::http_stream::{HttpActivityLogStreamConfig, run_http_stream_task}, db::models::activity_log_stream::{ActivityLogStream, ActivityLogStreamConfig}, - is_enterprise_enabled, + is_business_license_active, }; // check if enterprise features are enabled every minute @@ -27,7 +27,7 @@ pub async fn run_activity_log_stream_manager( let mut enterprise_check_timer = interval(Duration::from_secs(ENTERPRISE_CHECK_PERIOD_SECS)); // initialize enterprise features status - let mut enterprise_features_enabled = is_enterprise_enabled(); + let mut enterprise_features_enabled = is_business_license_active(); loop { let mut handles = JoinSet::<()>::new(); @@ -94,7 +94,7 @@ pub async fn run_activity_log_stream_manager( } _ = enterprise_check_timer.tick() => { // check if enterprise features status has changed - let current_enterprise_features_enabled = is_enterprise_enabled(); + let current_enterprise_features_enabled = is_business_license_active(); if current_enterprise_features_enabled != enterprise_features_enabled { warn!("Activity log stream manager will reload, detected license enterprise features status has changed"); enterprise_features_enabled = current_enterprise_features_enabled; diff --git a/crates/defguard_core/src/enterprise/db/models/enterprise_settings.rs b/crates/defguard_core/src/enterprise/db/models/enterprise_settings.rs index d1c9be350b..916417a973 100644 --- a/crates/defguard_core/src/enterprise/db/models/enterprise_settings.rs +++ b/crates/defguard_core/src/enterprise/db/models/enterprise_settings.rs @@ -1,7 +1,7 @@ use sqlx::{PgExecutor, Type, query, query_as}; use struct_patch::Patch; -use crate::enterprise::is_enterprise_enabled; +use crate::enterprise::is_business_license_active; #[derive(Debug, Deserialize, Patch, Serialize)] #[patch(attribute(derive(Deserialize, Serialize)))] @@ -35,7 +35,7 @@ impl EnterpriseSettings { { // avoid holding the rwlock across await, makes the future !Send // and therefore unusable in axum handlers - if is_enterprise_enabled() { + if is_business_license_active() { let settings = query_as!( Self, "SELECT admin_device_management, \ diff --git a/crates/defguard_core/src/enterprise/directory_sync/mod.rs b/crates/defguard_core/src/enterprise/directory_sync/mod.rs index b37fccba56..9f56e8f211 100644 --- a/crates/defguard_core/src/enterprise/directory_sync/mod.rs +++ b/crates/defguard_core/src/enterprise/directory_sync/mod.rs @@ -11,12 +11,12 @@ use sqlx::{PgConnection, PgPool, error::Error as SqlxError}; use thiserror::Error; use tokio::sync::broadcast::Sender; -#[cfg(not(test))] -use super::is_enterprise_enabled; use super::{ db::models::openid_provider::{DirectorySyncTarget, OpenIdProvider}, ldap::utils::ldap_update_users_state, }; +#[cfg(not(test))] +use crate::enterprise::is_business_license_active; use crate::{ db::{GatewayEvent, Group, User}, enterprise::{ @@ -383,7 +383,7 @@ pub(crate) async fn test_directory_sync_connection( pool: &PgPool, ) -> Result<(), DirectorySyncError> { #[cfg(not(test))] - if !is_enterprise_enabled() { + if !is_business_license_active() { debug!("Enterprise is not enabled, skipping testing directory sync connection"); return Ok(()); } @@ -408,7 +408,7 @@ pub(crate) async fn sync_user_groups_if_configured( wg_tx: &Sender, ) -> Result<(), DirectorySyncError> { #[cfg(not(test))] - if !is_enterprise_enabled() { + if !is_business_license_active() { debug!("Enterprise is not enabled, skipping syncing user groups"); return Ok(()); } @@ -966,7 +966,7 @@ pub(crate) async fn do_directory_sync( wireguard_tx: &Sender, ) -> Result<(), DirectorySyncError> { #[cfg(not(test))] - if !is_enterprise_enabled() { + if !is_business_license_active() { debug!("Enterprise is not enabled, skipping performing directory sync"); return Ok(()); } diff --git a/crates/defguard_core/src/enterprise/firewall/mod.rs b/crates/defguard_core/src/enterprise/firewall/mod.rs index 5e2b7e8d97..eebb8bcc50 100644 --- a/crates/defguard_core/src/enterprise/firewall/mod.rs +++ b/crates/defguard_core/src/enterprise/firewall/mod.rs @@ -23,7 +23,7 @@ use crate::{ db::{Device, User, WireguardNetwork}, enterprise::{ db::models::{acl::AliasKind, snat::UserSnatBinding}, - is_enterprise_enabled, + is_business_license_active, }, }; @@ -903,7 +903,7 @@ impl WireguardNetwork { conn: &mut PgConnection, ) -> Result, FirewallError> { // do a license check - if !is_enterprise_enabled() { + if !is_business_license_active() { debug!( "Enterprise features are disabled, skipping generating firewall config for \ location {self}" diff --git a/crates/defguard_core/src/enterprise/grpc/desktop_client_mfa.rs b/crates/defguard_core/src/enterprise/grpc/desktop_client_mfa.rs index 6e5e8d032d..b76c9170cf 100644 --- a/crates/defguard_core/src/enterprise/grpc/desktop_client_mfa.rs +++ b/crates/defguard_core/src/enterprise/grpc/desktop_client_mfa.rs @@ -6,7 +6,7 @@ use tonic::Status; use crate::{ enterprise::{ handlers::openid_login::{extract_state_data, user_from_claims}, - is_enterprise_enabled, + is_business_license_active, }, events::{BidiRequestContext, BidiStreamEvent, BidiStreamEventType, DesktopClientMfaEvent}, grpc::{ @@ -23,7 +23,7 @@ impl ClientMfaServer { info: Option, ) -> Result<(), Status> { debug!("Received OIDC MFA authentication request: {request:?}"); - if !is_enterprise_enabled() { + if !is_business_license_active() { error!("OIDC MFA method requires enterprise feature to be enabled"); return Err(Status::invalid_argument("OIDC MFA method is not supported")); } diff --git a/crates/defguard_core/src/enterprise/grpc/polling.rs b/crates/defguard_core/src/enterprise/grpc/polling.rs index 8e04e9a411..e210dabbd0 100644 --- a/crates/defguard_core/src/enterprise/grpc/polling.rs +++ b/crates/defguard_core/src/enterprise/grpc/polling.rs @@ -5,7 +5,7 @@ use tonic::Status; use crate::{ db::{Device, User, models::polling_token::PollingToken}, - enterprise::is_enterprise_enabled, + enterprise::is_business_license_active, grpc::utils::build_device_config_response, }; @@ -24,7 +24,7 @@ impl PollingServer { debug!("Validating polling token. Token: {token}"); // Polling service is enterprise-only, check the lincense - if !is_enterprise_enabled() { + if !is_business_license_active() { debug!("Instance has enterprise features disabled, denying instance polling info"); return Err(Status::failed_precondition("no valid license")); } diff --git a/crates/defguard_core/src/enterprise/handlers/mod.rs b/crates/defguard_core/src/enterprise/handlers/mod.rs index dac781bcfd..5686717e6a 100644 --- a/crates/defguard_core/src/enterprise/handlers/mod.rs +++ b/crates/defguard_core/src/enterprise/handlers/mod.rs @@ -17,7 +17,7 @@ use axum::{ }; use super::{ - db::models::enterprise_settings::EnterpriseSettings, is_enterprise_enabled, + db::models::enterprise_settings::EnterpriseSettings, is_business_license_active, license::get_cached_license, }; use crate::{appstate::AppState, error::WebError}; @@ -37,7 +37,7 @@ where type Rejection = WebError; async fn from_request_parts(_parts: &mut Parts, _state: &S) -> Result { - if is_enterprise_enabled() { + if is_business_license_active() { Ok(LicenseInfo { valid: true }) } else { Err(WebError::Forbidden( diff --git a/crates/defguard_core/src/enterprise/ldap/mod.rs b/crates/defguard_core/src/enterprise/ldap/mod.rs index 17152ee73d..b3e97a5564 100644 --- a/crates/defguard_core/src/enterprise/ldap/mod.rs +++ b/crates/defguard_core/src/enterprise/ldap/mod.rs @@ -18,7 +18,7 @@ use sync::{get_ldap_sync_status, is_ldap_desynced, set_ldap_sync_status}; use self::error::LdapError; use crate::{ db::{self, User}, - enterprise::{is_enterprise_enabled, ldap::model::extract_dn_path, limits::update_counts}, + enterprise::{is_business_license_active, ldap::model::extract_dn_path, limits::update_counts}, }; #[cfg(not(test))] @@ -54,7 +54,7 @@ pub(crate) async fn do_ldap_sync(pool: &PgPool) -> Result<(), LdapError> { return Ok(()); } - if !is_enterprise_enabled() { + if !is_business_license_active() { info!( "Enterprise features are disabled, not performing LDAP sync and automatically disabling it" ); @@ -100,7 +100,7 @@ where F: Future>, { let settings = Settings::get_current_settings(); - if !is_enterprise_enabled() { + if !is_business_license_active() { info!("Enterprise features are disabled, not performing LDAP operation"); set_ldap_sync_status(LdapSyncStatus::OutOfSync, pool).await?; return Err(LdapError::EnterpriseDisabled("LDAP".to_string())); diff --git a/crates/defguard_core/src/enterprise/license.rs b/crates/defguard_core/src/enterprise/license.rs index 2bec3c4747..2f0f8b67e9 100644 --- a/crates/defguard_core/src/enterprise/license.rs +++ b/crates/defguard_core/src/enterprise/license.rs @@ -1,4 +1,4 @@ -use std::time::Duration; +use std::{fmt::Display, time::Duration}; use anyhow::Result; use base64::prelude::*; @@ -20,7 +20,9 @@ use thiserror::Error; use tokio::time::sleep; use super::limits::Counts; -use crate::grpc::proto::enterprise::license::{LicenseKey, LicenseLimits, LicenseMetadata}; +use crate::grpc::proto::enterprise::license::{ + LicenseKey, LicenseLimits, LicenseMetadata, LicenseTier as LicenseTierProto, +}; const LICENSE_SERVER_URL: &str = "https://pkgs.defguard.net/api/license/renew"; @@ -195,6 +197,8 @@ pub enum LicenseError { "License limits exceeded. To upgrade your license please contact salesdefguard.net" )] LicenseLimitsExceeded, + #[error("License tier is lower than required minimum")] + LicenseTierTooLow, } #[derive(Debug, Serialize, Deserialize)] @@ -202,6 +206,28 @@ struct RefreshRequestResponse { key: String, } +/// Represents license tiers +/// +/// Variant order must be maintained to go from lowest (first) to highest (last) tier +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, PartialOrd)] +pub enum LicenseTier { + Business, // this corresponds to both Team & Business level in our current pricing structure + Enterprise, +} + +impl Display for LicenseTier { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Business => { + write!(f, "Business") + } + Self::Enterprise => { + write!(f, "Enterprise") + } + } + } +} + #[derive(Debug, Serialize, Deserialize, Clone)] pub struct License { pub customer_id: String, @@ -209,6 +235,7 @@ pub struct License { pub valid_until: Option>, pub limits: Option, pub version_date_limit: Option>, + pub tier: LicenseTier, } impl License { @@ -219,6 +246,7 @@ impl License { valid_until: Option>, limits: Option, version_date_limit: Option>, + tier: LicenseTier, ) -> Self { Self { customer_id, @@ -226,6 +254,7 @@ impl License { valid_until, limits, version_date_limit, + tier, } } @@ -306,12 +335,27 @@ impl License { None => None, }; + let license_tier = match LicenseTierProto::try_from(metadata.tier) { + Ok(LicenseTierProto::Enterprise) => LicenseTier::Enterprise, + // fall back to Business tier for legacy licenses + Ok(LicenseTierProto::Business | LicenseTierProto::Unspecified) => { + LicenseTier::Business + } + Err(err) => { + error!("Failed to read license tier from license metadata: {err}"); + return Err(LicenseError::DecodeError( + "Failed to decode license tier metadata".into(), + )); + } + }; + let license = License::new( metadata.customer_id, metadata.subscription, valid_until, metadata.limits, version_date_limit, + license_tier, ); if license.requires_renewal() { @@ -448,6 +492,14 @@ impl License { self.is_expired() } } + + // Checks if License tier is lower than specified minimum + // + // Ordering is implemented by the `LicenseTier` enum itself + #[must_use] + pub(crate) fn is_lower_tier(&self, minimum_tier: LicenseTier) -> bool { + self.tier < minimum_tier + } } /// Exchange the currently stored key for a new one from the license server. @@ -510,9 +562,11 @@ async fn renew_license() -> Result { /// 1. Does the cached license exist /// 2. Is the cached license past its maximum expiry date /// 3. Does current object count exceed license limits +/// 4. Is the license of at least the specified tier (or higher) pub(crate) fn validate_license( license: Option<&License>, counts: &Counts, + minimum_tier: LicenseTier, ) -> Result<(), LicenseError> { debug!("Validating if the license is present, not expired and not exceeding limits..."); match license { @@ -523,6 +577,9 @@ pub(crate) fn validate_license( if counts.is_over_license_limits(license) { return Err(LicenseError::LicenseLimitsExceeded); } + if license.is_lower_tier(minimum_tier) { + return Err(LicenseError::LicenseTierTooLow); + } Ok(()) } None => Err(LicenseError::LicenseNotFound), @@ -695,6 +752,9 @@ mod test { assert_eq!(limits.users, 10); assert_eq!(limits.devices, 100); assert_eq!(limits.locations, 5); + + // pre-1.6 license defaults to Business tier + assert_eq!(license.tier, LicenseTier::Business); } #[test] @@ -713,6 +773,9 @@ mod test { // legacy license is unlimited assert!(license.limits.is_none()); + + // legacy license defaults to Business tier + assert_eq!(license.tier, LicenseTier::Business); } #[test] @@ -728,6 +791,9 @@ mod test { license.valid_until.unwrap(), Utc.with_ymd_and_hms(2024, 12, 26, 13, 57, 54).unwrap() ); + + // pre-1.6 license defaults to Business tier + assert_eq!(license.tier, LicenseTier::Business); } #[test] @@ -735,8 +801,8 @@ mod test { let license = "CigKIDBjNGRjYjU0MDA1NDRkNDdhZDg2MTdmY2RmMjcwNGNiGOLBtbsGErUBiLMEAAEIAB0WIQSaLjwX4m6jCO3NypmohGwBApqEhAUCZ3ZjywAKCRCohGwBApqEhEwFBACpHDnIszU2+KZcGhi3kycd3a12PyXJuFhhY4cuSyC8YEND85BplSWK1L8nu5ghFULFlddXP9HTHdxhJbtx4SgOQ8pxUY3+OpBN4rfJOMF61tvMRLaWlz7FWm/RnHe8cpoAOYm4oKRS0+FA2qLThxSsVa+S907ty19c6mcDgi6V5g=="; let license = License::from_base64(license).unwrap(); let counts = Counts::default(); - assert!(validate_license(Some(&license), &counts).is_err()); - assert!(validate_license(None, &counts).is_err()); + assert!(validate_license(Some(&license), &counts, LicenseTier::Business).is_err()); + assert!(validate_license(None, &counts, LicenseTier::Business).is_err()); // One day past the expiry date, non-subscription license let license = License::new( @@ -745,8 +811,9 @@ mod test { Some(Utc::now() - TimeDelta::days(1)), None, None, + LicenseTier::Business, ); - assert!(validate_license(Some(&license), &counts).is_err()); + assert!(validate_license(Some(&license), &counts, LicenseTier::Business).is_err()); // One day before the expiry date, non-subscription license let license = License::new( @@ -755,12 +822,20 @@ mod test { Some(Utc::now() + TimeDelta::days(1)), None, None, + LicenseTier::Business, ); - assert!(validate_license(Some(&license), &counts).is_ok()); + assert!(validate_license(Some(&license), &counts, LicenseTier::Business).is_ok()); // No expiry date, non-subscription license - let license = License::new("test".to_string(), false, None, None, None); - assert!(validate_license(Some(&license), &counts).is_ok()); + let license = License::new( + "test".to_string(), + false, + None, + None, + None, + LicenseTier::Business, + ); + assert!(validate_license(Some(&license), &counts, LicenseTier::Business).is_ok()); // One day past the maximum overdue date let license = License::new( @@ -769,8 +844,9 @@ mod test { Some(Utc::now() - MAX_OVERDUE_TIME - TimeDelta::days(1)), None, None, + LicenseTier::Business, ); - assert!(validate_license(Some(&license), &counts).is_err()); + assert!(validate_license(Some(&license), &counts, LicenseTier::Business).is_err()); // One day before the maximum overdue date let license = License::new( @@ -779,8 +855,9 @@ mod test { Some(Utc::now() - MAX_OVERDUE_TIME + TimeDelta::days(1)), None, None, + LicenseTier::Business, ); - assert!(validate_license(Some(&license), &counts).is_ok()); + assert!(validate_license(Some(&license), &counts, LicenseTier::Business).is_ok()); let counts = Counts::new(5, 5, 5, 5); @@ -796,8 +873,9 @@ mod test { network_devices: Some(1), }), None, + LicenseTier::Business, ); - assert!(validate_license(Some(&license), &counts).is_err()); + assert!(validate_license(Some(&license), &counts, LicenseTier::Business).is_err()); // Below object count limits let license = License::new( @@ -811,7 +889,23 @@ mod test { network_devices: Some(10), }), None, + LicenseTier::Business, ); - assert!(validate_license(Some(&license), &counts).is_ok()); + assert!(validate_license(Some(&license), &counts, LicenseTier::Business).is_ok()); + } + + #[test] + fn test_license_tiers() { + let legacy_license = "CjAKIDBjNGRjYjU0MDA1NDRkNDdhZDg2MTdmY2RmMjcwNGNiGOLBtbsGIgYIChBkGAUStQGIswQAAQgAHRYhBJouPBfibqMI7c3KmaiEbAECmoSEBQJnd9EMAAoJEKiEbAECmoSE/0kEAIb18pVTEYWQo0w6813nShJqi7++Uo/fX4pxaAzEiG9r5HGpZSbsceCarMiK1rBr93HOIMeDRsbZmJBA/MAYGi32uXgzLE8fGSd4lcUPAbpvlj7KNvQNH6sMelzQVw+AJVY+IASqO84nfy92taEVagbLqIwl/eSQUnehJBS+B5/z"; + let legacy_license = License::from_base64(legacy_license).unwrap(); + assert_eq!(legacy_license.tier, LicenseTier::Business); + + let business_license = "Ci4KJGEyYjE1M2MzLWYwZmEtNGUzNC05ZThkLWY0Nzk1NTA4OWMwNRiI7KTKBjABErUBiLMEAAEIAB0WIQSaLjwX4m6jCO3NypmohGwBApqEhAUCaT/7iAAKCRCohGwBApqEhHdaA/0QqDNiryYSzWTEayBMwEBE6KAxTEtwRzXOxQxsnULjbQMol/SRjqfu8iwlI4IeBQP3CuAR9kglewvwg3osXDldIns46W/cDBd0jxANebLY9SPz0JS6pStMnSzhZ6rFW5ns3nCz86EOyAA9npx0/qxHCbtT6Qzi//5JYQe6VvvCmw=="; + let business_license = License::from_base64(business_license).unwrap(); + assert_eq!(business_license.tier, LicenseTier::Business); + + let enterprise_license = "Ci4KJDRiYjMzZTUyLWUzNGMtNGQyMS1iNDVhLTkxY2EzYTMzNGMwORiy7KTKBjACErUBiLMEAAEIAB0WIQSaLjwX4m6jCO3NypmohGwBApqEhAUCaT/7sgAKCRCohGwBApqEhIMzBACGd7vIyLaRVGV/MAD8bpgWURG1x1tlxD9ehaSNkk01GkfZc+6+QwiTUBUOSp0MKPtuLmow5AIRKS9M75CQQ4bGtjLWO5cXJm1sduRpTvXwPLXNkRFPSxhjHmo4yjFFHMHMySqQE2WUjcz/b5dMT/WNqWYg7tSfT72eiK18eSVFTA=="; + let enterprise_license = License::from_base64(enterprise_license).unwrap(); + assert_eq!(enterprise_license.tier, LicenseTier::Enterprise); } } diff --git a/crates/defguard_core/src/enterprise/limits.rs b/crates/defguard_core/src/enterprise/limits.rs index 6e7d0f66b7..ff78223a6e 100644 --- a/crates/defguard_core/src/enterprise/limits.rs +++ b/crates/defguard_core/src/enterprise/limits.rs @@ -18,7 +18,7 @@ pub struct Counts { user: u32, user_device: u32, network_device: u32, - wireguard_network: u32, + location: u32, } global_value!(COUNTS, Counts, Counts::default(), set_counts, get_counts); @@ -52,7 +52,7 @@ pub async fn update_counts<'e, E: sqlx::PgExecutor<'e>>(executor: E) -> Result<( .network_devices .try_into() .expect("device count should never be negative"), - wireguard_network: result + location: result .wireguard_networks .try_into() .expect("network count should never be negative"), @@ -77,22 +77,17 @@ impl Counts { Self { user: 0, user_device: 0, - wireguard_network: 0, + location: 0, network_device: 0, } } #[cfg(test)] - pub(crate) fn new( - user: u32, - user_device: u32, - wireguard_network: u32, - network_device: u32, - ) -> Self { + pub(crate) fn new(user: u32, user_device: u32, location: u32, network_device: u32) -> Self { Self { user, user_device, - wireguard_network, + location, network_device, } } @@ -114,7 +109,7 @@ impl Counts { debug!("Cached license not found. Using default limits for validation..."); self.user > DEFAULT_USERS_LIMIT || self.user_device > DEFAULT_DEVICES_LIMIT - || self.wireguard_network > DEFAULT_LOCATIONS_LIMIT + || self.location > DEFAULT_LOCATIONS_LIMIT || self.network_device > DEFAULT_NETWORK_DEVICES_LIMIT } } @@ -135,19 +130,19 @@ impl Counts { Some(limits) => { self.user > limits.users || self.is_over_device_limit(limits) - || self.wireguard_network > limits.locations + || self.location > limits.locations } // unlimited license None => false, } } - /// Checks if current object count exceeds default limits - pub(crate) fn needs_enterprise_license(&self) -> bool { - debug!("Checking if current object counts ({self:?}) exceed default limits"); + /// Checks if current object count exceeds default free tier limits + pub(crate) fn needs_paid_license(&self) -> bool { + debug!("Checking if current object counts ({self:?}) exceed default free tier limits"); self.user > DEFAULT_USERS_LIMIT || self.user_device > DEFAULT_DEVICES_LIMIT - || self.wireguard_network > DEFAULT_LOCATIONS_LIMIT + || self.location > DEFAULT_LOCATIONS_LIMIT || self.network_device > DEFAULT_NETWORK_DEVICES_LIMIT } @@ -161,7 +156,7 @@ impl Counts { } else { self.user_device + self.network_device > limits.devices }, - wireguard_network: self.wireguard_network > limits.locations, + wireguard_network: self.location > limits.locations, network_device: match limits.network_devices { Some(devices) => self.network_device > devices, None => false, @@ -179,7 +174,7 @@ impl Counts { LimitsExceeded { user: self.user > DEFAULT_DEVICES_LIMIT, device: self.user_device > DEFAULT_DEVICES_LIMIT, - wireguard_network: self.wireguard_network > DEFAULT_LOCATIONS_LIMIT, + wireguard_network: self.location > DEFAULT_LOCATIONS_LIMIT, network_device: self.network_device > DEFAULT_NETWORK_DEVICES_LIMIT, } } @@ -208,7 +203,7 @@ mod test { use super::*; use crate::{ - enterprise::license::{License, set_cached_license}, + enterprise::license::{License, LicenseTier, set_cached_license}, grpc::proto::enterprise::license::LicenseLimits, }; @@ -223,7 +218,7 @@ mod test { let counts = Counts { user: 5, user_device: 15, - wireguard_network: 3, + location: 3, network_device: 6, }; assert!(counts.is_over_device_limit(&limits)); @@ -231,7 +226,7 @@ mod test { let counts = Counts { user: 5, user_device: 10, - wireguard_network: 3, + location: 3, network_device: 5, }; assert!(!counts.is_over_device_limit(&limits)); @@ -246,7 +241,7 @@ mod test { let counts = Counts { user: 5, user_device: 15, - wireguard_network: 3, + location: 3, network_device: 6, }; assert!(!counts.is_over_device_limit(&limits)); @@ -254,7 +249,7 @@ mod test { let counts = Counts { user: 5, user_device: 15, - wireguard_network: 3, + location: 3, network_device: 11, }; assert!(counts.is_over_device_limit(&limits)); @@ -265,7 +260,7 @@ mod test { let counts = Counts { user: 1, user_device: 2, - wireguard_network: 3, + location: 3, network_device: 4, }; @@ -275,7 +270,7 @@ mod test { assert_eq!(counts.user, 1); assert_eq!(counts.user_device, 2); - assert_eq!(counts.wireguard_network, 3); + assert_eq!(counts.location, 3); } #[test] @@ -285,7 +280,7 @@ mod test { let counts = Counts { user: DEFAULT_USERS_LIMIT + 1, user_device: 1, - wireguard_network: 1, + location: 1, network_device: 1, }; set_counts(counts); @@ -298,7 +293,7 @@ mod test { let counts = Counts { user: 1, user_device: DEFAULT_DEVICES_LIMIT + 1, - wireguard_network: 1, + location: 1, network_device: 1, }; set_counts(counts); @@ -311,7 +306,7 @@ mod test { let counts = Counts { user: 1, user_device: 1, - wireguard_network: DEFAULT_LOCATIONS_LIMIT + 1, + location: DEFAULT_LOCATIONS_LIMIT + 1, network_device: 1, }; set_counts(counts); @@ -324,7 +319,7 @@ mod test { let counts = Counts { user: 1, user_device: 1, - wireguard_network: 1, + location: 1, network_device: 1, }; set_counts(counts); @@ -337,7 +332,7 @@ mod test { let counts = Counts { user: DEFAULT_USERS_LIMIT + 1, user_device: DEFAULT_DEVICES_LIMIT, - wireguard_network: DEFAULT_LOCATIONS_LIMIT, + location: DEFAULT_LOCATIONS_LIMIT, network_device: 1, }; set_counts(counts); @@ -365,6 +360,7 @@ mod test { Some(Utc::now() + TimeDelta::days(1)), Some(limits), None, + LicenseTier::Business, ); set_cached_license(Some(license)); @@ -373,7 +369,7 @@ mod test { let counts = Counts { user: users_limit + 1, user_device: 1, - wireguard_network: 1, + location: 1, network_device: 1, }; set_counts(counts); @@ -386,7 +382,7 @@ mod test { let counts = Counts { user: 1, user_device: devices_limit + 1, - wireguard_network: 1, + location: 1, network_device: 1, }; set_counts(counts); @@ -399,7 +395,7 @@ mod test { let counts = Counts { user: 1, user_device: 1, - wireguard_network: locations_limit + 1, + location: locations_limit + 1, network_device: 1, }; set_counts(counts); @@ -412,7 +408,7 @@ mod test { let counts = Counts { user: users_limit, user_device: devices_limit, - wireguard_network: locations_limit, + location: locations_limit, network_device: network_devices_limit, }; set_counts(counts); @@ -425,7 +421,7 @@ mod test { let counts = Counts { user: users_limit + 1, user_device: devices_limit + 1, - wireguard_network: locations_limit + 1, + location: locations_limit + 1, network_device: network_devices_limit + 1, }; set_counts(counts); @@ -442,6 +438,7 @@ mod test { Some(Utc::now() + TimeDelta::days(1)), None, None, + LicenseTier::Business, ); set_cached_license(Some(license)); @@ -450,7 +447,7 @@ mod test { let counts = Counts { user: u32::MAX, user_device: u32::MAX, - wireguard_network: u32::MAX, + location: u32::MAX, network_device: u32::MAX, }; set_counts(counts); @@ -469,7 +466,7 @@ mod test { let counts = Counts { user: exceed_user, user_device: 0, - wireguard_network: 0, + location: 0, network_device: 0, }; set_counts(counts); @@ -483,7 +480,7 @@ mod test { let counts = Counts { user: 0, user_device: exceed_device, - wireguard_network: 0, + location: 0, network_device: 0, }; set_counts(counts); @@ -497,7 +494,7 @@ mod test { let counts = Counts { user: 0, user_device: 0, - wireguard_network: exceed_wireguard_network, + location: exceed_wireguard_network, network_device: 0, }; set_counts(counts); @@ -510,7 +507,7 @@ mod test { let counts = Counts { user: 0, user_device: 0, - wireguard_network: 0, + location: 0, network_device: exceed_network_device, }; @@ -525,7 +522,7 @@ mod test { let counts = Counts { user: 0, user_device: 0, - wireguard_network: 0, + location: 0, network_device: 0, }; set_counts(counts); @@ -547,11 +544,12 @@ mod test { network_devices: Some(2), }), None, + LicenseTier::Business, ); let counts = Counts { user: 3, user_device: 3, - wireguard_network: 3, + location: 3, network_device: 3, }; set_counts(counts); @@ -568,11 +566,12 @@ mod test { Some(Utc::now() + TimeDelta::days(1)), None, None, + LicenseTier::Business, ); let counts = Counts { user: 300, user_device: 300, - wireguard_network: 300, + location: 300, network_device: 300, }; set_counts(counts); diff --git a/crates/defguard_core/src/enterprise/mod.rs b/crates/defguard_core/src/enterprise/mod.rs index 679296908e..4d16e6a429 100644 --- a/crates/defguard_core/src/enterprise/mod.rs +++ b/crates/defguard_core/src/enterprise/mod.rs @@ -13,17 +13,34 @@ mod utils; use license::{get_cached_license, validate_license}; use limits::get_counts; -pub(crate) fn is_enterprise_enabled() -> bool { - debug!("Checking if enterprise features should be enabled"); +use crate::enterprise::license::LicenseTier; + +/// Helper function to gate features which require a base license (Team or Business tier) +pub(crate) fn is_business_license_active() -> bool { + is_license_tier_active(LicenseTier::Business) +} + +/// Helper function to gate features which require an Enterprise tier license +pub(crate) fn is_enterprise_license_active() -> bool { + is_license_tier_active(LicenseTier::Enterprise) +} + +/// Shared logic for gating features to specific license tiers +fn is_license_tier_active(tier: LicenseTier) -> bool { + debug!("Checking if features for {tier} license tier should be enabled"); + + // get current object counts let counts = get_counts(); - if counts.needs_enterprise_license() { + + // only check license if object count exceed free limit + if counts.needs_paid_license() { debug!("User is over limit, checking his license"); let license = get_cached_license(); - let validation_result = validate_license(license.as_ref(), &counts); + let validation_result = validate_license(license.as_ref(), &counts, tier); debug!("License validation result: {:?}", validation_result); validation_result.is_ok() } else { - debug!("User is not over limit, allowing enterprise features"); + debug!("User is not over limit, allowing {tier} tier features"); true } } @@ -35,9 +52,9 @@ pub(crate) fn is_enterprise_free() -> bool { debug!("Checking if enterprise features are a part of the free version"); let counts = get_counts(); let license = get_cached_license(); - if validate_license(license.as_ref(), &counts).is_ok() { + if validate_license(license.as_ref(), &counts, LicenseTier::Business).is_ok() { false - } else if counts.needs_enterprise_license() { + } else if counts.needs_paid_license() { debug!("User is over limit, the enterprise features are not free"); false } else { @@ -45,3 +62,97 @@ pub(crate) fn is_enterprise_free() -> bool { true } } + +#[cfg(test)] +mod test { + use chrono::{TimeDelta, Utc}; + + use crate::{ + enterprise::{ + is_business_license_active, is_enterprise_free, is_enterprise_license_active, + license::{License, LicenseTier, set_cached_license}, + limits::{Counts, set_counts}, + }, + grpc::proto::enterprise::license::LicenseLimits, + }; + + #[test] + fn test_feature_gates_no_license() { + set_cached_license(None); + + // free limits are not exceeded + let counts = Counts::new(1, 1, 1, 1); + set_counts(counts); + + assert!(is_business_license_active()); + assert!(is_enterprise_license_active()); + assert!(is_enterprise_free()); + + // exceed free limits + let counts = Counts::new(1, 1, 5, 1); + set_counts(counts); + + assert!(!is_business_license_active()); + assert!(!is_enterprise_license_active()); + assert!(!is_enterprise_free()); + } + + #[test] + fn test_feature_gates_with_license() { + // exceed free limits + let counts = Counts::new(1, 1, 5, 1); + set_counts(counts); + + // set Business license + let users_limit = 15; + let devices_limit = 35; + let locations_limit = 5; + let network_devices_limit = 10; + + let limits = LicenseLimits { + users: users_limit, + devices: devices_limit, + locations: locations_limit, + network_devices: Some(network_devices_limit), + }; + let license = License::new( + "test".to_string(), + true, + Some(Utc::now() + TimeDelta::days(1)), + Some(limits), + None, + LicenseTier::Business, + ); + set_cached_license(Some(license)); + + assert!(is_business_license_active()); + assert!(!is_enterprise_license_active()); + assert!(!is_enterprise_free()); + + // set Enterprise license + let users_limit = 15; + let devices_limit = 35; + let locations_limit = 5; + let network_devices_limit = 10; + + let limits = LicenseLimits { + users: users_limit, + devices: devices_limit, + locations: locations_limit, + network_devices: Some(network_devices_limit), + }; + let license = License::new( + "test".to_string(), + true, + Some(Utc::now() + TimeDelta::days(1)), + Some(limits), + None, + LicenseTier::Enterprise, + ); + set_cached_license(Some(license)); + + assert!(is_business_license_active()); + assert!(is_enterprise_license_active()); + assert!(!is_enterprise_free()); + } +} diff --git a/crates/defguard_core/src/enterprise/proto/license.proto b/crates/defguard_core/src/enterprise/proto/license.proto index 098548a450..a8dfc170c5 100644 --- a/crates/defguard_core/src/enterprise/proto/license.proto +++ b/crates/defguard_core/src/enterprise/proto/license.proto @@ -8,12 +8,19 @@ message LicenseLimits { optional uint32 network_devices = 4; } +enum LicenseTier { + LICENSE_TIER_UNSPECIFIED = 0; + LICENSE_TIER_BUSINESS = 1; + LICENSE_TIER_ENTERPRISE = 2; +} + message LicenseMetadata { string customer_id = 1; bool subscription = 2; optional int64 valid_until = 3; LicenseLimits limits = 4; optional int64 version_date_limit = 5; + LicenseTier tier = 6; } message LicenseKey { diff --git a/crates/defguard_core/src/grpc/client_mfa.rs b/crates/defguard_core/src/grpc/client_mfa.rs index f688a41a48..4d38504c82 100644 --- a/crates/defguard_core/src/grpc/client_mfa.rs +++ b/crates/defguard_core/src/grpc/client_mfa.rs @@ -30,7 +30,7 @@ use crate::{ wireguard::LocationMfaMode, }, }, - enterprise::{db::models::openid_provider::OpenIdProvider, is_enterprise_enabled}, + enterprise::{db::models::openid_provider::OpenIdProvider, is_business_license_active}, events::{BidiRequestContext, BidiStreamEvent, BidiStreamEventType, DesktopClientMfaEvent}, grpc::utils::parse_client_ip_agent, handlers::mail::send_email_mfa_code_email, @@ -259,7 +259,7 @@ impl ClientMfaServer { })?; } MfaMethod::Oidc => { - if !is_enterprise_enabled() { + if !is_business_license_active() { error!("OIDC MFA method requires enterprise feature to be enabled"); return Err(Status::invalid_argument( "selected MFA method not available", diff --git a/crates/defguard_core/src/grpc/mod.rs b/crates/defguard_core/src/grpc/mod.rs index a4c4ba3dcf..4b82621ab4 100644 --- a/crates/defguard_core/src/grpc/mod.rs +++ b/crates/defguard_core/src/grpc/mod.rs @@ -59,7 +59,7 @@ use crate::{ handlers::openid_login::{ SELECT_ACCOUNT_SUPPORTED_PROVIDERS, build_state, make_oidc_client, user_from_claims, }, - is_enterprise_enabled, + is_business_license_active, ldap::utils::ldap_update_user_state, }, events::{BidiStreamEvent, GrpcEvent}, @@ -380,7 +380,7 @@ async fn handle_proxy_message_loop( } } Some(core_request::Payload::AuthInfo(request)) => { - if !is_enterprise_enabled() { + if !is_business_license_active() { warn!("Enterprise license required"); Some(core_response::Payload::CoreError(CoreError { status_code: Code::FailedPrecondition as i32, @@ -833,7 +833,7 @@ impl InstanceInfo { proxy_url: config.enrollment_url.clone(), username: username.into(), client_traffic_policy: enterprise_settings.client_traffic_policy, - enterprise_enabled: is_enterprise_enabled(), + enterprise_enabled: is_business_license_active(), openid_display_name, } } diff --git a/crates/defguard_core/src/handlers/app_info.rs b/crates/defguard_core/src/handlers/app_info.rs index 344ee41925..65ce7d954d 100644 --- a/crates/defguard_core/src/handlers/app_info.rs +++ b/crates/defguard_core/src/handlers/app_info.rs @@ -9,8 +9,8 @@ use crate::{ db::WireguardNetwork, enterprise::{ db::models::openid_provider::OpenIdProvider, - is_enterprise_enabled, is_enterprise_free, - license::get_cached_license, + is_business_license_active, is_enterprise_free, + license::{LicenseTier, get_cached_license}, limits::{LimitsExceeded, get_counts}, }, }; @@ -25,6 +25,8 @@ struct LicenseInfo { any_limit_exceeded: bool, /// Whether the enterprise features are used for free. is_enterprise_free: bool, + // Which license tier (if any) is active + tier: Option, } #[derive(Serialize)] @@ -55,11 +57,12 @@ pub(crate) async fn get_app_info( let external_openid_enabled = OpenIdProvider::get_current(&appstate.pool).await?.is_some(); let settings = Settings::get_current_settings(); - let enterprise = is_enterprise_enabled(); + let enterprise = is_business_license_active(); let license = get_cached_license(); let counts = get_counts(); let limits_exceeded = counts.get_exceeded_limits(license.as_ref()); let any_limit_exceeded = limits_exceeded.any(); + let tier = license.as_ref().map(|license| license.tier.clone()); let res = AppInfo { network_present: !networks.is_empty(), @@ -70,6 +73,7 @@ pub(crate) async fn get_app_info( limits_exceeded, any_limit_exceeded, is_enterprise_free: is_enterprise_free(), + tier, }, ldap_info: LdapInfo { enabled: settings.ldap_enabled, diff --git a/crates/defguard_core/src/handlers/wireguard.rs b/crates/defguard_core/src/handlers/wireguard.rs index 9410134bdd..34a2af7ea3 100644 --- a/crates/defguard_core/src/handlers/wireguard.rs +++ b/crates/defguard_core/src/handlers/wireguard.rs @@ -40,7 +40,7 @@ use crate::{ enterprise::{ db::models::{enterprise_settings::EnterpriseSettings, openid_provider::OpenIdProvider}, handlers::CanManageDevices, - is_enterprise_enabled, + is_business_license_active, limits::update_counts, }, events::{ApiEvent, ApiEventType, ApiRequestContext}, @@ -126,7 +126,7 @@ impl WireguardNetworkData { // if external MFA was chosen verify if enterprise features are enabled // and external OpenID provider is configured if self.location_mfa_mode == LocationMfaMode::External { - if !is_enterprise_enabled() { + if !is_business_license_active() { error!( "Unable to create location with external MFA. External OpenID provider is not configured" ); diff --git a/crates/defguard_core/src/utility_thread.rs b/crates/defguard_core/src/utility_thread.rs index 5a0de67b32..d49340a836 100644 --- a/crates/defguard_core/src/utility_thread.rs +++ b/crates/defguard_core/src/utility_thread.rs @@ -13,7 +13,7 @@ use crate::{ enterprise::{ db::models::acl::{AclRule, RuleState}, directory_sync::{do_directory_sync, get_directory_sync_interval}, - is_enterprise_enabled, + is_business_license_active, ldap::{do_ldap_sync, sync::get_ldap_sync_interval}, limits::do_count_update, }, @@ -40,7 +40,7 @@ pub async fn run_utility_thread( let mut last_enterprise_status_check = Instant::now(); // helper variable which stores previous enterprise features status - let mut enterprise_enabled = is_enterprise_enabled(); + let mut enterprise_enabled = is_business_license_active(); let directory_sync_task = || async { if let Err(e) = Box::pin( @@ -129,7 +129,7 @@ pub async fn run_utility_thread( // Check if enterprise features got enabled or disabled if last_enterprise_status_check.elapsed().as_secs() >= ENTERPRISE_STATUS_CHECK_INTERVAL { - let new_enterprise_enabled = is_enterprise_enabled(); + let new_enterprise_enabled = is_business_license_active(); if let Err(err) = enterprise_status_check( pool, wireguard_tx.clone(), diff --git a/crates/defguard_core/tests/integration/api/common/mod.rs b/crates/defguard_core/tests/integration/api/common/mod.rs index 1c4e222445..4a85f5205d 100644 --- a/crates/defguard_core/tests/integration/api/common/mod.rs +++ b/crates/defguard_core/tests/integration/api/common/mod.rs @@ -15,7 +15,7 @@ use defguard_core::{ auth::failed_login::FailedLoginMap, build_webapp, db::{AppEvent, Device, GatewayEvent, User, UserDetails, WireguardNetwork}, - enterprise::license::{License, set_cached_license}, + enterprise::license::{License, LicenseTier, set_cached_license}, events::ApiEvent, grpc::{WorkerState, gateway::map::GatewayMap}, handlers::Auth, @@ -95,6 +95,7 @@ pub(crate) async fn make_base_client( None, None, None, + LicenseTier::Business, ); set_cached_license(Some(license)); diff --git a/crates/defguard_core/tests/integration/api/openid_login.rs b/crates/defguard_core/tests/integration/api/openid_login.rs index 923633fe1a..027d2ae355 100644 --- a/crates/defguard_core/tests/integration/api/openid_login.rs +++ b/crates/defguard_core/tests/integration/api/openid_login.rs @@ -5,7 +5,7 @@ use defguard_core::{ enterprise::{ db::models::openid_provider::{DirectorySyncTarget, DirectorySyncUserBehavior}, handlers::openid_providers::AddProviderData, - license::{License, set_cached_license}, + license::{License, LicenseTier, set_cached_license}, }, handlers::Auth, }; @@ -93,6 +93,7 @@ async fn test_openid_providers(_: PgPoolOptions, options: PgConnectOptions) { Some(Utc::now() - Duration::days(1)), None, None, + LicenseTier::Business, ); set_cached_license(Some(new_license)); let response = client.get("/api/v1/openid/auth_info").send().await; diff --git a/crates/defguard_core/tests/integration/grpc/common/mod.rs b/crates/defguard_core/tests/integration/grpc/common/mod.rs index 96609dbfa7..1d2d364bbd 100644 --- a/crates/defguard_core/tests/integration/grpc/common/mod.rs +++ b/crates/defguard_core/tests/integration/grpc/common/mod.rs @@ -5,7 +5,7 @@ use defguard_common::db::models::settings::initialize_current_settings; use defguard_core::{ auth::failed_login::FailedLoginMap, db::{AppEvent, GatewayEvent}, - enterprise::license::{License, set_cached_license}, + enterprise::license::{License, LicenseTier, set_cached_license}, events::GrpcEvent, grpc::{ WorkerState, build_grpc_service_router, @@ -147,6 +147,7 @@ pub(crate) async fn make_grpc_test_server(pool: &PgPool) -> TestGrpcServer { None, None, None, + LicenseTier::Business, ); set_cached_license(Some(license)); diff --git a/crates/defguard_version/src/lib.rs b/crates/defguard_version/src/lib.rs index 05f177b24f..8d6881440e 100644 --- a/crates/defguard_version/src/lib.rs +++ b/crates/defguard_version/src/lib.rs @@ -62,7 +62,7 @@ use std::{cmp::Ordering, fmt, str::FromStr}; -use ::tracing::{error, warn}; +use ::tracing::warn; pub use semver::{BuildMetadata, Error as SemverError, Prerelease, Version}; use serde::Serialize; use thiserror::Error; diff --git a/flake.lock b/flake.lock index 8f18766cd4..b2ef29d1e0 100644 --- a/flake.lock +++ b/flake.lock @@ -20,11 +20,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1763966396, - "narHash": "sha256-6eeL1YPcY1MV3DDStIDIdy/zZCDKgHdkCmsrLJFiZf0=", + "lastModified": 1765779637, + "narHash": "sha256-KJ2wa/BLSrTqDjbfyNx70ov/HdgNBCBBSQP3BIzKnv4=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "5ae3b07d8d6527c42f17c876e404993199144b6a", + "rev": "1306659b587dc277866c7b69eb97e5f07864d8c4", "type": "github" }, "original": { @@ -48,11 +48,11 @@ ] }, "locked": { - "lastModified": 1764124769, - "narHash": "sha256-vcoOEy3i8AGJi3Y2C48hrf6CuL2h8W1gLe1gNt72Kxg=", + "lastModified": 1765852971, + "narHash": "sha256-rQdOMqfQNhcfqvh1dFIVWh09mrIWwerUJqqBdhIsf8g=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "5da8c00313b4434f00aed6b4c94cd3b207bafdc5", + "rev": "5f98ccecc9f1bc1c19c0a350a659af1a04b3b319", "type": "github" }, "original": { diff --git a/web/src/i18n/en/index.ts b/web/src/i18n/en/index.ts index fa88c26ca5..9431378732 100644 --- a/web/src/i18n/en/index.ts +++ b/web/src/i18n/en/index.ts @@ -2066,6 +2066,7 @@ Licensing information: [https://docs.defguard.net/enterprise/license](https://do alwaysOn: 'Always-on - A VPN connection will always be active when the user device is on.', mfaWarning: "Service locations can't be used while location MFA is enabled.", + enterpriseTierWarning: "This feature requires an Enterprise-tier license. If you are interested in using it, please contact our sales team at: sales@defguard.net" }, }, sections: { diff --git a/web/src/i18n/i18n-types.ts b/web/src/i18n/i18n-types.ts index a2ce0a2129..c4fd696949 100644 --- a/web/src/i18n/i18n-types.ts +++ b/web/src/i18n/i18n-types.ts @@ -4949,6 +4949,10 @@ type RootTranslation = { * S​e​r​v​i​c​e​ ​l​o​c​a​t​i​o​n​s​ ​c​a​n​'​t​ ​b​e​ ​u​s​e​d​ ​w​h​i​l​e​ ​l​o​c​a​t​i​o​n​ ​M​F​A​ ​i​s​ ​e​n​a​b​l​e​d​. */ mfaWarning: string + /** + * T​h​i​s​ ​f​e​a​t​u​r​e​ ​r​e​q​u​i​r​e​s​ ​a​n​ ​E​n​t​e​r​p​r​i​s​e​-​t​i​e​r​ ​l​i​c​e​n​s​e​.​ ​I​f​ ​y​o​u​ ​a​r​e​ ​i​n​t​e​r​e​s​t​e​d​ ​i​n​ ​u​s​i​n​g​ ​i​t​,​ ​p​l​e​a​s​e​ ​c​o​n​t​a​c​t​ ​o​u​r​ ​s​a​l​e​s​ ​t​e​a​m​ ​a​t​:​ ​s​a​l​e​s​@​d​e​f​g​u​a​r​d​.​n​e​t + */ + enterpriseTierWarning: string } } sections: { @@ -11694,6 +11698,10 @@ export type TranslationFunctions = { * Service locations can't be used while location MFA is enabled. */ mfaWarning: () => LocalizedString + /** + * This feature requires an Enterprise-tier license. If you are interested in using it, please contact our sales team at: sales@defguard.net + */ + enterpriseTierWarning: () => LocalizedString } } sections: { diff --git a/web/src/pages/network/NetworkEditForm/NetworkEditForm.tsx b/web/src/pages/network/NetworkEditForm/NetworkEditForm.tsx index 8968939c1e..197115644c 100644 --- a/web/src/pages/network/NetworkEditForm/NetworkEditForm.tsx +++ b/web/src/pages/network/NetworkEditForm/NetworkEditForm.tsx @@ -25,6 +25,7 @@ import useApi from '../../../shared/hooks/useApi'; import { useToaster } from '../../../shared/hooks/useToaster'; import { QueryKeys } from '../../../shared/queries'; import { + LicenseTier, LocationMfaMode, type Network, ServiceLocationMode, @@ -53,7 +54,18 @@ export const NetworkEditForm = () => { ); const queryClient = useQueryClient(); const { LL } = useI18nContext(); - const enterpriseEnabled = useAppStore((s) => s.appInfo?.license_info.enterprise); + const [licenseEnabled, licenseTier, isFreeLicense] = useAppStore( + (s) => [ + s.appInfo?.license_info.enterprise, + s.appInfo?.license_info.tier, + s.appInfo?.license_info.is_enterprise_free, + ], + shallow, + ); + const enterpriseLicenseEnabled = useMemo( + () => Boolean(isFreeLicense || licenseTier === LicenseTier.ENTERPRISE), + [licenseTier, isFreeLicense], + ); const { mutate } = useMutation({ mutationFn: editNetwork, @@ -397,7 +409,7 @@ export const NetworkEditForm = () => { displayValue: titleCase(val), })} /> - {!enterpriseEnabled && ( + {!licenseEnabled && (

{LL.networkConfiguration.form.helpers.aclFeatureDisabled()}

@@ -406,7 +418,7 @@ export const NetworkEditForm = () => { controller={{ control, name: 'acl_enabled' }} label={LL.networkConfiguration.form.fields.acl_enabled.label()} labelPlacement="right" - disabled={!enterpriseEnabled} + disabled={!licenseEnabled} /> { - {!mfaDisabled && ( + {!enterpriseLicenseEnabled ? ( -

{LL.networkConfiguration.form.helpers.serviceLocation.mfaWarning()}

+

+ {LL.networkConfiguration.form.helpers.serviceLocation.enterpriseTierWarning()} +

+ ) : ( + !mfaDisabled && ( + +

{LL.networkConfiguration.form.helpers.serviceLocation.mfaWarning()}

+
+ ) )} diff --git a/web/src/pages/wizard/components/WizardNetworkConfiguration/WizardNetworkConfiguration.tsx b/web/src/pages/wizard/components/WizardNetworkConfiguration/WizardNetworkConfiguration.tsx index 993154b893..5cf16313ba 100644 --- a/web/src/pages/wizard/components/WizardNetworkConfiguration/WizardNetworkConfiguration.tsx +++ b/web/src/pages/wizard/components/WizardNetworkConfiguration/WizardNetworkConfiguration.tsx @@ -23,7 +23,11 @@ import { useAppStore } from '../../../../shared/hooks/store/useAppStore.ts'; import useApi from '../../../../shared/hooks/useApi'; import { useToaster } from '../../../../shared/hooks/useToaster'; import { QueryKeys } from '../../../../shared/queries'; -import { LocationMfaMode, ServiceLocationMode } from '../../../../shared/types.ts'; +import { + LicenseTier, + LocationMfaMode, + ServiceLocationMode, +} from '../../../../shared/types.ts'; import { titleCase } from '../../../../shared/utils/titleCase'; import { trimObjectStrings } from '../../../../shared/utils/trimObjectStrings.ts'; import { Validate } from '../../../../shared/validators'; @@ -45,7 +49,18 @@ export const WizardNetworkConfiguration = () => { ); const wizardNetworkConfiguration = useWizardStore((state) => state.manualNetworkConfig); - const enterpriseEnabled = useAppStore((s) => s.appInfo?.license_info.enterprise); + const [licenseEnabled, licenseTier, isFreeLicense] = useAppStore( + (s) => [ + s.appInfo?.license_info.enterprise, + s.appInfo?.license_info.tier, + s.appInfo?.license_info.is_enterprise_free, + ], + shallow, + ); + const enterpriseLicenseEnabled = useMemo( + () => Boolean(isFreeLicense || licenseTier === LicenseTier.ENTERPRISE), + [licenseTier, isFreeLicense], + ); const toaster = useToaster(); const { LL } = useI18nContext(); @@ -290,7 +305,7 @@ export const WizardNetworkConfiguration = () => { displayValue: titleCase(group), })} /> - {!enterpriseEnabled && ( + {!licenseEnabled && (

{LL.networkConfiguration.form.helpers.aclFeatureDisabled()}

@@ -338,14 +353,22 @@ export const WizardNetworkConfiguration = () => { type="number" disabled={mfaDisabled} /> - {!mfaDisabled && ( + {!enterpriseLicenseEnabled ? ( -

{LL.networkConfiguration.form.helpers.serviceLocation.mfaWarning()}

+

+ {LL.networkConfiguration.form.helpers.serviceLocation.enterpriseTierWarning()} +

+ ) : ( + !mfaDisabled && ( + +

{LL.networkConfiguration.form.helpers.serviceLocation.mfaWarning()}

+
+ ) )} diff --git a/web/src/shared/types.ts b/web/src/shared/types.ts index 86bf0baca2..b438891d7f 100644 --- a/web/src/shared/types.ts +++ b/web/src/shared/types.ts @@ -1433,9 +1433,15 @@ export type LicenseLimits = { wireguard_network: boolean; }; +export enum LicenseTier { + BUSINESS = 'Business', + ENTERPRISE = 'Enterprise', +} + export type LicenseInfo = { enterprise: boolean; limits_exceeded: LicenseLimits; any_limit_exceeded: boolean; is_enterprise_free: boolean; + tier?: LicenseTier; };