Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions crates/defguard_common/src/db/models/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::{
global_value!(SETTINGS, Option<Settings>, None, set_settings, get_settings);

/// Initializes global `SETTINGS` struct at program startup
pub async fn initialize_current_settings(pool: &PgPool) -> Result<(), sqlx::Error> {
pub async fn initialize_current_settings(pool: &PgPool) -> sqlx::Result<()> {
debug!("Initializing global settings struct");
if let Some(settings) = Settings::get(pool).await? {
set_settings(Some(settings));
Expand All @@ -38,7 +38,7 @@ pub async fn initialize_current_settings(pool: &PgPool) -> Result<(), sqlx::Erro
pub async fn update_current_settings<'e, E: sqlx::PgExecutor<'e>>(
executor: E,
new_settings: Settings,
) -> Result<(), sqlx::Error> {
) -> sqlx::Result<()> {
debug!("Updating current settings to: {new_settings:?}");
new_settings.save(executor).await?;
set_settings(Some(new_settings));
Expand Down Expand Up @@ -303,7 +303,7 @@ impl Settings {
BASE64_STANDARD.encode(bytes)
}

pub async fn get<'e, E>(executor: E) -> Result<Option<Self>, sqlx::Error>
pub async fn get<'e, E>(executor: E) -> sqlx::Result<Option<Self>>
where
E: PgExecutor<'e>,
{
Expand Down Expand Up @@ -358,7 +358,7 @@ impl Settings {
Ok(())
}

pub async fn save<'e, E>(&self, executor: E) -> Result<(), sqlx::Error>
pub async fn save<'e, E>(&self, executor: E) -> sqlx::Result<()>
where
E: PgExecutor<'e>,
{
Expand Down Expand Up @@ -707,7 +707,7 @@ impl Settings {
&mut self,
executor: E,
config: &DefGuardConfig,
) -> Result<(), sqlx::Error>
) -> sqlx::Result<()>
where
E: PgExecutor<'e>,
{
Expand Down Expand Up @@ -746,7 +746,7 @@ pub struct SettingsEssentials {
}

impl SettingsEssentials {
pub async fn get_settings_essentials<'e, E>(executor: E) -> Result<Self, sqlx::Error>
pub async fn get_settings_essentials<'e, E>(executor: E) -> sqlx::Result<Self>
where
E: PgExecutor<'e>,
{
Expand Down
25 changes: 11 additions & 14 deletions crates/defguard_core/src/enterprise/license.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,14 @@ pub(crate) const PUBLIC_KEY: &[u8] = include_bytes!("public_key.asc");

#[derive(Debug, Error)]
pub enum LicenseError {
#[error("Provided license is invalid: {0}")]
InvalidLicense(String),
#[error("Provided signature does not match the license")]
SignatureMismatch,
#[error("Provided signature is invalid")]
InvalidSignature,
#[error("Database error")]
DbError(#[from] SqlxError),
#[error("License decoding error: {0}")]
DecodeError(String),
DecodeError(&'static str),
#[error(
"License is expired and has reached its maximum overdue time, please contact sales<at>defguard.net"
)]
Expand Down Expand Up @@ -123,8 +121,7 @@ impl License {
fn decode(bytes: &[u8]) -> Result<Vec<u8>, LicenseError> {
let bytes = BASE64_STANDARD.decode(bytes).map_err(|_| {
LicenseError::DecodeError(
"Failed to decode the license key, check if the provided key is correct."
.to_string(),
"Failed to decode the license key, check if the provided key is correct.",
)
})?;
Ok(bytes)
Expand Down Expand Up @@ -164,7 +161,7 @@ impl License {

/// Deserialize the license object from a base64 encoded string.
/// Also verifies the signature of the license
pub fn from_base64(key: &str) -> Result<License, LicenseError> {
pub(crate) fn from_base64(key: &str) -> Result<License, LicenseError> {
debug!("Decoding the license key from a provided base64 string...");
let bytes = key.as_bytes();
let decoded = Self::decode(bytes)?;
Expand All @@ -173,7 +170,7 @@ impl License {

let license_key = LicenseKey::decode(slice).map_err(|_| {
LicenseError::DecodeError(
"The license key is malformed, check if the provided key is correct.".to_string(),
"The license key is malformed, check if the provided key is correct.",
)
})?;
let metadata_bytes: &[u8] = &license_key.metadata;
Expand All @@ -184,7 +181,7 @@ impl License {
Ok(()) => {
info!("Successfully decoded the license and validated the license signature");
let metadata = LicenseMetadata::decode(metadata_bytes).map_err(|_| {
LicenseError::DecodeError("Failed to decode the license metadata".to_string())
LicenseError::DecodeError("Failed to decode the license metadata")
})?;

let valid_until = match metadata.valid_until {
Expand All @@ -206,7 +203,7 @@ impl License {
Err(err) => {
error!("Failed to read license tier from license metadata: {err}");
return Err(LicenseError::DecodeError(
"Failed to decode license tier metadata".into(),
"Failed to decode license tier metadata",
));
}
};
Expand Down Expand Up @@ -257,7 +254,7 @@ impl License {

/// Create the license object based on the license key stored in the database.
/// Automatically decodes and deserializes the keys and verifies the signature.
pub fn load() -> Result<Option<License>, LicenseError> {
pub(crate) fn load() -> Result<Option<License>, LicenseError> {
if let Some(key) = Self::get_key() {
Ok(Some(Self::from_base64(&key)?))
} else {
Expand Down Expand Up @@ -307,7 +304,7 @@ impl License {
/// NOTE: license should be considered valid for an additional period of `MAX_OVERDUE_TIME`.
/// If you want to check if the license reached this point, use `is_max_overdue` instead.
#[must_use]
pub fn is_expired(&self) -> bool {
pub(crate) fn is_expired(&self) -> bool {
match self.valid_until {
Some(time) => time < Utc::now(),
None => false,
Expand All @@ -323,7 +320,7 @@ impl License {
/// Gets the time the license is past its expiry date.
/// If the license doesn't have a `valid_until` field, it will return 0.
#[must_use]
pub fn time_overdue(&self) -> TimeDelta {
pub(crate) fn time_overdue(&self) -> TimeDelta {
match self.valid_until {
Some(time) => {
let delta = Utc::now() - time;
Expand All @@ -339,7 +336,7 @@ impl License {

/// Checks whether we should try to renew the license.
#[must_use]
pub fn requires_renewal(&self) -> bool {
pub(crate) fn requires_renewal(&self) -> bool {
if self.subscription {
if let Some(remaining) = self.time_left() {
remaining <= RENEWAL_TIME
Expand All @@ -353,7 +350,7 @@ impl License {

/// Checks if the license has reached its maximum overdue time.
#[must_use]
pub fn is_max_overdue(&self) -> bool {
pub(crate) fn is_max_overdue(&self) -> bool {
if self.subscription {
self.time_overdue() > MAX_OVERDUE_TIME
} else {
Expand Down
4 changes: 2 additions & 2 deletions crates/defguard_core/src/enterprise/limits.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use defguard_common::global_value;
use serde::Serialize;
use sqlx::{error::Error as SqlxError, query};

use super::license::License;
#[cfg(test)]
use super::license::get_cached_license;

#[derive(Debug)]
#[cfg_attr(test, derive(Clone))]
#[derive(Clone, Debug, Serialize)]
pub struct Counts {
user: u32,
user_device: u32,
Expand Down
56 changes: 56 additions & 0 deletions crates/defguard_core/src/handlers/license.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
use axum::{Json, http::StatusCode};
use utoipa::ToSchema;

use super::{ApiResponse, ApiResult};
use crate::{
enterprise::{
license::License,
limits::{Counts, get_counts},
},
grpc::proto::enterprise::license::LicenseLimits,
};

#[derive(Deserialize, ToSchema)]
pub struct CheckParams {
license: String,
}

#[derive(Serialize)]
pub struct CheckResult {
limits: Option<LicenseLimits>,
counts: Counts,
}

/// Check given license. Return [`LicenseLimits`].
#[utoipa::path(
post,
path = "/api/v1/license/check",
request_body = CheckParams,
responses(
(
status = 200,
description = "Decoded license limits.",
// TODO: uncomment when LicenseLimits and Counts implement ToSchema.
// body = CheckResult,
example = json!({
"users": 100,
"devices": 250,
"locations": 10,
"network_devices": 50
})
),
(status = 400, description = "Invalid license key.", body = ApiResponse, example = json!({"msg": "License signature doesn't match its content"})),
(status = 404, description = "License not found.", body = ApiResponse, example = json!({"msg": "License not found"}))
)
)]
pub(crate) async fn license_check(Json(params): Json<CheckParams>) -> ApiResult {
let license = License::from_base64(params.license.trim())?;

Ok(ApiResponse::json(
CheckResult {
limits: license.limits,
counts: get_counts().clone(),
},
StatusCode::OK,
))
}
3 changes: 2 additions & 1 deletion crates/defguard_core/src/handlers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub mod component_setup;
pub(crate) mod forward_auth;
pub mod gateway;
pub(crate) mod group;
pub mod license;
pub(crate) mod location_stats;
pub mod mail;
pub mod network_devices;
Expand Down Expand Up @@ -230,7 +231,7 @@ impl From<WebError> for ApiResponse {
)
}
WebError::LicenseError(err) => match err {
LicenseError::DecodeError(msg) | LicenseError::InvalidLicense(msg) => {
LicenseError::DecodeError(msg) => {
warn!(msg);
ApiResponse::new(json!({"msg": msg}), StatusCode::BAD_REQUEST)
}
Expand Down
10 changes: 5 additions & 5 deletions crates/defguard_core/src/handlers/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub async fn get_settings(_admin: AdminRole, State(appstate): State<AppState>) -
Ok(ApiResponse::default())
}

pub async fn update_settings(
pub(crate) async fn update_settings(
_admin: AdminRole,
session: SessionInfo,
context: ApiRequestContext,
Expand Down Expand Up @@ -84,7 +84,7 @@ pub async fn get_settings_essentials(Extension(pool): Extension<PgPool>) -> ApiR
Ok(ApiResponse::json(settings, StatusCode::OK))
}

pub async fn set_default_branding(
pub(crate) async fn set_default_branding(
_admin: AdminRole,
State(appstate): State<AppState>,
Path(_id): Path<Id>, // TODO: check with front-end and remove.
Expand Down Expand Up @@ -123,7 +123,7 @@ pub async fn patch_settings(
context: ApiRequestContext,
Json(data): Json<SettingsPatch>,
) -> ApiResult {
debug!("Admin {} patching settings", session.user.username);
debug!("Admin {} is patching settings", session.user.username);
let mut settings = Settings::get_current_settings();
// prepare clone for emitting an event
let before = settings.clone();
Expand Down Expand Up @@ -158,15 +158,15 @@ pub async fn patch_settings(
let after = settings.clone();
update_current_settings(&appstate.pool, settings).await?;

info!("Admin {} patched settings.", session.user.username);
info!("Admin {} patched settings", session.user.username);
appstate.emit_event(ApiEvent {
context,
event: Box::new(ApiEventType::SettingsUpdatedPartial { before, after }),
})?;
Ok(ApiResponse::default())
}

pub async fn test_ldap_settings(_admin: AdminRole, _license: LicenseInfo) -> ApiResult {
pub(crate) async fn test_ldap_settings(_admin: AdminRole, _license: LicenseInfo) -> ApiResult {
debug!("Testing LDAP connection");
match LDAPConnection::create().await {
Ok(_) => {
Expand Down
4 changes: 3 additions & 1 deletion crates/defguard_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ use crate::{
add_group_member, create_group, delete_group, get_group, list_groups, modify_group,
remove_group_member,
},
license::license_check,
location_stats::{
location_connected_network_devices, location_connected_user_devices,
location_connected_users, location_stats, locations_overview_stats,
Expand Down Expand Up @@ -575,7 +576,8 @@ pub fn build_webapp(
"/network/{location_id}/snat/{user_id}",
put(modify_snat_binding).delete(delete_snat_binding),
)
.route("/outdated", get(outdated_components)),
.route("/outdated", get(outdated_components))
.route("/license/check", post(license_check)),
);

let webapp = webapp.nest(
Expand Down
6 changes: 5 additions & 1 deletion crates/defguard_core/src/openapi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use super::{
ApiResponse, EditGroupInfo, GroupInfo, PasswordChange, PasswordChangeSelf,
SESSION_COOKIE_NAME, StartEnrollmentRequest, Username, auth,
group::{self, BulkAssignToGroupsRequest, Groups},
license,
user::{self, UserDetails},
wireguard as device, wireguard as network,
wireguard::AddDeviceResult,
Expand Down Expand Up @@ -71,6 +72,8 @@ use super::{
network::delete_network,
network::list_networks,
network::network_details,
// /license
license::license_check,
// /network/{location_id}/snat
snat::list_snat_bindings,
snat::create_snat_binding,
Expand Down Expand Up @@ -111,7 +114,8 @@ use super::{
schemas(
ApiResponse, UserInfo, UserDetails, UserDevice, Groups, Username,
StartEnrollmentRequest, PasswordChangeSelf, PasswordChange, AddDevice, AddDeviceResult,
Device, ModifyDevice, BulkAssignToGroupsRequest, GroupInfo, EditGroupInfo, WebError
Device, ModifyDevice, BulkAssignToGroupsRequest, GroupInfo, EditGroupInfo, WebError,
license::CheckParams
),
),
tags(
Expand Down
10 changes: 5 additions & 5 deletions crates/model_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ pub fn derive(input: TokenStream) -> TokenStream {
// TODO: add limit and offset for all().
quote! {
impl #name<NoId> {
pub async fn save<'e, E>(self, executor: E) -> Result<#name<Id>, sqlx::Error>
pub async fn save<'e, E>(self, executor: E) -> sqlx::Result<#name<Id>>
where
E: sqlx::PgExecutor<'e>
{
Expand All @@ -235,21 +235,21 @@ pub fn derive(input: TokenStream) -> TokenStream {
}

impl #name<Id> {
pub async fn find_by_id<'e, E>(executor: E, id: Id) -> Result<Option<Self>, sqlx::Error>
pub async fn find_by_id<'e, E>(executor: E, id: Id) -> sqlx::Result<Option<Self>>
where
E: sqlx::PgExecutor<'e>
{
sqlx::query_as!(Self, #find_by_id_query, id).fetch_optional(executor).await
}

pub async fn all<'e, E>(executor: E) -> Result<Vec<Self>, sqlx::Error>
pub async fn all<'e, E>(executor: E) -> sqlx::Result<Vec<Self>>
where
E: sqlx::PgExecutor<'e>
{
sqlx::query_as!(Self, #all_query).fetch_all(executor).await
}

pub async fn delete<'e, E>(self, executor: E) -> Result<(), sqlx::Error>
pub async fn delete<'e, E>(self, executor: E) -> sqlx::Result<()>
where
E: sqlx::PgExecutor<'e>
{
Expand All @@ -258,7 +258,7 @@ pub fn derive(input: TokenStream) -> TokenStream {
Ok(())
}

pub async fn save<'e, E>(&self, executor: E) -> Result<(), sqlx::Error>
pub async fn save<'e, E>(&self, executor: E) -> sqlx::Result<()>
where
E: sqlx::PgExecutor<'e>
{
Expand Down
Loading
Loading