From f9e12bc39df887202399e6b3950d8b01cb8e3aa6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Mon, 9 Mar 2026 12:36:45 +0100 Subject: [PATCH 1/2] Introduced license_check endpoint --- Cargo.lock | 4 +- .../defguard_common/src/db/models/settings.rs | 12 ++-- .../defguard_core/src/enterprise/license.rs | 25 ++++----- crates/defguard_core/src/enterprise/limits.rs | 4 +- crates/defguard_core/src/handlers/license.rs | 56 +++++++++++++++++++ crates/defguard_core/src/handlers/mod.rs | 3 +- crates/defguard_core/src/handlers/settings.rs | 10 ++-- crates/defguard_core/src/lib.rs | 4 +- crates/defguard_core/src/openapi.rs | 6 +- crates/model_derive/src/lib.rs | 10 ++-- 10 files changed, 97 insertions(+), 37 deletions(-) create mode 100644 crates/defguard_core/src/handlers/license.rs diff --git a/Cargo.lock b/Cargo.lock index c233448fc9..78cb133bb0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4756,9 +4756,9 @@ dependencies = [ [[package]] name = "quinn-proto" -version = "0.11.13" +version = "0.11.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" +checksum = "434b42fec591c96ef50e21e886936e66d3cc3f737104fdb9b737c40ffb94c098" dependencies = [ "bytes", "getrandom 0.3.4", diff --git a/crates/defguard_common/src/db/models/settings.rs b/crates/defguard_common/src/db/models/settings.rs index 936b0575eb..d55de1f48d 100644 --- a/crates/defguard_common/src/db/models/settings.rs +++ b/crates/defguard_common/src/db/models/settings.rs @@ -20,7 +20,7 @@ use crate::{ global_value!(SETTINGS, Option, None, set_settings, get_settings); /// Initializes global `SETTINGS` struct at program startup -pub async fn initialize_current_settings(pool: &PgPool) -> Result<(), sqlx::Error> { +pub async fn initialize_current_settings(pool: &PgPool) -> sqlx::Result<()> { debug!("Initializing global settings struct"); if let Some(settings) = Settings::get(pool).await? { set_settings(Some(settings)); @@ -38,7 +38,7 @@ pub async fn initialize_current_settings(pool: &PgPool) -> Result<(), sqlx::Erro pub async fn update_current_settings<'e, E: sqlx::PgExecutor<'e>>( executor: E, new_settings: Settings, -) -> Result<(), sqlx::Error> { +) -> sqlx::Result<()> { debug!("Updating current settings to: {new_settings:?}"); new_settings.save(executor).await?; set_settings(Some(new_settings)); @@ -303,7 +303,7 @@ impl Settings { BASE64_STANDARD.encode(bytes) } - pub async fn get<'e, E>(executor: E) -> Result, sqlx::Error> + pub async fn get<'e, E>(executor: E) -> sqlx::Result> where E: PgExecutor<'e>, { @@ -358,7 +358,7 @@ impl Settings { Ok(()) } - pub async fn save<'e, E>(&self, executor: E) -> Result<(), sqlx::Error> + pub async fn save<'e, E>(&self, executor: E) -> sqlx::Result<()> where E: PgExecutor<'e>, { @@ -707,7 +707,7 @@ impl Settings { &mut self, executor: E, config: &DefGuardConfig, - ) -> Result<(), sqlx::Error> + ) -> sqlx::Result<()> where E: PgExecutor<'e>, { @@ -746,7 +746,7 @@ pub struct SettingsEssentials { } impl SettingsEssentials { - pub async fn get_settings_essentials<'e, E>(executor: E) -> Result + pub async fn get_settings_essentials<'e, E>(executor: E) -> sqlx::Result where E: PgExecutor<'e>, { diff --git a/crates/defguard_core/src/enterprise/license.rs b/crates/defguard_core/src/enterprise/license.rs index 471e64574d..c0b3227b3f 100644 --- a/crates/defguard_core/src/enterprise/license.rs +++ b/crates/defguard_core/src/enterprise/license.rs @@ -40,8 +40,6 @@ pub(crate) const PUBLIC_KEY: &[u8] = include_bytes!("public_key.asc"); #[derive(Debug, Error)] pub enum LicenseError { - #[error("Provided license is invalid: {0}")] - InvalidLicense(String), #[error("Provided signature does not match the license")] SignatureMismatch, #[error("Provided signature is invalid")] @@ -49,7 +47,7 @@ pub enum LicenseError { #[error("Database error")] DbError(#[from] SqlxError), #[error("License decoding error: {0}")] - DecodeError(String), + DecodeError(&'static str), #[error( "License is expired and has reached its maximum overdue time, please contact salesdefguard.net" )] @@ -123,8 +121,7 @@ impl License { fn decode(bytes: &[u8]) -> Result, LicenseError> { let bytes = BASE64_STANDARD.decode(bytes).map_err(|_| { LicenseError::DecodeError( - "Failed to decode the license key, check if the provided key is correct." - .to_string(), + "Failed to decode the license key, check if the provided key is correct.", ) })?; Ok(bytes) @@ -164,7 +161,7 @@ impl License { /// Deserialize the license object from a base64 encoded string. /// Also verifies the signature of the license - pub fn from_base64(key: &str) -> Result { + pub(crate) fn from_base64(key: &str) -> Result { debug!("Decoding the license key from a provided base64 string..."); let bytes = key.as_bytes(); let decoded = Self::decode(bytes)?; @@ -173,7 +170,7 @@ impl License { let license_key = LicenseKey::decode(slice).map_err(|_| { LicenseError::DecodeError( - "The license key is malformed, check if the provided key is correct.".to_string(), + "The license key is malformed, check if the provided key is correct.", ) })?; let metadata_bytes: &[u8] = &license_key.metadata; @@ -184,7 +181,7 @@ impl License { Ok(()) => { info!("Successfully decoded the license and validated the license signature"); let metadata = LicenseMetadata::decode(metadata_bytes).map_err(|_| { - LicenseError::DecodeError("Failed to decode the license metadata".to_string()) + LicenseError::DecodeError("Failed to decode the license metadata") })?; let valid_until = match metadata.valid_until { @@ -206,7 +203,7 @@ impl License { Err(err) => { error!("Failed to read license tier from license metadata: {err}"); return Err(LicenseError::DecodeError( - "Failed to decode license tier metadata".into(), + "Failed to decode license tier metadata", )); } }; @@ -257,7 +254,7 @@ impl License { /// Create the license object based on the license key stored in the database. /// Automatically decodes and deserializes the keys and verifies the signature. - pub fn load() -> Result, LicenseError> { + pub(crate) fn load() -> Result, LicenseError> { if let Some(key) = Self::get_key() { Ok(Some(Self::from_base64(&key)?)) } else { @@ -307,7 +304,7 @@ impl License { /// NOTE: license should be considered valid for an additional period of `MAX_OVERDUE_TIME`. /// If you want to check if the license reached this point, use `is_max_overdue` instead. #[must_use] - pub fn is_expired(&self) -> bool { + pub(crate) fn is_expired(&self) -> bool { match self.valid_until { Some(time) => time < Utc::now(), None => false, @@ -323,7 +320,7 @@ impl License { /// Gets the time the license is past its expiry date. /// If the license doesn't have a `valid_until` field, it will return 0. #[must_use] - pub fn time_overdue(&self) -> TimeDelta { + pub(crate) fn time_overdue(&self) -> TimeDelta { match self.valid_until { Some(time) => { let delta = Utc::now() - time; @@ -339,7 +336,7 @@ impl License { /// Checks whether we should try to renew the license. #[must_use] - pub fn requires_renewal(&self) -> bool { + pub(crate) fn requires_renewal(&self) -> bool { if self.subscription { if let Some(remaining) = self.time_left() { remaining <= RENEWAL_TIME @@ -353,7 +350,7 @@ impl License { /// Checks if the license has reached its maximum overdue time. #[must_use] - pub fn is_max_overdue(&self) -> bool { + pub(crate) fn is_max_overdue(&self) -> bool { if self.subscription { self.time_overdue() > MAX_OVERDUE_TIME } else { diff --git a/crates/defguard_core/src/enterprise/limits.rs b/crates/defguard_core/src/enterprise/limits.rs index f95628113e..46f3a43007 100644 --- a/crates/defguard_core/src/enterprise/limits.rs +++ b/crates/defguard_core/src/enterprise/limits.rs @@ -1,12 +1,12 @@ use defguard_common::global_value; +use serde::Serialize; use sqlx::{error::Error as SqlxError, query}; use super::license::License; #[cfg(test)] use super::license::get_cached_license; -#[derive(Debug)] -#[cfg_attr(test, derive(Clone))] +#[derive(Clone, Debug, Serialize)] pub struct Counts { user: u32, user_device: u32, diff --git a/crates/defguard_core/src/handlers/license.rs b/crates/defguard_core/src/handlers/license.rs new file mode 100644 index 0000000000..e568086de6 --- /dev/null +++ b/crates/defguard_core/src/handlers/license.rs @@ -0,0 +1,56 @@ +use axum::{Json, http::StatusCode}; +use utoipa::ToSchema; + +use super::{ApiResponse, ApiResult}; +use crate::{ + enterprise::{ + license::License, + limits::{Counts, get_counts}, + }, + grpc::proto::enterprise::license::LicenseLimits, +}; + +#[derive(Deserialize, ToSchema)] +pub struct CheckParams { + license: String, +} + +#[derive(Serialize)] +pub struct CheckResult { + limits: Option, + counts: Counts, +} + +/// Check given license. Return [`LicenseLimits`]. +#[utoipa::path( + post, + path = "/api/v1/license/check", + request_body = CheckParams, + responses( + ( + status = 200, + description = "Decoded license limits.", + // TODO: uncomment when LicenseLimits and Counts implement ToSchema. + // body = CheckResult, + example = json!({ + "users": 100, + "devices": 250, + "locations": 10, + "network_devices": 50 + }) + ), + (status = 400, description = "Invalid license key.", body = ApiResponse, example = json!({"msg": "License signature doesn't match its content"})), + (status = 404, description = "License not found.", body = ApiResponse, example = json!({"msg": "License not found"})) + ) +)] +pub(crate) async fn license_check(Json(params): Json) -> ApiResult { + let license = License::from_base64(params.license.trim())?; + + Ok(ApiResponse::json( + CheckResult { + limits: license.limits, + counts: get_counts().clone(), + }, + StatusCode::OK, + )) +} diff --git a/crates/defguard_core/src/handlers/mod.rs b/crates/defguard_core/src/handlers/mod.rs index da040d548e..270caeeda7 100644 --- a/crates/defguard_core/src/handlers/mod.rs +++ b/crates/defguard_core/src/handlers/mod.rs @@ -35,6 +35,7 @@ pub mod component_setup; pub(crate) mod forward_auth; pub mod gateway; pub(crate) mod group; +pub mod license; pub(crate) mod location_stats; pub mod mail; pub mod network_devices; @@ -230,7 +231,7 @@ impl From for ApiResponse { ) } WebError::LicenseError(err) => match err { - LicenseError::DecodeError(msg) | LicenseError::InvalidLicense(msg) => { + LicenseError::DecodeError(msg) => { warn!(msg); ApiResponse::new(json!({"msg": msg}), StatusCode::BAD_REQUEST) } diff --git a/crates/defguard_core/src/handlers/settings.rs b/crates/defguard_core/src/handlers/settings.rs index bb43703dfe..1bae0ea2ba 100644 --- a/crates/defguard_core/src/handlers/settings.rs +++ b/crates/defguard_core/src/handlers/settings.rs @@ -40,7 +40,7 @@ pub async fn get_settings(_admin: AdminRole, State(appstate): State) - Ok(ApiResponse::default()) } -pub async fn update_settings( +pub(crate) async fn update_settings( _admin: AdminRole, session: SessionInfo, context: ApiRequestContext, @@ -84,7 +84,7 @@ pub async fn get_settings_essentials(Extension(pool): Extension) -> ApiR Ok(ApiResponse::json(settings, StatusCode::OK)) } -pub async fn set_default_branding( +pub(crate) async fn set_default_branding( _admin: AdminRole, State(appstate): State, Path(_id): Path, // TODO: check with front-end and remove. @@ -123,7 +123,7 @@ pub async fn patch_settings( context: ApiRequestContext, Json(data): Json, ) -> ApiResult { - debug!("Admin {} patching settings", session.user.username); + debug!("Admin {} is patching settings", session.user.username); let mut settings = Settings::get_current_settings(); // prepare clone for emitting an event let before = settings.clone(); @@ -158,7 +158,7 @@ pub async fn patch_settings( let after = settings.clone(); update_current_settings(&appstate.pool, settings).await?; - info!("Admin {} patched settings.", session.user.username); + info!("Admin {} patched settings", session.user.username); appstate.emit_event(ApiEvent { context, event: Box::new(ApiEventType::SettingsUpdatedPartial { before, after }), @@ -166,7 +166,7 @@ pub async fn patch_settings( Ok(ApiResponse::default()) } -pub async fn test_ldap_settings(_admin: AdminRole, _license: LicenseInfo) -> ApiResult { +pub(crate) async fn test_ldap_settings(_admin: AdminRole, _license: LicenseInfo) -> ApiResult { debug!("Testing LDAP connection"); match LDAPConnection::create().await { Ok(_) => { diff --git a/crates/defguard_core/src/lib.rs b/crates/defguard_core/src/lib.rs index 753c01040f..f8a9c8077c 100644 --- a/crates/defguard_core/src/lib.rs +++ b/crates/defguard_core/src/lib.rs @@ -127,6 +127,7 @@ use crate::{ add_group_member, create_group, delete_group, get_group, list_groups, modify_group, remove_group_member, }, + license::license_check, location_stats::{ location_connected_network_devices, location_connected_user_devices, location_connected_users, location_stats, locations_overview_stats, @@ -575,7 +576,8 @@ pub fn build_webapp( "/network/{location_id}/snat/{user_id}", put(modify_snat_binding).delete(delete_snat_binding), ) - .route("/outdated", get(outdated_components)), + .route("/outdated", get(outdated_components)) + .route("/license/check", post(license_check)), ); let webapp = webapp.nest( diff --git a/crates/defguard_core/src/openapi.rs b/crates/defguard_core/src/openapi.rs index 6017ec6399..5c3483856a 100644 --- a/crates/defguard_core/src/openapi.rs +++ b/crates/defguard_core/src/openapi.rs @@ -20,6 +20,7 @@ use super::{ ApiResponse, EditGroupInfo, GroupInfo, PasswordChange, PasswordChangeSelf, SESSION_COOKIE_NAME, StartEnrollmentRequest, Username, auth, group::{self, BulkAssignToGroupsRequest, Groups}, + license, user::{self, UserDetails}, wireguard as device, wireguard as network, wireguard::AddDeviceResult, @@ -71,6 +72,8 @@ use super::{ network::delete_network, network::list_networks, network::network_details, + // /license + license::license_check, // /network/{location_id}/snat snat::list_snat_bindings, snat::create_snat_binding, @@ -111,7 +114,8 @@ use super::{ schemas( ApiResponse, UserInfo, UserDetails, UserDevice, Groups, Username, StartEnrollmentRequest, PasswordChangeSelf, PasswordChange, AddDevice, AddDeviceResult, - Device, ModifyDevice, BulkAssignToGroupsRequest, GroupInfo, EditGroupInfo, WebError + Device, ModifyDevice, BulkAssignToGroupsRequest, GroupInfo, EditGroupInfo, WebError, + license::CheckParams ), ), tags( diff --git a/crates/model_derive/src/lib.rs b/crates/model_derive/src/lib.rs index c3c477278b..a40f0fe861 100644 --- a/crates/model_derive/src/lib.rs +++ b/crates/model_derive/src/lib.rs @@ -220,7 +220,7 @@ pub fn derive(input: TokenStream) -> TokenStream { // TODO: add limit and offset for all(). quote! { impl #name { - pub async fn save<'e, E>(self, executor: E) -> Result<#name, sqlx::Error> + pub async fn save<'e, E>(self, executor: E) -> sqlx::Result<#name> where E: sqlx::PgExecutor<'e> { @@ -235,21 +235,21 @@ pub fn derive(input: TokenStream) -> TokenStream { } impl #name { - pub async fn find_by_id<'e, E>(executor: E, id: Id) -> Result, sqlx::Error> + pub async fn find_by_id<'e, E>(executor: E, id: Id) -> sqlx::Result> where E: sqlx::PgExecutor<'e> { sqlx::query_as!(Self, #find_by_id_query, id).fetch_optional(executor).await } - pub async fn all<'e, E>(executor: E) -> Result, sqlx::Error> + pub async fn all<'e, E>(executor: E) -> sqlx::Result> where E: sqlx::PgExecutor<'e> { sqlx::query_as!(Self, #all_query).fetch_all(executor).await } - pub async fn delete<'e, E>(self, executor: E) -> Result<(), sqlx::Error> + pub async fn delete<'e, E>(self, executor: E) -> sqlx::Result<()> where E: sqlx::PgExecutor<'e> { @@ -258,7 +258,7 @@ pub fn derive(input: TokenStream) -> TokenStream { Ok(()) } - pub async fn save<'e, E>(&self, executor: E) -> Result<(), sqlx::Error> + pub async fn save<'e, E>(&self, executor: E) -> sqlx::Result<()> where E: sqlx::PgExecutor<'e> { From b68281301a842f9778023ca36aa8678dd6ccf495 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Mon, 9 Mar 2026 13:36:47 +0100 Subject: [PATCH 2/2] Add web interface for licence check --- .../SettingsLicenseModal.tsx | 112 ++++++++++++++++-- web/src/routes/_authorized.tsx | 2 + web/src/shared/api/api.ts | 3 + web/src/shared/api/types.ts | 22 ++++ .../LicenseLimitConflictModal.tsx | 93 +++++++++++++++ .../shared/hooks/modalControls/modalTypes.ts | 6 + web/src/shared/hooks/modalControls/types.ts | 8 ++ 7 files changed, 236 insertions(+), 10 deletions(-) create mode 100644 web/src/shared/components/modals/license/LicenseLimitConflictModal/LicenseLimitConflictModal.tsx diff --git a/web/src/pages/settings/SettingsIndexPage/tabs/SettingsLicenseTab/modals/SettingsLicenseModal/SettingsLicenseModal.tsx b/web/src/pages/settings/SettingsIndexPage/tabs/SettingsLicenseTab/modals/SettingsLicenseModal/SettingsLicenseModal.tsx index af696ad018..389c38725c 100644 --- a/web/src/pages/settings/SettingsIndexPage/tabs/SettingsLicenseTab/modals/SettingsLicenseModal/SettingsLicenseModal.tsx +++ b/web/src/pages/settings/SettingsIndexPage/tabs/SettingsLicenseTab/modals/SettingsLicenseModal/SettingsLicenseModal.tsx @@ -5,7 +5,10 @@ import { useEffect, useMemo, useState } from 'react'; import z from 'zod'; import { m } from '../../../../../../../paraglide/messages'; import api from '../../../../../../../shared/api/api'; -import type { ApiError } from '../../../../../../../shared/api/types'; +import type { + ApiError, + LicenseCheckResponse, +} from '../../../../../../../shared/api/types'; import { CopyButton } from '../../../../../../../shared/components/CopyButton/CopyButton'; import { Modal } from '../../../../../../../shared/defguard-ui/components/Modal/Modal'; import { ModalControls } from '../../../../../../../shared/defguard-ui/components/ModalControls/ModalControls'; @@ -14,6 +17,7 @@ import { useAppForm } from '../../../../../../../shared/form'; import { formChangeLogic } from '../../../../../../../shared/formLogic'; import { closeModal, + openModal, subscribeCloseModal, subscribeOpenModal, } from '../../../../../../../shared/hooks/modalControls/modalsSubjects'; @@ -65,6 +69,64 @@ const formSchema = z.object({ type FormFields = z.infer; +type LicenseLimitConflict = { + label: string; + current: number; + limit: number; +}; + +const sanitizeLicense = (license: string | null | undefined) => + license?.replaceAll('\n', '').trim() ?? ''; + +const getLicenseLimitConflicts = ({ + counts, + limits, +}: LicenseCheckResponse): LicenseLimitConflict[] => { + if (!limits) { + return []; + } + + const conflicts: LicenseLimitConflict[] = []; + + if (counts.user > limits.users) { + conflicts.push({ + label: 'Users', + current: counts.user, + limit: limits.users, + }); + } + + if (counts.location > limits.locations) { + conflicts.push({ + label: 'Locations', + current: counts.location, + limit: limits.locations, + }); + } + + const currentDevices = counts.user_device + counts.network_device; + if (currentDevices > limits.devices) { + conflicts.push({ + label: 'Devices', + current: currentDevices, + limit: limits.devices, + }); + } + + if ( + isPresent(limits.network_devices) && + counts.network_device > limits.network_devices + ) { + conflicts.push({ + label: 'Network devices', + current: counts.network_device, + limit: limits.network_devices, + }); + } + + return conflicts; +}; + const ModalContent = ({ license: initialLicense }: ModalData) => { const defaultValues: FormFields = useMemo( () => ({ @@ -91,17 +153,47 @@ const ModalContent = ({ license: initialLicense }: ModalData) => { onChange: formSchema, }, onSubmit: async ({ value, formApi }) => { - await patchSettings({ - license: value.license?.replaceAll('\n', '').trim() ?? '', - }).catch((e: AxiosError) => { - if (e.status && e.status >= 400 && e.status < 500) { - formApi.setErrorMap({ - onSubmit: { - fields: { - license: m.form_error_license(), - }, + const license = sanitizeLicense(value.license); + + const setLicenseError = () => { + formApi.setErrorMap({ + onSubmit: { + fields: { + license: m.form_error_license(), }, + }, + }); + }; + + if (license.length > 0) { + const checkResult = await api + .checkLicense({ license }) + .catch((e: AxiosError) => { + const status = e.status ?? e.response?.status; + if (status && status >= 400 && status < 500) { + setLicenseError(); + } + return null; }); + + if (!checkResult) { + return; + } + + const conflicts = getLicenseLimitConflicts(checkResult.data); + if (conflicts.length > 0) { + closeModal(modalNameValue); + openModal(ModalName.LicenseLimitConflict, { conflicts }); + return; + } + } + + await patchSettings({ + license, + }).catch((e: AxiosError) => { + const status = e.status ?? e.response?.status; + if (status && status >= 400 && status < 500) { + setLicenseError(); } }); }, diff --git a/web/src/routes/_authorized.tsx b/web/src/routes/_authorized.tsx index bb2eb8c64f..5414bcfa4b 100644 --- a/web/src/routes/_authorized.tsx +++ b/web/src/routes/_authorized.tsx @@ -1,6 +1,7 @@ import { createFileRoute, Outlet, redirect } from '@tanstack/react-router'; import { DisplayListModal } from '../shared/components/DisplayListModal/DisplayListModal'; import { LicenseExpiredModal } from '../shared/components/modals/license/LicenseExpiredModal/LicenseExpiredModal'; +import { LicenseLimitConflictModal } from '../shared/components/modals/license/LicenseLimitConflictModal/LicenseLimitConflictModal'; import { LimitReachedModal } from '../shared/components/modals/license/LimitReachedModal/LimitReachedModal'; import { UpgradeBusinessModal } from '../shared/components/modals/license/UpgradeBusinessModal/UpgradeBusinessModal'; import { UpgradeEnterpriseModal } from '../shared/components/modals/license/UpgradeEnterpriseModal/UpgradeEnterpriseModal'; @@ -35,6 +36,7 @@ function RouteComponent() { + diff --git a/web/src/shared/api/api.ts b/web/src/shared/api/api.ts index e540ec9a52..eb7f013508 100644 --- a/web/src/shared/api/api.ts +++ b/web/src/shared/api/api.ts @@ -59,6 +59,7 @@ import type { GroupInfo, GroupsResponse, IpValidation, + LicenseCheckResponse, LicenseInfoResponse, LocationConnectedNetworkDevice, LocationConnectedNetworkDevicesRequest, @@ -532,6 +533,8 @@ const api = { setGeneralConfig: (data: MigrationGeneralConfigRequest) => client.post(`/migration/general_config`, data), }, + checkLicense: (data: { license: string }) => + client.post('/license/check', data), getSessionInfo: () => client.get(`/session-info`), getActivityLog: (data?: ActivityLogRequestParams) => client diff --git a/web/src/shared/api/types.ts b/web/src/shared/api/types.ts index 691af31f1b..d5e44fadd8 100644 --- a/web/src/shared/api/types.ts +++ b/web/src/shared/api/types.ts @@ -318,6 +318,28 @@ export interface LimitInfo { export interface LicenseLimitsInfo { locations: LimitInfo; users: LimitInfo; + devices?: LimitInfo | null; + user_devices?: LimitInfo | null; + network_devices?: LimitInfo | null; +} + +export interface LicenseCheckCounts { + user: number; + user_device: number; + network_device: number; + location: number; +} + +export interface LicenseCheckLimits { + users: number; + devices: number; + locations: number; + network_devices?: number | null; +} + +export interface LicenseCheckResponse { + limits: LicenseCheckLimits | null; + counts: LicenseCheckCounts; } export const LicenseTier = { diff --git a/web/src/shared/components/modals/license/LicenseLimitConflictModal/LicenseLimitConflictModal.tsx b/web/src/shared/components/modals/license/LicenseLimitConflictModal/LicenseLimitConflictModal.tsx new file mode 100644 index 0000000000..6240115194 --- /dev/null +++ b/web/src/shared/components/modals/license/LicenseLimitConflictModal/LicenseLimitConflictModal.tsx @@ -0,0 +1,93 @@ +import { useEffect, useState } from 'react'; +import { AppText } from '../../../../defguard-ui/components/AppText/AppText'; +import { Divider } from '../../../../defguard-ui/components/Divider/Divider'; +import { SizedBox } from '../../../../defguard-ui/components/SizedBox/SizedBox'; +import { TextStyle, ThemeSpacing, ThemeVariable } from '../../../../defguard-ui/types'; +import { isPresent } from '../../../../defguard-ui/utils/isPresent'; +import { + subscribeCloseModal, + subscribeOpenModal, +} from '../../../../hooks/modalControls/modalsSubjects'; +import { ModalName } from '../../../../hooks/modalControls/modalTypes'; +import type { OpenLicenseLimitConflictModal } from '../../../../hooks/modalControls/types'; +import { LicenseModal } from '../../LicenseModal/LicenseModal'; +import { LicenseModalControls } from '../LicenseModalControls'; + +const modalNameKey = ModalName.LicenseLimitConflict; + +export const LicenseLimitConflictModal = () => { + const [isOpen, setOpen] = useState(false); + const [modalData, setModalData] = useState(null); + + useEffect(() => { + const openSub = subscribeOpenModal(modalNameKey, (data) => { + setModalData(data); + setOpen(true); + }); + const closeSub = subscribeCloseModal(modalNameKey, () => setOpen(false)); + return () => { + openSub.unsubscribe(); + closeSub.unsubscribe(); + }; + }, []); + + return ( + setOpen(false)} + afterClose={() => { + setModalData(null); + }} + > + {isPresent(modalData) && } + + ); +}; + +const ModalContent = ({ conflicts }: OpenLicenseLimitConflictModal) => { + return ( + <> + {`Plan limits don’t match`} + + {`License cannot be applied`} + + {`The license you’re trying to use allows fewer resources that your current setup is using.`} + + + {`To apply this license, first reduce your usage so it fits within the license limits.`} + + + + {`You can also upgrade your plan to the one with higher limits such as:`} +
+ {`• 30 users or more`} +
+ {`• 5 locations or more`} +
+ + + {`No changes were made to your current configuration.`} + +
+ {conflicts.map((conflict) => ( + {`${conflict.label}: ${conflict.current} used, ${conflict.limit} allowed`} + ))} +
+ + + ); +}; diff --git a/web/src/shared/hooks/modalControls/modalTypes.ts b/web/src/shared/hooks/modalControls/modalTypes.ts index 8b5c92ce01..a0b972e958 100644 --- a/web/src/shared/hooks/modalControls/modalTypes.ts +++ b/web/src/shared/hooks/modalControls/modalTypes.ts @@ -21,6 +21,7 @@ import type { OpenEditUserModal, OpenEnrollmentTokenModal, OpenLicenseExpiredModal, + OpenLicenseLimitConflictModal, OpenNetworkDeviceConfigModal, OpenNetworkDeviceTokenModal, OpenRenameApiTokenModal, @@ -29,6 +30,7 @@ import type { export const ModalName = { LicenseExpired: 'licenseExpired', + LicenseLimitConflict: 'licenseLimitConflict', UpgradeBusiness: 'upgradeBusiness', UpgradeEnterprise: 'upgradeEnterprise', LimitReached: 'limitReached', @@ -199,6 +201,10 @@ const modalOpenArgsSchema = z.discriminatedUnion('name', [ z.object({ name: z.literal(ModalName.LimitReached), }), + z.object({ + name: z.literal(ModalName.LicenseLimitConflict), + data: z.custom(), + }), z.object({ name: z.literal(ModalName.UpgradeBusiness), }), diff --git a/web/src/shared/hooks/modalControls/types.ts b/web/src/shared/hooks/modalControls/types.ts index f2e63e6116..55f675a690 100644 --- a/web/src/shared/hooks/modalControls/types.ts +++ b/web/src/shared/hooks/modalControls/types.ts @@ -121,6 +121,14 @@ export interface OpenSettingsLicenseModal { license?: string | null; } +export interface OpenLicenseLimitConflictModal { + conflicts: Array<{ + label: string; + current: number; + limit: number; + }>; +} + export interface OpenLicenseExpiredModal { licenseTier: LicenseTierValue; }