From 6bcb9079dde83a6602a62b8808aeb9c0efe6a61b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Thu, 26 Feb 2026 12:06:19 +0100 Subject: [PATCH 1/4] Cleanup certificate handling --- Cargo.lock | 12 +- Cargo.toml | 4 +- crates/defguard_certs/src/lib.rs | 76 ++-- crates/defguard_core/src/enterprise/limits.rs | 7 +- .../src/handlers/component_setup.rs | 347 ++++++++++-------- crates/defguard_core/src/utility_thread.rs | 141 ++++--- crates/defguard_setup/src/handlers.rs | 6 +- 7 files changed, 310 insertions(+), 283 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7fcdf032cf..c27e0c4f87 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5086,9 +5086,9 @@ dependencies = [ [[package]] name = "rgb" -version = "0.8.52" +version = "0.8.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c6a884d2998352bb4daf0183589aec883f16a6da1f4dde84d8e2e9a5409a1ce" +checksum = "47b34b781b31e5d73e9fbc8689c70551fd1ade9a19e3e28cfec8580a79290cc4" [[package]] name = "ring" @@ -6101,18 +6101,18 @@ dependencies = [ [[package]] name = "strum" -version = "0.27.2" +version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf" +checksum = "9628de9b8791db39ceda2b119bbe13134770b56c138ec1d3af810d045c04f9bd" dependencies = [ "strum_macros", ] [[package]] name = "strum_macros" -version = "0.27.2" +version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7695ce3845ea4b33927c055a39dc438a45b059f7c1b3d91d38d10355fb8cbca7" +checksum = "ab85eea0270ee17587ed4156089e10b9e6880ee688791d45a905f5b1ca36f664" dependencies = [ "heck", "proc-macro2", diff --git a/Cargo.toml b/Cargo.toml index 2315ef7223..2ac0c71e00 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -102,8 +102,8 @@ sqlx = { version = "0.8", features = [ ] } ssh-key = "0.6" struct-patch = "0.10" -strum = { version = "0.27", features = ["derive"] } -strum_macros = "0.27" +strum = { version = "0.28", features = ["derive"] } +strum_macros = "0.28" tera = "1.20" thiserror = "2.0" # match axum-extra -> cookies diff --git a/crates/defguard_certs/src/lib.rs b/crates/defguard_certs/src/lib.rs index 30abea7844..1f458ac396 100644 --- a/crates/defguard_certs/src/lib.rs +++ b/crates/defguard_certs/src/lib.rs @@ -137,7 +137,7 @@ impl CertificateAuthority<'_> { } pub fn expiry(&self) -> Result { - let CertificateInfo { not_after, .. } = parse_certificate_info(&self.cert_der)?; + let CertificateInfo { not_after, .. } = CertificateInfo::from_der(&self.cert_der)?; Ok(not_after) } } @@ -149,44 +149,48 @@ pub struct CertificateInfo { pub serial: String, } -pub fn parse_certificate_info(cert_der: &[u8]) -> Result { - let (_, parsed) = parse_x509_certificate(cert_der) - .map_err(|e| CertificateError::ParsingError(format!("Failed to parse certificate: {e}")))?; +impl CertificateInfo { + /// Parse certificate from DER-encoded bytes. + pub fn from_der(cert_der: &[u8]) -> Result { + let (_, parsed) = parse_x509_certificate(cert_der).map_err(|e| { + CertificateError::ParsingError(format!("Failed to parse certificate: {e}")) + })?; - let subject = &parsed.tbs_certificate.subject; - let serial = parsed.raw_serial_as_string(); + let subject = &parsed.tbs_certificate.subject; + let serial = parsed.raw_serial_as_string(); - let cn = subject - .iter_common_name() - .next() - .ok_or_else(|| CertificateError::ParsingError("Common Name not found".to_string()))? - .as_str() - .map_err(|e| { - CertificateError::ParsingError(format!("Failed to parse CN as string: {e}")) - })?; + let cn = subject + .iter_common_name() + .next() + .ok_or_else(|| CertificateError::ParsingError("Common Name not found".to_string()))? + .as_str() + .map_err(|e| { + CertificateError::ParsingError(format!("Failed to parse CN as string: {e}")) + })?; - let validity = &parsed.tbs_certificate.validity; - let not_before = validity.not_before.to_datetime(); - let not_after = validity.not_after.to_datetime(); - - Ok(CertificateInfo { - subject_common_name: cn.to_string(), - not_before: chrono::DateTime::from_timestamp(not_before.unix_timestamp(), 0) - .ok_or_else(|| { - CertificateError::ParsingError(format!( - "Failed to convert certificate not_before {not_before} to NaiveDateTime", - )) - })? - .naive_utc(), - not_after: chrono::DateTime::from_timestamp(not_after.unix_timestamp(), 0) - .ok_or_else(|| { - CertificateError::ParsingError(format!( - "Failed to convert certificate not_after {not_after} to NaiveDateTime", - )) - })? - .naive_utc(), - serial, - }) + let validity = &parsed.tbs_certificate.validity; + let not_before = validity.not_before.to_datetime(); + let not_after = validity.not_after.to_datetime(); + + Ok(Self { + subject_common_name: cn.to_string(), + not_before: chrono::DateTime::from_timestamp(not_before.unix_timestamp(), 0) + .ok_or_else(|| { + CertificateError::ParsingError(format!( + "Failed to convert certificate not_before {not_before} to NaiveDateTime", + )) + })? + .naive_utc(), + not_after: chrono::DateTime::from_timestamp(not_after.unix_timestamp(), 0) + .ok_or_else(|| { + CertificateError::ParsingError(format!( + "Failed to convert certificate not_after {not_after} to NaiveDateTime", + )) + })? + .naive_utc(), + serial, + }) + } } pub struct Csr<'a> { diff --git a/crates/defguard_core/src/enterprise/limits.rs b/crates/defguard_core/src/enterprise/limits.rs index c0c47c8592..f95628113e 100644 --- a/crates/defguard_core/src/enterprise/limits.rs +++ b/crates/defguard_core/src/enterprise/limits.rs @@ -1,5 +1,5 @@ use defguard_common::global_value; -use sqlx::{PgPool, error::Error as SqlxError, query}; +use sqlx::{error::Error as SqlxError, query}; use super::license::License; #[cfg(test)] @@ -60,11 +60,6 @@ pub async fn update_counts<'e, E: sqlx::PgExecutor<'e>>(executor: E) -> Result<( Ok(()) } -pub async fn do_count_update(pool: &PgPool) -> Result<(), SqlxError> { - update_counts(pool).await?; - Ok(()) -} - impl Counts { pub(crate) const fn default() -> Self { Self { diff --git a/crates/defguard_core/src/handlers/component_setup.rs b/crates/defguard_core/src/handlers/component_setup.rs index c1c46b8f51..6a6afb7211 100644 --- a/crates/defguard_core/src/handlers/component_setup.rs +++ b/crates/defguard_core/src/handlers/component_setup.rs @@ -5,7 +5,7 @@ use axum::{ extract::{Path, Query}, response::sse::{Event, KeepAlive, Sse}, }; -use defguard_certs::{der_to_pem, parse_certificate_info}; +use defguard_certs::der_to_pem; use defguard_common::{ VERSION, auth::claims::Claims, @@ -85,7 +85,8 @@ pub enum SetupStep { pub struct SetupResponse { #[serde(flatten)] pub step: SetupStep, - pub proxy_version: Option, + /// Gateway or Edge version. + pub version: Option, pub message: Option, pub logs: Option>, pub error: bool, @@ -114,14 +115,15 @@ impl Interceptor for AuthInterceptor { fn fallback_message(err: &str, last_step: SetupStep) -> String { format!( - r#"{{"step":"{last_step:?}","message":"Failed to serialize error response: {err}","error":true}}"#, + r#"{{"step":"{last_step:?}","message":"Failed to serialize error response: \ + {err}","error":true}}"#, ) } fn error_message(message: &str, last_step: SetupStep, logs: Option>) -> Event { let response = SetupResponse { step: last_step, - proxy_version: None, + version: None, message: Some(message.to_string()), logs, error: true, @@ -136,7 +138,7 @@ fn error_message(message: &str, last_step: SetupStep, logs: Option>) fn set_step_message(next_step: SetupStep) -> Event { let response = SetupResponse { step: next_step, - proxy_version: None, + version: None, message: None, logs: None, error: false, @@ -200,10 +202,13 @@ pub async fn setup_proxy_tls_stream( match Proxy::list(&pool).await { Ok(current_proxies) => { if !current_proxies.is_empty() { - yield Ok(flow.error("Enterprise license is required for connecting more then one edge.")); + yield Ok(flow.error( + "Enterprise license is required for connecting more \ + then one Edge.", + )); return; } - }, + } Err(e) => { yield Ok(flow.error(&format!("Failed to query existing proxies: {e}"))); return; @@ -212,21 +217,32 @@ pub async fn setup_proxy_tls_stream( } // Step 1: Check configuration - yield Ok( - flow.step(SetupStep::CheckingConfiguration) - ); + yield Ok(flow.step(SetupStep::CheckingConfiguration)); - match Proxy::find_by_address_port(&pool, &request.ip_or_domain, i32::from(request.grpc_port)).await { + match Proxy::find_by_address_port( + &pool, + &request.ip_or_domain, + i32::from(request.grpc_port), + ) + .await + { Ok(Some(proxy)) => { - yield Ok(flow.error(&format!("An edge Proxy with address {}:{} is already registered with name \"{}\".", request.ip_or_domain, request.grpc_port, proxy.name))); - return; + yield Ok(flow.error(&format!( + "An edge Proxy with address {}:{} is already \ + registered with name \"{}\".", + request.ip_or_domain, request.grpc_port, proxy.name + ))); + return; } Ok(None) => { - debug!("Verified no existing proxy registration for {}:{}", request.ip_or_domain, request.grpc_port); - }, + debug!( + "Verified no existing proxy registration for {}:{}", + request.ip_or_domain, request.grpc_port + ); + } Err(e) => { - yield Ok(flow.error(&format!("Failed to query existing proxy: {e}"))); - return; + yield Ok(flow.error(&format!("Failed to query existing proxy: {e}"))); + return; } } @@ -240,7 +256,7 @@ pub async fn setup_proxy_tls_stream( } }; - debug!("Successfully validated proxy address: {}", url_str); + debug!("Successfully validated Edge address: {url_str}",); let endpoint = match Endpoint::from_shared(url_str) { Ok(e) => e, @@ -282,7 +298,10 @@ pub async fn setup_proxy_tls_stream( } }; - debug!("Prepared secure connection endpoint for proxy at {}:{}", request.ip_or_domain, request.grpc_port); + debug!( + "Prepared secure connection endpoint for Edge at {}:{}", + request.ip_or_domain, request.grpc_port + ); let version = match Version::parse(VERSION) { Ok(v) => v, @@ -293,10 +312,7 @@ pub async fn setup_proxy_tls_stream( }; // Step 2: Check availability - yield Ok( - flow.step(SetupStep::CheckingAvailability) - ); - + yield Ok(flow.step(SetupStep::CheckingAvailability)); let version_clone = version.clone(); @@ -315,7 +331,7 @@ pub async fn setup_proxy_tls_stream( } }; - debug!("Generated secure setup token for proxy authentication"); + debug!("Generated secure setup token for Edge authentication"); let version_interceptor = ClientVersionInterceptor::new(version); let auth_interceptor = AuthInterceptor::new(token); @@ -323,55 +339,59 @@ pub async fn setup_proxy_tls_stream( let mut client = ProxySetupClient::with_interceptor( endpoint.connect_lazy(), move |mut req: Request<()>| { - req = version_interceptor.clone().call(req)?; - auth_interceptor.clone().call(req) - } + req = version_interceptor.clone().call(req)?; + auth_interceptor.clone().call(req) + }, ); - debug!("Initiating connection to edge proxy at {}:{}", request.ip_or_domain, request.grpc_port); + debug!( + "Initiating connection to Edge at {}:{}", + request.ip_or_domain, request.grpc_port + ); - let response_with_metadata = match tokio::time::timeout( - CONNECTION_TIMEOUT, - client.start(()) - ).await { - Ok(Ok(r)) => r, - Ok(Err(e)) => { - match e.code() { - tonic::Code::Unavailable => { - let error_msg = e.to_string(); - if error_msg.contains("h2 protocol error") || error_msg.contains("http2 error") { - yield Ok(flow.error(&format!( - "Failed to connect to edge proxy at {}:{}: {}. This may indicate that the proxy is already configured with TLS. Please check if the proxy has already been set up.", + let response_with_metadata = + match tokio::time::timeout(CONNECTION_TIMEOUT, client.start(())).await { + Ok(Ok(r)) => r, + Ok(Err(e)) => { + match e.code() { + tonic::Code::Unavailable => { + let error_msg = e.to_string(); + if error_msg.contains("h2 protocol error") + || error_msg.contains("http2 error") + { + yield Ok(flow.error(&format!( + "Failed to connect to Edge at {}:{}: {}. This may indicate that \ + the Edge is already configured with TLS. Please check if the Edge \ + has already been set up.", request.ip_or_domain, request.grpc_port, e ))); - } else { - yield Ok(flow.error(&format!( - "Failed to connect to edge proxy at {}:{}. Please ensure the address and port are correct and that the edge component is running.", + } else { + yield Ok(flow.error(&format!( + "Failed to connect to Edge at {}:{}. Please ensure the address \ + and port are correct and that the Edge component is running.", request.ip_or_domain, request.grpc_port ))); + } + } + _ => { + yield Ok(flow.error(&format!("Failed to connect to Edge: {e}"))); } } - _ => { - yield Ok(flow.error(&format!("Failed to connect to edge proxy: {e}"))); - } + return; } - return; - } - Err(_) => { - yield Ok(flow.error(&format!( - "Connection to edge proxy at {}:{} timed out after 10 seconds.", - request.ip_or_domain, request.grpc_port - ))); - return; - } - }; + Err(_) => { + yield Ok(flow.error(&format!( + "Connection to Edge at {}:{} timed out after 10 seconds.", + request.ip_or_domain, request.grpc_port + ))); + return; + } + }; - debug!("Successfully connected to edge proxy"); + debug!("Successfully connected to Edge"); // Step 3: Check version - yield Ok( - flow.step(SetupStep::CheckingVersion) - ); + yield Ok(flow.step(SetupStep::CheckingVersion)); let proxy_version = response_with_metadata .metadata() @@ -382,21 +402,25 @@ pub async fn setup_proxy_tls_stream( .unwrap_or(None); debug!("Proxy metadata: {:?}", response_with_metadata.metadata()); - debug!("Proxy version: {:?}", proxy_version); + debug!("Proxy version: {proxy_version:?}"); if let Some(proxy_version) = proxy_version { if proxy_version < MIN_PROXY_VERSION { yield Ok(flow.error(&format!( - "Edge proxy version {proxy_version} is older than core version {version_clone}. Please update the edge component.", + "Edge version {proxy_version} is older than core version \ + {version_clone}. Please update the edge component.", ))); return; } - debug!("Edge proxy version {} is compatible with core version {}", proxy_version, version_clone); + debug!( + "Edge version {} is compatible with core version {}", + proxy_version, version_clone + ); let response = SetupResponse { step: SetupStep::CheckingVersion, - proxy_version: Some(proxy_version.to_string()), + version: Some(proxy_version.to_string()), message: None, logs: None, error: false, @@ -404,17 +428,15 @@ pub async fn setup_proxy_tls_stream( match serde_json::to_string(&response) { Ok(body) => { - yield Ok( - Event::default().data(body) - ); - }, + yield Ok(Event::default().data(body)); + } Err(e) => { yield Ok(flow.error(&format!("Failed to serialize version response: {e}"))); return; } } } else { - yield Ok(flow.error("Failed to determine edge proxy version")); + yield Ok(flow.error("Failed to determine Edge version")); return; } @@ -423,26 +445,23 @@ pub async fn setup_proxy_tls_stream( let log_reader_task = tokio::spawn(async move { while let Some(log_entry) = response.next().await { match log_entry { - Ok(entry) => { - let level = entry.level - .strip_prefix("Level(") - .and_then(|s| s.strip_suffix(")")) - .unwrap_or(&entry.level) - .to_uppercase(); - - - let formatted = format!( - "{} {} {}: message={}", - entry.timestamp, - level, - entry.target, - entry.message - ); - if log_tx.send(formatted).is_err() { - break; + Ok(entry) => { + let level = entry + .level + .strip_prefix("Level(") + .and_then(|s| s.strip_suffix(")")) + .unwrap_or(&entry.level) + .to_uppercase(); + + let formatted = format!( + "{} {} {}: message={}", + entry.timestamp, level, entry.target, entry.message + ); + if log_tx.send(formatted).is_err() { + break; + } } - } - Err(e) => { + Err(e) => { let _ = log_tx.send(format!("Error reading log: {e}")); break; } @@ -451,7 +470,7 @@ pub async fn setup_proxy_tls_stream( }); // Create guard to ensure task is aborted on all exit paths - let _log_task_guard = TaskGuard(log_reader_task); + let _ = TaskGuard(log_reader_task); // Step 4: Obtain CSR yield Ok(flow.step(SetupStep::ObtainingCsr)); @@ -463,26 +482,28 @@ pub async fn setup_proxy_tls_stream( let csr_response = match client .get_csr(CertificateInfo { - cert_hostname: hostname.to_string(), + cert_hostname: hostname.to_string(), }) .await { Ok(r) => r.into_inner(), Err(e) => { - yield Ok(flow.error(&format!("Failed to obtain CSR: {e}"))); - return; + yield Ok(flow.error(&format!("Failed to obtain CSR: {e}"))); + return; } }; let csr = match defguard_certs::Csr::from_der(&csr_response.der_data) { Ok(c) => c, Err(e) => { - yield Ok(flow.error(&format!("Failed to parse CSR: {e}"))); - return; + yield Ok(flow.error(&format!("Failed to parse CSR: {e}"))); + return; } }; - debug!("Received certificate signing request from edge proxy for hostname: {}", hostname); + debug!( + "Received certificate signing request from Edge for hostname: {hostname}" + ); // Step 5: Sign certificate yield Ok(flow.step(SetupStep::SigningCertificate)); @@ -505,8 +526,8 @@ pub async fn setup_proxy_tls_stream( ) { Ok(c) => c, Err(e) => { - yield Ok(flow.error(&format!("Failed to create CA: {e}"))); - return; + yield Ok(flow.error(&format!("Failed to create CA: {e}"))); + return; } }; @@ -515,12 +536,12 @@ pub async fn setup_proxy_tls_stream( let cert = match ca.sign_csr(&csr) { Ok(c) => c, Err(e) => { - yield Ok(flow.error(&format!("Failed to sign CSR: {e}"))); - return; + yield Ok(flow.error(&format!("Failed to sign CSR: {e}"))); + return; } }; - debug!("Successfully signed certificate for edge proxy"); + debug!("Successfully signed certificate for Edge"); // Step 6: Configure TLS yield Ok(flow.step(SetupStep::ConfiguringTls)); @@ -534,23 +555,21 @@ pub async fn setup_proxy_tls_stream( return; } - debug!("Certificate successfully delivered to edge proxy"); + debug!("Certificate successfully delivered to Edge"); let defguard_certs::CertificateInfo { not_after: expiry, serial, .. - } = match parse_certificate_info(cert.der()) { - Ok(dt) => { - dt - }, + } = match defguard_certs::CertificateInfo::from_der(cert.der()) { + Ok(dt) => dt, Err(err) => { - yield Ok(flow.error(&format!("Failed to get certificate expiry: {err}"))); - return; + yield Ok(flow.error(&format!("Failed to get certificate expiry: {err}"))); + return; } }; - debug!("Certificate expiry date determined: {}", expiry); + debug!("Certificate expiry date determined: {expiry}"); let mut proxy = Proxy::new( &request.common_name, @@ -562,27 +581,34 @@ pub async fn setup_proxy_tls_stream( proxy.certificate = Some(serial); proxy.certificate_expiry = Some(expiry); - let proxy = match proxy.save(&pool).await { Ok(p) => p, Err(err) => { - yield Ok(flow.error(&format!("Failed to save proxy to database: {err}"))); - return; + yield Ok(flow.error(&format!("Failed to save Edge to database: {err}"))); + return; } }; - debug!("Edge proxy '{}' registered successfully with ID: {}", request.common_name, proxy.id); - debug!("Establishing connection to newly configured edge proxy"); + debug!( + "Edge '{}' registered successfully with ID: {}", + request.common_name, proxy.id + ); + debug!("Establishing connection to newly configured Edge"); if let Some(proxy_control_tx) = proxy_control_tx { - if let Err(err) = proxy_control_tx.send(ProxyControlMessage::StartConnection(proxy.id)).await { - yield Ok(flow.error(&format!("Failed send message to connect to proxy after setup: {err}"))); + if let Err(err) = proxy_control_tx + .send(ProxyControlMessage::StartConnection(proxy.id)) + .await + { + yield Ok(flow.error(&format!( + "Failed send message to connect to Edge after setup: {err}" + ))); return; } } else { - debug!("Proxy control channel not available; skipping connection initiation"); + debug!("Edge control channel not available; skipping connection initiation"); } - debug!("Edge proxy setup completed successfully"); + debug!("Edge setup completed successfully"); let mut settings = Settings::get_current_settings(); if !settings.initial_setup_completed { @@ -641,15 +667,17 @@ pub async fn setup_gateway_tls_stream( match Gateway::find_by_url(&pool, &request.ip_or_domain, request.grpc_port).await { Ok(Some(gateway)) => { - yield Ok(flow.error(&format!("A Gateway with URL {}:{} is already registered with name \"{}\".", request.ip_or_domain, request.grpc_port, gateway.name))); + yield Ok(flow.error(&format!("A Gateway with URL {}:{} is already registered with \ + name \"{}\".", request.ip_or_domain, request.grpc_port, gateway.name))); return; } Ok(None) => { - debug!("Verified no existing Gateway registration for {}:{}", request.ip_or_domain, request.grpc_port); + debug!("Verified no existing Gateway registration for {}:{}", request.ip_or_domain, + request.grpc_port); }, Err(e) => { - yield Ok(flow.error(&format!("Failed to query existing Gateway: {e}"))); - return; + yield Ok(flow.error(&format!("Failed to query existing Gateway: {e}"))); + return; } } @@ -704,7 +732,8 @@ pub async fn setup_gateway_tls_stream( } }; - debug!("Prepared secure connection endpoint for Gateway at {}:{}", request.ip_or_domain, request.grpc_port); + debug!("Prepared secure connection endpoint for Gateway at {}:{}", request.ip_or_domain, + request.grpc_port); let version = match Version::parse(VERSION) { Ok(v) => v, @@ -749,7 +778,8 @@ pub async fn setup_gateway_tls_stream( } ); - debug!("Initiating connection to edge Gateway at {}:{}", request.ip_or_domain, request.grpc_port); + debug!("Initiating connection to edge Gateway at {}:{}", request.ip_or_domain, + request.grpc_port); let response_with_metadata = match tokio::time::timeout( CONNECTION_TIMEOUT, @@ -762,12 +792,15 @@ pub async fn setup_gateway_tls_stream( let error_msg = e.to_string(); if error_msg.contains("h2 protocol error") || error_msg.contains("http2 error") { yield Ok(flow.error(&format!( - "Failed to connect to Gateway at {}:{}: {}. This may indicate that the Gateway is already configured with TLS. Please check if the Gateway has already been set up.", - request.ip_or_domain, request.grpc_port, e + "Failed to connect to Gateway at {}:{}: {e}. This may indicate that \ + the Gateway is already configured with TLS. Please check if the \ + Gateway has already been set up.", + request.ip_or_domain, request.grpc_port, ))); } else { yield Ok(flow.error(&format!( - "Failed to connect to Gateway at {}:{}. Please ensure the address and port are correct and that the Gateway is running.", + "Failed to connect to Gateway at {}:{}. Please ensure the address and \ + port are correct and that the Gateway is running.", request.ip_or_domain, request.grpc_port ))); } @@ -794,7 +827,7 @@ pub async fn setup_gateway_tls_stream( flow.step(SetupStep::CheckingVersion) ); - let proxy_version = response_with_metadata + let gateway_version = response_with_metadata .metadata() .get(defguard_version::VERSION_HEADER) .and_then(|v| v.to_str().ok()) @@ -802,22 +835,24 @@ pub async fn setup_gateway_tls_stream( .transpose() .unwrap_or(None); - debug!("Proxy metadata: {:?}", response_with_metadata.metadata()); - debug!("Proxy version: {:?}", proxy_version); + debug!("Gateway metadata: {:?}", response_with_metadata.metadata()); + debug!("Gateway version: {gateway_version:?}"); - if let Some(proxy_version) = proxy_version { - if proxy_version < MIN_GATEWAY_VERSION { + if let Some(gateway_version) = gateway_version { + if gateway_version < MIN_GATEWAY_VERSION { yield Ok(flow.error(&format!( - "Gateway version {proxy_version} is older than core version {version_clone}. Please update the edge component.", + "Gateway version {gateway_version} is older than core version {version_clone}. \ + Please update the Edge component.", ))); return; } - debug!("Gateway version {} is compatible with core version {}", proxy_version, version_clone); + debug!("Gateway version {gateway_version} is compatible with Core version \ + {version_clone}"); let response = SetupResponse { step: SetupStep::CheckingVersion, - proxy_version: Some(proxy_version.to_string()), + version: Some(gateway_version.to_string()), message: None, logs: None, error: false, @@ -844,26 +879,24 @@ pub async fn setup_gateway_tls_stream( let log_reader_task = tokio::spawn(async move { while let Some(log_entry) = response.next().await { match log_entry { - Ok(entry) => { - let level = entry.level - .strip_prefix("Level(") - .and_then(|s| s.strip_suffix(")")) - .unwrap_or(&entry.level) - .to_uppercase(); - - - let formatted = format!( - "{} {} {}: message={}", - entry.timestamp, - level, - entry.target, - entry.message - ); - if log_tx.send(formatted).is_err() { - break; + Ok(entry) => { + let level = entry.level + .strip_prefix("Level(") + .and_then(|s| s.strip_suffix(")")) + .unwrap_or(&entry.level) + .to_uppercase(); + + let formatted = format!( + "{} {level} {}: message={}", + entry.timestamp, + entry.target, + entry.message + ); + if log_tx.send(formatted).is_err() { + break; + } } - } - Err(e) => { + Err(e) => { let _ = log_tx.send(format!("Error reading log: {e}")); break; } @@ -872,7 +905,7 @@ pub async fn setup_gateway_tls_stream( }); // Create guard to ensure task is aborted on all exit paths - let _log_task_guard = TaskGuard(log_reader_task); + let _ = TaskGuard(log_reader_task); // Step 4: Obtain CSR yield Ok(flow.step(SetupStep::ObtainingCsr)); @@ -903,7 +936,7 @@ pub async fn setup_gateway_tls_stream( } }; - debug!("Received certificate signing request from Gateway for hostname: {}", hostname); + debug!("Received certificate signing request from Gateway for hostname: {hostname}"); // Step 5: Sign certificate yield Ok(flow.step(SetupStep::SigningCertificate)); @@ -961,7 +994,7 @@ pub async fn setup_gateway_tls_stream( not_after: expiry, serial, .. - } = match parse_certificate_info(cert.der()) { + } = match defguard_certs::CertificateInfo::from_der(cert.der()) { Ok(dt) => { dt }, @@ -971,7 +1004,7 @@ pub async fn setup_gateway_tls_stream( } }; - debug!("Certificate expiry date determined: {}", expiry); + debug!("Certificate expiry date determined: {expiry}"); let mut gateway = Gateway::new( network_id, diff --git a/crates/defguard_core/src/utility_thread.rs b/crates/defguard_core/src/utility_thread.rs index cf071fe5da..9bc8b4af37 100644 --- a/crates/defguard_core/src/utility_thread.rs +++ b/crates/defguard_core/src/utility_thread.rs @@ -18,7 +18,7 @@ use crate::{ firewall::try_get_location_firewall_config, is_business_license_active, ldap::{do_ldap_sync, sync::get_ldap_sync_interval}, - limits::do_count_update, + limits::update_counts, }, grpc::GatewayEvent, location_management::allowed_peers::get_location_allowed_peers, @@ -26,7 +26,7 @@ use crate::{ }; // Times in seconds -const UTILITY_THREAD_MAIN_SLEEP_TIME: u64 = 5; +const UTILITY_THREAD_MAIN_SLEEP_TIME: Duration = Duration::from_secs(5); const COUNT_UPDATE_INTERVAL: u64 = 60 * 60; const UPDATES_CHECK_INTERVAL: u64 = 60 * 60 * 6; const EXPIRED_ACL_RULES_CHECK_INTERVAL: u64 = 60 * 5; @@ -58,7 +58,7 @@ pub async fn run_utility_thread( }; let count_update_task = || async { - if let Err(e) = do_count_update(pool) + if let Err(e) = update_counts(pool) .instrument(info_span!("count_update_task")) .await { @@ -100,7 +100,7 @@ pub async fn run_utility_thread( expired_acl_rules_task().await; loop { - sleep(Duration::from_secs(UTILITY_THREAD_MAIN_SLEEP_TIME)).await; + sleep(UTILITY_THREAD_MAIN_SLEEP_TIME).await; // Count update job for updating device/user/network counts if last_count_update.elapsed().as_secs() >= COUNT_UPDATE_INTERVAL { @@ -135,14 +135,17 @@ pub async fn run_utility_thread( // Check if enterprise features got enabled or disabled if last_enterprise_status_check.elapsed().as_secs() >= ENTERPRISE_STATUS_CHECK_INTERVAL { let new_enterprise_enabled = is_business_license_active(); - if let Err(err) = enterprise_status_check( - pool, - wireguard_tx.clone(), - enterprise_enabled, - new_enterprise_enabled, - ) - .instrument(info_span!("enterprise_status_check")) - .await + if new_enterprise_enabled == enterprise_enabled { + continue; + } + debug!( + "Enterprise feature status changed from {enterprise_enabled} to \ + {new_enterprise_enabled}" + ); + if let Err(err) = + enterprise_status_check(pool, wireguard_tx.clone(), new_enterprise_enabled) + .instrument(info_span!("enterprise_status_check")) + .await { error!("Failed to check enterprise status: {err}"); } else { @@ -158,70 +161,61 @@ pub async fn run_utility_thread( async fn enterprise_status_check( pool: &PgPool, wireguard_tx: Sender, - current_enterprise_enabled: bool, - new_enterprise_enabled: bool, + enable_enterprise: bool, ) -> Result<(), anyhow::Error> { - if new_enterprise_enabled != current_enterprise_enabled { - debug!( - "Enterprise feature status changed from {current_enterprise_enabled} to \ - {new_enterprise_enabled}" - ); - - // fetch all ACL-enabled networks - let locations: Vec> = WireguardNetwork::all(pool) - .await? - .into_iter() - .filter(|location| location.acl_enabled) - .collect(); + // fetch all ACL-enabled networks + let locations = WireguardNetwork::all(pool) + .await? + .into_iter() + .filter(|location| location.acl_enabled) + .collect::>(); - if new_enterprise_enabled { - // handle switch from disabled -> enabled - debug!("Re-enabling gateway firewall configuration for ACL-enabled locations"); - let mut transaction = pool.begin().await?; - for location in locations { - debug!("Re-enabling gateway firewall configuration for location {location:?}"); - let firewall_config = try_get_location_firewall_config(&location, &mut transaction) - .await? - .expect("ACL-enabled location must have firewall config"); + if enable_enterprise { + // handle switch from disabled -> enabled + debug!("Re-enabling gateway firewall configuration for ACL-enabled locations"); + let mut transaction = pool.begin().await?; + for location in locations { + debug!("Re-enabling gateway firewall configuration for location {location:?}"); + let firewall_config = try_get_location_firewall_config(&location, &mut transaction) + .await? + .expect("ACL-enabled location must have firewall config"); - // Handle service location update or just update the firewall - if location.service_location_mode == ServiceLocationMode::Disabled { - wireguard_tx.send(GatewayEvent::FirewallConfigChanged( - location.id, - firewall_config, - ))?; - } else { - let new_peers = - get_location_allowed_peers(&location, &mut *transaction).await?; - wireguard_tx.send(GatewayEvent::NetworkModified( - location.id, - location, - new_peers, - Some(firewall_config), - ))?; - } + // Handle service location update or just update the firewall + if location.service_location_mode == ServiceLocationMode::Disabled { + wireguard_tx.send(GatewayEvent::FirewallConfigChanged( + location.id, + firewall_config, + ))?; + } else { + let new_peers = get_location_allowed_peers(&location, &mut *transaction).await?; + wireguard_tx.send(GatewayEvent::NetworkModified( + location.id, + location, + new_peers, + Some(firewall_config), + ))?; } - transaction.commit().await?; - } else { - // handle switch from enabled -> disabled - debug!("Disabling gateway firewall configuration for ACL-enabled locations"); - for location in locations { - if location.service_location_mode == ServiceLocationMode::Disabled { - debug!("Disabling gateway firewall configuration for location {location:?}"); - wireguard_tx.send(GatewayEvent::FirewallDisabled(location.id))?; - } else { - debug!( - "Disabling gateway firewall configuration and service location client \ - connections for location {location}" - ); - wireguard_tx.send(GatewayEvent::NetworkModified( - location.id, - location, - // Send empty peer list, we are disabling the service location - Vec::new(), - None, - ))?; - } + } + transaction.commit().await?; + } else { + // handle switch from enabled -> disabled + debug!("Disabling gateway firewall configuration for ACL-enabled locations"); + for location in locations { + if location.service_location_mode == ServiceLocationMode::Disabled { + debug!("Disabling gateway firewall configuration for location {location:?}"); + wireguard_tx.send(GatewayEvent::FirewallDisabled(location.id))?; + } else { + debug!( + "Disabling gateway firewall configuration and service location client \ + connections for location {location}" + ); + wireguard_tx.send(GatewayEvent::NetworkModified( + location.id, + location, + // Send empty peer list, we are disabling the service location + Vec::new(), + None, + ))?; } } } @@ -240,7 +234,8 @@ async fn expired_acl_rules_check( "UPDATE aclrule SET state = 'expired'::aclrule_state \ WHERE state = 'applied'::aclrule_state AND expires < NOW() \ RETURNING id, parent_id, state AS \"state: RuleState\", name, allow_all_users, \ - deny_all_users, allow_all_groups, deny_all_groups, allow_all_network_devices, deny_all_network_devices, all_locations, \ + deny_all_users, allow_all_groups, deny_all_groups, allow_all_network_devices, \ + deny_all_network_devices, all_locations, \ addresses, ports, protocols, enabled, expires, any_address, any_port, \ any_protocol, use_manual_destination_settings" ) diff --git a/crates/defguard_setup/src/handlers.rs b/crates/defguard_setup/src/handlers.rs index 4b7b952dc9..7bdd883cf5 100644 --- a/crates/defguard_setup/src/handlers.rs +++ b/crates/defguard_setup/src/handlers.rs @@ -10,7 +10,7 @@ use axum_extra::{ }, headers::UserAgent, }; -use defguard_certs::{der_to_pem, parse_certificate_info, parse_pem_certificate}; +use defguard_certs::{CertificateInfo, der_to_pem, parse_pem_certificate}; use defguard_common::db::models::{ Session, SessionState, Settings, User, group::Group, @@ -321,7 +321,7 @@ pub async fn get_ca(_: AdminOrSetupRole, Extension(pool): Extension) -> let settings = Settings::get_current_settings(); if let Some(ca_cert_der) = settings.ca_cert_der { let ca_pem = der_to_pem(&ca_cert_der, defguard_certs::PemLabel::Certificate)?; - let info = parse_certificate_info(&ca_cert_der)?; + let info = CertificateInfo::from_der(&ca_cert_der)?; let valid_for_days = (info.not_after.and_utc() - chrono::Utc::now()).num_days(); debug!( @@ -354,7 +354,7 @@ pub async fn upload_ca( ) -> ApiResult { info!("Uploading existing certificate authority"); let cert_der = parse_pem_certificate(&ca_info.cert_file)?; - let expiry = parse_certificate_info(&cert_der)?.not_after; + let expiry = CertificateInfo::from_der(&cert_der)?.not_after; let mut settings = Settings::get_current_settings(); settings.ca_cert_der = Some(cert_der.to_vec()); From 12f3c134f6d814779e8c3ad6d21eeb3c5f999f7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Thu, 26 Feb 2026 12:09:04 +0100 Subject: [PATCH 2/4] Make clippy happy --- .../src/db/models/vpn_session_stats.rs | 1 + crates/defguard_common/src/utils.rs | 1 + crates/defguard_proxy_manager/src/handler.rs | 14 +++++++------- crates/defguard_session_manager/src/lib.rs | 1 + 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/crates/defguard_common/src/db/models/vpn_session_stats.rs b/crates/defguard_common/src/db/models/vpn_session_stats.rs index 794fe3f426..3225780393 100644 --- a/crates/defguard_common/src/db/models/vpn_session_stats.rs +++ b/crates/defguard_common/src/db/models/vpn_session_stats.rs @@ -86,6 +86,7 @@ impl VpnSessionStats { } } +#[must_use] pub fn endpoint_without_port(endpoint: &str) -> Option { // Remove port part let mut addr = endpoint.rsplit_once(':')?.0; diff --git a/crates/defguard_common/src/utils.rs b/crates/defguard_common/src/utils.rs index ae5965921a..73c9e0ff6e 100644 --- a/crates/defguard_common/src/utils.rs +++ b/crates/defguard_common/src/utils.rs @@ -44,6 +44,7 @@ pub struct SplitIp { /// If they are not equal, we found the first modifiable segment (one of the segments of an address that may change between hosts in the same network), /// append the rest of the segments to the modifiable part. /// 3. Join the segments with the delimiter and return the network part, modifiable part and the network prefix +#[must_use] pub fn split_ip(ip: &IpAddr, network: &IpNetwork) -> SplitIp { let network_addr = network.network(); let network_prefix = network.prefix(); diff --git a/crates/defguard_proxy_manager/src/handler.rs b/crates/defguard_proxy_manager/src/handler.rs index f16fb182c6..d8c045c675 100644 --- a/crates/defguard_proxy_manager/src/handler.rs +++ b/crates/defguard_proxy_manager/src/handler.rs @@ -619,13 +619,7 @@ impl ProxyHandler { } } Some(core_request::Payload::AuthInfo(request)) => { - if !is_business_license_active() { - warn!("Enterprise license required"); - Some(core_response::Payload::CoreError(CoreError { - status_code: Code::FailedPrecondition as i32, - message: "no valid license".into(), - })) - } else { + if is_business_license_active() { let redirect_url = match request.auth_flow_type() { ProtoAuthFlowType::Enrollment => { let settings = Settings::get_current_settings(); @@ -703,6 +697,12 @@ impl ProxyHandler { message: "invalid redirect URL".into(), })) } + } else { + warn!("Enterprise license required"); + Some(core_response::Payload::CoreError(CoreError { + status_code: Code::FailedPrecondition as i32, + message: "no valid license".into(), + })) } } Some(core_request::Payload::AuthCallback(request)) => { diff --git a/crates/defguard_session_manager/src/lib.rs b/crates/defguard_session_manager/src/lib.rs index 48cce5a1d7..9ee8e4a463 100644 --- a/crates/defguard_session_manager/src/lib.rs +++ b/crates/defguard_session_manager/src/lib.rs @@ -109,6 +109,7 @@ pub struct SessionManager { } impl SessionManager { + #[must_use] pub fn new( pool: PgPool, session_manager_event_tx: UnboundedSender, From b3ff47601a945effe518bd8e78e83f3c7d3e98a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Thu, 26 Feb 2026 13:26:05 +0100 Subject: [PATCH 3/4] More refinements --- .../src/handlers/component_setup.rs | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/crates/defguard_core/src/handlers/component_setup.rs b/crates/defguard_core/src/handlers/component_setup.rs index 6a6afb7211..cc90dd2b5e 100644 --- a/crates/defguard_core/src/handlers/component_setup.rs +++ b/crates/defguard_core/src/handlers/component_setup.rs @@ -115,8 +115,7 @@ impl Interceptor for AuthInterceptor { fn fallback_message(err: &str, last_step: SetupStep) -> String { format!( - r#"{{"step":"{last_step:?}","message":"Failed to serialize error response: \ - {err}","error":true}}"#, + r#"{{"step":"{last_step:?}","message":"Failed to serialize error response: {err}","error":true}}"#, ) } @@ -401,20 +400,20 @@ pub async fn setup_proxy_tls_stream( .transpose() .unwrap_or(None); - debug!("Proxy metadata: {:?}", response_with_metadata.metadata()); - debug!("Proxy version: {proxy_version:?}"); + debug!("Edge metadata: {:?}", response_with_metadata.metadata()); + debug!("Edge version: {proxy_version:?}"); if let Some(proxy_version) = proxy_version { if proxy_version < MIN_PROXY_VERSION { yield Ok(flow.error(&format!( - "Edge version {proxy_version} is older than core version \ - {version_clone}. Please update the edge component.", + "Edge version {proxy_version} is older than Core version \ + {version_clone}. Please update the Edge component.", ))); return; } debug!( - "Edge version {} is compatible with core version {}", + "Edge version {} is compatible with Core version {}", proxy_version, version_clone ); @@ -470,7 +469,7 @@ pub async fn setup_proxy_tls_stream( }); // Create guard to ensure task is aborted on all exit paths - let _ = TaskGuard(log_reader_task); + let _log_task_guard = TaskGuard(log_reader_task); // Step 4: Obtain CSR yield Ok(flow.step(SetupStep::ObtainingCsr)); @@ -778,7 +777,7 @@ pub async fn setup_gateway_tls_stream( } ); - debug!("Initiating connection to edge Gateway at {}:{}", request.ip_or_domain, + debug!("Initiating connection to Gateway at {}:{}", request.ip_or_domain, request.grpc_port); let response_with_metadata = match tokio::time::timeout( @@ -841,8 +840,8 @@ pub async fn setup_gateway_tls_stream( if let Some(gateway_version) = gateway_version { if gateway_version < MIN_GATEWAY_VERSION { yield Ok(flow.error(&format!( - "Gateway version {gateway_version} is older than core version {version_clone}. \ - Please update the Edge component.", + "Gateway version {gateway_version} is older than Core version {version_clone}. \ + Please update the Gateway component.", ))); return; } @@ -905,7 +904,7 @@ pub async fn setup_gateway_tls_stream( }); // Create guard to ensure task is aborted on all exit paths - let _ = TaskGuard(log_reader_task); + let _log_task_guard = TaskGuard(log_reader_task); // Step 4: Obtain CSR yield Ok(flow.step(SetupStep::ObtainingCsr)); From 0049ba5ba0f50ca668846c39df0f23cac72364f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Thu, 26 Feb 2026 13:29:00 +0100 Subject: [PATCH 4/4] Re-format --- .../src/handlers/component_setup.rs | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/crates/defguard_core/src/handlers/component_setup.rs b/crates/defguard_core/src/handlers/component_setup.rs index cc90dd2b5e..070331c853 100644 --- a/crates/defguard_core/src/handlers/component_setup.rs +++ b/crates/defguard_core/src/handlers/component_setup.rs @@ -689,7 +689,7 @@ pub async fn setup_gateway_tls_stream( } }; - debug!("Successfully validated Gateway address: {}", url_str); + debug!("Successfully validated Gateway address: {url_str}"); let endpoint = match Endpoint::from_shared(url.to_string()) { Ok(e) => e, @@ -790,12 +790,12 @@ pub async fn setup_gateway_tls_stream( tonic::Code::Unavailable => { let error_msg = e.to_string(); if error_msg.contains("h2 protocol error") || error_msg.contains("http2 error") { - yield Ok(flow.error(&format!( - "Failed to connect to Gateway at {}:{}: {e}. This may indicate that \ - the Gateway is already configured with TLS. Please check if the \ - Gateway has already been set up.", - request.ip_or_domain, request.grpc_port, - ))); + yield Ok(flow.error(&format!( + "Failed to connect to Gateway at {}:{}: {e}. This may indicate \ + that the Gateway is already configured with TLS. Please check if \ + the Gateway has already been set up.", + request.ip_or_domain, request.grpc_port, + ))); } else { yield Ok(flow.error(&format!( "Failed to connect to Gateway at {}:{}. Please ensure the address and \ @@ -922,16 +922,16 @@ pub async fn setup_gateway_tls_stream( { Ok(r) => r.into_inner(), Err(e) => { - yield Ok(flow.error(&format!("Failed to obtain CSR: {e}"))); - return; + yield Ok(flow.error(&format!("Failed to obtain CSR: {e}"))); + return; } }; let csr = match defguard_certs::Csr::from_der(&csr_response.der_data) { Ok(c) => c, Err(e) => { - yield Ok(flow.error(&format!("Failed to parse CSR: {e}"))); - return; + yield Ok(flow.error(&format!("Failed to parse CSR: {e}"))); + return; } }; @@ -958,8 +958,8 @@ pub async fn setup_gateway_tls_stream( ) { Ok(c) => c, Err(e) => { - yield Ok(flow.error(&format!("Failed to create CA: {e}"))); - return; + yield Ok(flow.error(&format!("Failed to create CA: {e}"))); + return; } }; @@ -968,8 +968,8 @@ pub async fn setup_gateway_tls_stream( let cert = match ca.sign_csr(&csr) { Ok(c) => c, Err(e) => { - yield Ok(flow.error(&format!("Failed to sign CSR: {e}"))); - return; + yield Ok(flow.error(&format!("Failed to sign CSR: {e}"))); + return; } };