diff --git a/crates/defguard/src/main.rs b/crates/defguard/src/main.rs index facd3810a..aa5189646 100644 --- a/crates/defguard/src/main.rs +++ b/crates/defguard/src/main.rs @@ -292,9 +292,9 @@ async fn main() -> Result<(), anyhow::Error> { settings.stats_purge_threshold() ), if settings.enable_stats_purge => bail!("Periodic stats purge task returned early: {res:?}"), - res = run_periodic_license_check(&pool, proxy_control_tx) => + res = run_periodic_license_check(&pool, proxy_control_tx.clone()) => bail!("Periodic license check task returned early: {res:?}"), - res = run_utility_thread(&pool, gateway_tx.clone()) => + res = run_utility_thread(&pool, gateway_tx.clone(), proxy_control_tx) => bail!("Utility thread returned early: {res:?}"), res = run_event_router( RouterReceiverSet::new( diff --git a/crates/defguard_common/src/db/models/settings.rs b/crates/defguard_common/src/db/models/settings.rs index a5273bd81..5efb492c0 100644 --- a/crates/defguard_common/src/db/models/settings.rs +++ b/crates/defguard_common/src/db/models/settings.rs @@ -86,6 +86,12 @@ pub enum SettingsUrlError { DefguardUrlUsesIpAddress(String), #[error("Invalid WebAuthn configuration for defguard_url `{0}`: {1}")] InvalidWebauthnConfiguration(String, String), + #[error("Public Edge URL is not configured")] + PublicEdgeUrlEmpty, + #[error("Unparsable Edge url: {0}")] + UnparsableEdgeUrl(String), + #[error("Edge url missing hostname: {0}")] + EdgeUrlMissingHostname(String), } #[derive(Error, Debug)] @@ -770,6 +776,21 @@ impl Settings { Url::parse(&self.public_proxy_url) } + pub fn proxy_hostname(&self) -> Result { + if self.public_proxy_url.trim().is_empty() { + return Err(SettingsUrlError::PublicEdgeUrlEmpty); + } + let url = self + .proxy_public_url() + .map_err(|_err| SettingsUrlError::UnparsableEdgeUrl(self.public_proxy_url.clone()))?; + let hostname = url + .host_str() + .ok_or_else(|| SettingsUrlError::EdgeUrlMissingHostname(self.public_proxy_url.clone()))? + .to_string(); + + Ok(hostname) + } + #[allow(deprecated)] fn apply_from_config(&mut self, config: &DefGuardConfig) { let minute = 60; diff --git a/crates/defguard_core/src/handlers/component_setup.rs b/crates/defguard_core/src/handlers/component_setup.rs index 2c1a2bf22..90924344a 100644 --- a/crates/defguard_core/src/handlers/component_setup.rs +++ b/crates/defguard_core/src/handlers/component_setup.rs @@ -10,7 +10,6 @@ use axum::{ extract::{Path, Query}, response::sse::{Event, KeepAlive, Sse}, }; -use chrono::NaiveDateTime; use defguard_certs::der_to_pem; use defguard_common::{ VERSION, @@ -32,17 +31,14 @@ use defguard_common::{ use defguard_proto::{ common::{CertificateInfo, DerPayload}, gateway::gateway_setup_client::GatewaySetupClient, - proxy::{ - AcmeChallenge, AcmeLogs, AcmeStep, acme_issue_event, proxy_client::ProxyClient, - proxy_setup_client::ProxySetupClient, - }, + proxy::{AcmeStep, proxy_setup_client::ProxySetupClient}, }; use defguard_version::{Version, client::ClientVersionInterceptor}; use futures::Stream; use reqwest::Url; use serde::{Deserialize, Serialize}; use sqlx::PgPool; -use tokio::sync::mpsc::{Sender, UnboundedReceiver, UnboundedSender, unbounded_channel}; +use tokio::sync::mpsc::{Sender, UnboundedReceiver, unbounded_channel}; use tokio_stream::StreamExt; use tonic::{ Request, Status, @@ -54,6 +50,7 @@ use tracing::Instrument; use crate::{ auth::{AdminOrSetupRole, SessionInfo}, enterprise::is_enterprise_license_active, + letsencrypt::{ACME_TIMEOUT, acme_step_name, call_proxy_trigger_acme, parse_cert_expiry}, setup_logs::scope_setup_logs, version::{MIN_GATEWAY_VERSION, MIN_PROXY_VERSION}, }; @@ -1059,9 +1056,6 @@ pub async fn setup_gateway_tls_stream( Sse::new(stream).keep_alive(KeepAlive::default()) } -/// Maximum time (seconds) allowed for the ACME flow to complete end-to-end. -const ACME_TIMEOUT_SECS: u64 = 300; - #[derive(Debug, Serialize)] struct AcmeSetupResponse { step: &'static str, @@ -1094,147 +1088,6 @@ fn acme_error_event(step: &'static str, message: String, logs: Option &'static str { - match step { - AcmeStep::Unspecified | AcmeStep::Connecting => "Connecting", - AcmeStep::CheckingDomain => "CheckingDomain", - AcmeStep::ValidatingDomain => "ValidatingDomain", - AcmeStep::IssuingCertificate => "IssuingCertificate", - } -} - -fn parse_cert_expiry(cert_pem: &str) -> Option { - let der = defguard_certs::parse_pem_certificate(cert_pem) - .map_err(|e| warn!("Failed to parse ACME cert PEM for expiry: {e}")) - .ok()?; - defguard_certs::CertificateInfo::from_der(&der) - .map(|info| info.not_after) - .map_err(|e| warn!("Failed to extract expiry from ACME cert: {e}")) - .ok() -} - -fn public_proxy_hostname() -> Result { - let public_proxy_url = Settings::get_current_settings().public_proxy_url; - let url = public_proxy_url.trim(); - - if url.is_empty() { - return Err( - "Public Edge URL is not configured. Please re-submit the external URL settings \ - with a Let's Encrypt domain." - .to_string(), - ); - } - - Url::parse(url) - .ok() - .and_then(|u| u.host_str().map(ToString::to_string)) - .filter(|host| !host.is_empty()) - .ok_or_else(|| { - "Public Edge URL is not configured with a valid hostname. Please re-submit the \ - external URL settings with a valid domain." - .to_string() - }) -} - -/// Connects to the proxy's permanent `Proxy` gRPC service and calls `TriggerAcme`. -/// -/// Returns `(cert_pem, key_pem, account_credentials_json)` on success, or -/// `(error_message, log_lines)` on failure where `log_lines` are the proxy log entries -/// collected during the ACME run (sent by the proxy via an [`AcmeLogs`] event). -async fn call_proxy_trigger_acme( - pool: &PgPool, - proxy_host: &str, - proxy_port: u16, - domain: String, - account_credentials_json: String, - progress_tx: UnboundedSender, -) -> Result<(String, String, String), (String, Vec)> { - let certs = Certificates::get_or_default(pool) - .await - .map_err(|e| (format!("Failed to load certificates: {e}"), Vec::new()))?; - let ca_cert_der = certs.ca_cert_der.ok_or_else(|| { - ( - "CA certificate not found in settings".to_string(), - Vec::new(), - ) - })?; - - let cert_pem = der_to_pem(&ca_cert_der, defguard_certs::PemLabel::Certificate) - .map_err(|e| (format!("Failed to convert CA cert to PEM: {e}"), Vec::new()))?; - - let endpoint_str = format!("https://{proxy_host}:{proxy_port}"); - let endpoint = Endpoint::from_shared(endpoint_str) - .map_err(|e| (format!("Failed to build Edge endpoint: {e}"), Vec::new()))? - .http2_keep_alive_interval(Duration::from_secs(5)) - .tcp_keepalive(Some(Duration::from_secs(5))) - .keep_alive_while_idle(true); - - let tls = ClientTlsConfig::new().ca_certificate(Certificate::from_pem(cert_pem)); - let endpoint = endpoint.tls_config(tls).map_err(|e| { - ( - format!("Failed to configure TLS for Edge endpoint: {e}"), - Vec::new(), - ) - })?; - - let version = Version::parse(VERSION) - .map_err(|e| (format!("Failed to parse core version: {e}"), Vec::new()))?; - let version_interceptor = ClientVersionInterceptor::new(version); - - let mut client = - ProxyClient::with_interceptor(endpoint.connect_lazy(), move |req: Request<()>| { - version_interceptor.clone().call(req) - }); - - let mut stream = client - .trigger_acme(AcmeChallenge { - domain: domain.clone(), - account_credentials_json, - }) - .await - .map_err(|e| (format!("TriggerAcme RPC failed: {e}"), Vec::new()))? - .into_inner(); - - let mut collected_logs: Vec = Vec::new(); - - loop { - match stream.message().await { - Ok(Some(event)) => match event.payload { - Some(acme_issue_event::Payload::Progress(p)) => { - if let Ok(step) = AcmeStep::try_from(p.step) { - let _ = progress_tx.send(step); - } - } - Some(acme_issue_event::Payload::Certificate(cert)) => { - return Ok((cert.cert_pem, cert.key_pem, cert.account_credentials_json)); - } - Some(acme_issue_event::Payload::Logs(AcmeLogs { lines })) => { - collected_logs = lines; - } - None => { - return Err(( - "TriggerAcme stream sent an event with no payload".to_string(), - collected_logs, - )); - } - }, - Ok(None) => { - return Err(( - "TriggerAcme stream ended without delivering a certificate".to_string(), - collected_logs, - )); - } - Err(e) => { - return Err(( - format!("Failed to read TriggerAcme response: {e}"), - collected_logs, - )); - } - } - } -} - /// Streams Let's Encrypt certificate issuance progress as Server-Sent Events. /// /// Delegates the ACME HTTP-01 process to the proxy component via the `TriggerAcme` @@ -1259,10 +1112,11 @@ pub async fn stream_proxy_acme( } }; - let domain = match public_proxy_hostname() { + let settings = Settings::get_current_settings(); + let domain = match settings.proxy_hostname() { Ok(domain) => domain, - Err(message) => { - yield Ok(acme_error_event("Connecting", message, None)); + Err(err) => { + yield Ok(acme_error_event("Connecting", err.to_string(), None)); return; } }; @@ -1321,8 +1175,7 @@ pub async fn stream_proxy_acme( }); let mut current_step: &'static str = "Connecting"; - let deadline = tokio::time::Instant::now() - + tokio::time::Duration::from_secs(ACME_TIMEOUT_SECS); + let deadline = tokio::time::Instant::now() + ACME_TIMEOUT; // Drain progress steps until the ACME task finishes (channel closed) or times out. loop { @@ -1345,7 +1198,7 @@ pub async fn stream_proxy_acme( current_step, format!( "ACME certificate issuance timed out after \ - {ACME_TIMEOUT_SECS} seconds." + {} seconds.", ACME_TIMEOUT.as_secs() ), None, )); diff --git a/crates/defguard_core/src/letsencrypt.rs b/crates/defguard_core/src/letsencrypt.rs new file mode 100644 index 000000000..601cc3ff5 --- /dev/null +++ b/crates/defguard_core/src/letsencrypt.rs @@ -0,0 +1,795 @@ +use std::time::Duration; + +use chrono::{NaiveDateTime, TimeDelta, Utc}; +use defguard_certs::der_to_pem; +use defguard_common::{ + VERSION, + db::models::{Certificates, ProxyCertSource, Settings, User, proxy::Proxy}, + types::proxy::ProxyControlMessage, +}; +use defguard_mail::templates; +use defguard_proto::proxy::{ + AcmeChallenge, AcmeLogs, AcmeStep, acme_issue_event, proxy_client::ProxyClient, +}; +use defguard_version::{Version, client::ClientVersionInterceptor}; +use sqlx::PgPool; +use thiserror::Error; +use tokio::sync::mpsc::{self, UnboundedSender, unbounded_channel}; +use tonic::{ + Request, + service::Interceptor, + transport::{Certificate, ClientTlsConfig, Endpoint}, +}; + +/// Maximum time (seconds) allowed for the ACME flow to complete end-to-end. +#[cfg(not(test))] +pub const ACME_TIMEOUT: Duration = Duration::from_secs(300); +#[cfg(test)] +pub const ACME_TIMEOUT: Duration = Duration::from_secs(1); +const LETSENCRYPT_EXPIRY_THRESHOLD: TimeDelta = TimeDelta::days(14); + +#[derive(Debug, Error)] +pub(crate) enum LetsencryptError { + #[error("Failed to load certificates: {0}")] + CertificatesLoadFailed(sqlx::Error), + #[error("Failed to resolve proxy hostname: {0}")] + ProxyHostnameFailed(String), + #[error("Failed to load Edge list from DB: {0}")] + ProxyListLoadFailed(sqlx::Error), + #[error("No Edge found in database")] + NoProxyFound, + #[error("ACME certificate issuance timed out after {} seconds", timeout.as_secs())] + AcmeTimedOut { timeout: Duration }, + #[error("Failed to reload certificates for saving: {0}")] + CertificateReloadFailed(sqlx::Error), + #[error("Failed to save certificate: {0}")] + CertificateSaveFailed(sqlx::Error), + #[error("ACME issuance failed: {0}")] + AcmeIssuanceFailed(String), +} + +/// Refreshes the proxy HTTPS certificate through the Edge ACME flow when the +/// currently stored Let's Encrypt certificate is close to expiry. +/// +/// Returns `Ok(())` when refresh is not needed or when renewal completes +/// successfully. Returns [`LetsencryptError`] only for operational failures in +/// the refresh flow itself. +pub(crate) async fn do_letsencrypt_refresh( + pool: &PgPool, + proxy_control_tx: mpsc::Sender, +) -> Result<(), LetsencryptError> { + debug!("Performing letsencrypt cert validity check"); + let Some(certs) = Certificates::get(pool) + .await + .map_err(LetsencryptError::CertificatesLoadFailed)? + else { + warn!("Missing certificates configuration, aborting letsencrypt expiry check"); + return Ok(()); + }; + + if certs.proxy_http_cert_source != ProxyCertSource::LetsEncrypt { + info!( + "Edge certificate source is {:?}, skipping Letsencrypt expiry check", + certs.proxy_http_cert_source + ); + return Ok(()); + } + + let Some(expiry) = certs.proxy_http_cert_expiry else { + info!( + "Edge certificate has no expiry date, skipping Letsencrypt refresh certificate refresh" + ); + return Ok(()); + }; + + let expire_in = expiry - Utc::now().naive_utc(); + if expire_in > LETSENCRYPT_EXPIRY_THRESHOLD { + info!( + "Letsencrypt certificate expires in {} days, skipping refresh", + expire_in.num_days() + ); + return Ok(()); + } + + info!( + "Letsencrypt certificate expires in {} days, performing certificate refresh", + expire_in.num_days() + ); + let settings = Settings::get_current_settings(); + let domain = settings + .proxy_hostname() + .map_err(|err| LetsencryptError::ProxyHostnameFailed(err.to_string()))?; + let account_credentials_json = certs.acme_account_credentials.clone().unwrap_or_default(); + let proxies = Proxy::list(pool) + .await + .map_err(LetsencryptError::ProxyListLoadFailed)?; + let Some(proxy) = proxies.into_iter().next() else { + warn!("No Edge found in database, aborting Letsencrypt expiry check"); + return Err(LetsencryptError::NoProxyFound); + }; + + let proxy_host = proxy.address.clone(); + let proxy_port = proxy.port as u16; + info!( + "Triggering ACME HTTP-01 via Edge gRPC TriggerAcme for domain: {domain} \ + Edge={proxy_host}:{proxy_port}" + ); + + let (progress_tx, _progress_rx) = unbounded_channel::(); + + match tokio::time::timeout( + ACME_TIMEOUT, + call_proxy_trigger_acme( + pool, + &proxy_host, + proxy_port, + domain.clone(), + account_credentials_json, + progress_tx, + ), + ) + .await + { + Ok(Ok((cert_pem, key_pem, new_account_credentials_json))) => { + let acme_cert_expiry = parse_cert_expiry(&cert_pem); + match Certificates::get_or_default(pool).await { + Ok(mut updated_certs) => { + updated_certs.acme_domain = Some(domain.clone()); + updated_certs.proxy_http_cert_pem = Some(cert_pem.clone()); + updated_certs.proxy_http_cert_key_pem = Some(key_pem.clone()); + updated_certs.proxy_http_cert_expiry = acme_cert_expiry; + updated_certs.acme_account_credentials = Some(new_account_credentials_json); + updated_certs.proxy_http_cert_source = ProxyCertSource::LetsEncrypt; + if let Err(e) = updated_certs.save(pool).await { + error!("Failed to save certificate: {e}"); + return Err(LetsencryptError::CertificateSaveFailed(e)); + } + } + Err(e) => { + error!("Failed to reload certificates for saving: {e}"); + return Err(LetsencryptError::CertificateReloadFailed(e)); + } + } + + // Broadcast certs to the proxy via bidi channel + let msg = ProxyControlMessage::BroadcastHttpsCerts { cert_pem, key_pem }; + if let Err(e) = proxy_control_tx.send(msg).await { + error!("Failed to broadcast HttpsCerts to Edge: {e}"); + } + + info!("ACME certificate issued and saved for domain: {domain}"); + } + Ok(Err((acme_err, logs))) => { + error!("ACME issuance failed: {acme_err}"); + if let Err(err) = send_le_refresh_failed_emails(pool, &acme_err, &logs).await { + error!("Sending letsencrypt refresh email notification failed: {err}"); + } + return Err(LetsencryptError::AcmeIssuanceFailed(acme_err)); + } + Err(_) => { + error!( + "ACME certificate issuance timed out after \ + {}.", + ACME_TIMEOUT.as_secs(), + ); + return Err(LetsencryptError::AcmeTimedOut { + timeout: ACME_TIMEOUT, + }); + } + } + + Ok(()) +} + +/// Sends a failed Let's Encrypt refresh notification email to all active +/// administrators. +/// +/// The provided log lines are joined into a single text attachment and sent +/// with the notification email. +async fn send_le_refresh_failed_emails( + pool: &PgPool, + error_message: &str, + logs: &[String], +) -> Result<(), anyhow::Error> { + let mut conn = pool.begin().await?; + let admin_users = User::find_admins(&mut *conn).await?; + for user in admin_users { + templates::letsencrypt_cert_refresh_failed_mail( + &user.email, + &mut conn, + error_message, + &logs.join("\n"), + ) + .await?; + } + + Ok(()) +} + +/// Parses the expiry timestamp from a PEM-encoded certificate. +/// +/// Returns the certificate `not_after` value, or `None` if the PEM cannot be +/// parsed or the expiry cannot be extracted. +pub(crate) fn parse_cert_expiry(cert_pem: &str) -> Option { + let der = defguard_certs::parse_pem_certificate(cert_pem) + .map_err(|e| warn!("Failed to parse ACME cert PEM for expiry: {e}")) + .ok()?; + defguard_certs::CertificateInfo::from_der(&der) + .map(|info| info.not_after) + .map_err(|e| warn!("Failed to extract expiry from ACME cert: {e}")) + .ok() +} + +/// Maps a proto [`AcmeStep`] to the SSE step string expected by the frontend. +pub(crate) fn acme_step_name(step: AcmeStep) -> &'static str { + match step { + AcmeStep::Unspecified | AcmeStep::Connecting => "Connecting", + AcmeStep::CheckingDomain => "CheckingDomain", + AcmeStep::ValidatingDomain => "ValidatingDomain", + AcmeStep::IssuingCertificate => "IssuingCertificate", + } +} + +/// Connects to the proxy's permanent `Proxy` gRPC service and calls `TriggerAcme`. +/// +/// Returns `(cert_pem, key_pem, account_credentials_json)` on success, or +/// `(error_message, log_lines)` on failure where `log_lines` are the proxy log entries +/// collected during the ACME run (sent by the proxy via an [`AcmeLogs`] event). +pub(crate) async fn call_proxy_trigger_acme( + pool: &PgPool, + proxy_host: &str, + proxy_port: u16, + domain: String, + account_credentials_json: String, + progress_tx: UnboundedSender, +) -> Result<(String, String, String), (String, Vec)> { + let certs = Certificates::get_or_default(pool) + .await + .map_err(|e| (format!("Failed to load certificates: {e}"), Vec::new()))?; + let ca_cert_der = certs.ca_cert_der.ok_or_else(|| { + ( + "CA certificate not found in settings".to_string(), + Vec::new(), + ) + })?; + + let cert_pem = der_to_pem(&ca_cert_der, defguard_certs::PemLabel::Certificate) + .map_err(|e| (format!("Failed to convert CA cert to PEM: {e}"), Vec::new()))?; + + let endpoint_str = format!("https://{proxy_host}:{proxy_port}"); + let endpoint = Endpoint::from_shared(endpoint_str) + .map_err(|e| (format!("Failed to build Edge endpoint: {e}"), Vec::new()))? + .http2_keep_alive_interval(Duration::from_secs(5)) + .tcp_keepalive(Some(Duration::from_secs(5))) + .keep_alive_while_idle(true); + + let tls = ClientTlsConfig::new().ca_certificate(Certificate::from_pem(cert_pem)); + let endpoint = endpoint.tls_config(tls).map_err(|e| { + ( + format!("Failed to configure TLS for Edge endpoint: {e}"), + Vec::new(), + ) + })?; + + let version = Version::parse(VERSION) + .map_err(|e| (format!("Failed to parse core version: {e}"), Vec::new()))?; + let version_interceptor = ClientVersionInterceptor::new(version); + + let mut client = + ProxyClient::with_interceptor(endpoint.connect_lazy(), move |req: Request<()>| { + version_interceptor.clone().call(req) + }); + + let mut stream = client + .trigger_acme(AcmeChallenge { + domain: domain.clone(), + account_credentials_json, + }) + .await + .map_err(|e| (format!("TriggerAcme RPC failed: {e}"), Vec::new()))? + .into_inner(); + + let mut collected_logs: Vec = Vec::new(); + + loop { + match stream.message().await { + Ok(Some(event)) => match event.payload { + Some(acme_issue_event::Payload::Progress(p)) => { + if let Ok(step) = AcmeStep::try_from(p.step) { + let _ = progress_tx.send(step); + } + } + Some(acme_issue_event::Payload::Certificate(cert)) => { + return Ok((cert.cert_pem, cert.key_pem, cert.account_credentials_json)); + } + Some(acme_issue_event::Payload::Logs(AcmeLogs { lines })) => { + collected_logs = lines; + } + None => { + return Err(( + "TriggerAcme stream sent an event with no payload".to_string(), + collected_logs, + )); + } + }, + Ok(None) => { + return Err(( + "TriggerAcme stream ended without delivering a certificate".to_string(), + collected_logs, + )); + } + Err(e) => { + return Err(( + format!("Failed to read TriggerAcme response: {e}"), + collected_logs, + )); + } + } + } +} + +#[cfg(test)] +mod tests { + use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + pin::Pin, + sync::Arc, + sync::Once, + time::Duration, + }; + + use defguard_certs::{CertificateAuthority, Csr, DnType, PemLabel, generate_key_pair}; + use defguard_common::{ + db::{ + models::{Certificates, ProxyCertSource, Settings, User, proxy::Proxy}, + setup_pool, + }, + secret::SecretStringWrapper, + types::proxy::ProxyControlMessage, + }; + use defguard_proto::proxy::{ + AcmeCertificate, AcmeIssueEvent, AcmeLogs, AcmeProgress, AcmeStep, proxy_server, + }; + use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; + use std::str::FromStr; + use tokio::{ + net::TcpListener, + sync::{Mutex, mpsc}, + task::JoinHandle, + time::{sleep, timeout}, + }; + use tokio_stream::{self as stream}; + use tonic::{ + Request, Response, Status, Streaming, + transport::{Identity, Server, ServerTlsConfig}, + }; + + use super::{ACME_TIMEOUT, LetsencryptError, do_letsencrypt_refresh}; + + const TEST_ACCOUNT_JSON: &str = r#"{"account_url":"https://acme.example/account/1"}"#; + + enum MockAcmeBehavior { + Success { + cert_pem: String, + key_pem: String, + account_credentials_json: String, + logs: Vec, + }, + RpcError(Status), + Hang, + } + + struct MockProxyService { + behavior: Arc>, + } + + #[tonic::async_trait] + impl proxy_server::Proxy for MockProxyService { + type BidiStream = Pin< + Box< + dyn tokio_stream::Stream> + + Send, + >, + >; + type TriggerAcmeStream = + Pin> + Send>>; + + async fn bidi( + &self, + _request: Request>, + ) -> Result, Status> { + Ok(Response::new(Box::pin(stream::empty()))) + } + + async fn purge(&self, _request: Request<()>) -> Result, Status> { + Ok(Response::new(())) + } + + async fn trigger_acme( + &self, + _request: Request, + ) -> Result, Status> { + let behavior = self.behavior.lock().await; + match &*behavior { + MockAcmeBehavior::Success { + cert_pem, + key_pem, + account_credentials_json, + logs, + } => { + let mut events = vec![Ok(AcmeIssueEvent { + payload: Some(defguard_proto::proxy::acme_issue_event::Payload::Progress( + AcmeProgress { + step: AcmeStep::CheckingDomain as i32, + }, + )), + })]; + if !logs.is_empty() { + events.push(Ok(AcmeIssueEvent { + payload: Some(defguard_proto::proxy::acme_issue_event::Payload::Logs( + AcmeLogs { + lines: logs.clone(), + }, + )), + })); + } + events.push(Ok(AcmeIssueEvent { + payload: Some( + defguard_proto::proxy::acme_issue_event::Payload::Certificate( + AcmeCertificate { + cert_pem: cert_pem.clone(), + key_pem: key_pem.clone(), + account_credentials_json: account_credentials_json.clone(), + }, + ), + ), + })); + Ok(Response::new(Box::pin(stream::iter(events)))) + } + MockAcmeBehavior::RpcError(status) => Err(status.clone()), + MockAcmeBehavior::Hang => Ok(Response::new(Box::pin(stream::pending::< + Result, + >()))), + } + } + } + + struct MockAcmeServer { + port: u16, + task: JoinHandle<()>, + } + + impl MockAcmeServer { + async fn start( + ca: &CertificateAuthority<'_>, + common_name: &str, + behavior: MockAcmeBehavior, + ) -> Self { + init_rustls_crypto_provider(); + let identity = make_server_identity(ca, common_name); + let listener = TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)) + .await + .expect("failed to bind mock ACME server"); + let port = listener.local_addr().expect("missing local addr").port(); + let service = MockProxyService { + behavior: Arc::new(Mutex::new(behavior)), + }; + let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener); + let task = tokio::spawn(async move { + Server::builder() + .tls_config(ServerTlsConfig::new().identity(identity)) + .expect("failed to configure TLS for mock ACME server") + .add_service(proxy_server::ProxyServer::new(service)) + .serve_with_incoming(incoming) + .await + .expect("mock ACME server failed"); + }); + + tokio::task::yield_now().await; + + Self { port, task } + } + } + + impl Drop for MockAcmeServer { + fn drop(&mut self) { + self.task.abort(); + } + } + + fn make_server_identity(ca: &CertificateAuthority<'_>, common_name: &str) -> Identity { + let key_pair = generate_key_pair().expect("failed to generate key pair"); + let san = vec![common_name.to_string()]; + let dn = vec![(DnType::CommonName, common_name)]; + let csr = Csr::new(&key_pair, &san, dn).expect("failed to create CSR"); + let cert = ca.sign_csr(&csr).expect("failed to sign server cert"); + let cert_pem = + defguard_certs::der_to_pem(cert.der(), PemLabel::Certificate).expect("cert PEM"); + let key_pem = + defguard_certs::der_to_pem(key_pair.serialize_der().as_slice(), PemLabel::PrivateKey) + .expect("key PEM"); + Identity::from_pem(cert_pem, key_pem) + } + + fn init_rustls_crypto_provider() { + static INIT: Once = Once::new(); + INIT.call_once(|| { + rustls::crypto::ring::default_provider() + .install_default() + .ok(); + }); + } + + async fn seed_settings(pool: &sqlx::PgPool, hostname: &str) { + defguard_common::db::models::settings::initialize_current_settings(pool) + .await + .expect("failed to initialize settings"); + let mut settings = Settings::get_current_settings(); + settings.public_proxy_url = format!("https://{hostname}"); + settings.smtp_server = Some("smtp.example.com".into()); + settings.smtp_port = Some(587); + settings.smtp_sender = Some("noreply@example.com".into()); + settings.smtp_user = Some(String::new()); + settings.smtp_password = Some(SecretStringWrapper::from_str("").unwrap()); + defguard_common::db::models::settings::set_settings(Some(settings)); + } + + async fn seed_admin(pool: &sqlx::PgPool) { + let _ = User::new("admin", None, "Admin", "User", "admin@example.com", None) + .save(pool) + .await + .expect("failed to save admin user"); + } + + fn make_ca() -> CertificateAuthority<'static> { + CertificateAuthority::new("Test CA", "test@example.com", 365).expect("failed to create CA") + } + + async fn seed_ca(pool: &sqlx::PgPool, ca: &CertificateAuthority<'_>) { + Certificates { + ca_cert_der: Some(ca.cert_der().to_vec()), + ca_key_der: Some(ca.key_pair_der().to_vec()), + ca_expiry: Some(ca.expiry().expect("missing CA expiry")), + ..Default::default() + } + .save(pool) + .await + .expect("failed to save CA certs"); + } + + async fn seed_letsencrypt_cert( + pool: &sqlx::PgPool, + ca: &CertificateAuthority<'_>, + common_name: &str, + valid_for_days: i64, + ) { + let key_pair = generate_key_pair().expect("failed to generate key pair"); + let san = vec![common_name.to_string()]; + let dn = vec![(DnType::CommonName, common_name)]; + let csr = Csr::new(&key_pair, &san, dn).expect("failed to create CSR"); + let cert = ca + .sign_csr_with_validity(&csr, valid_for_days) + .expect("failed to sign cert"); + let cert_pem = + defguard_certs::der_to_pem(cert.der(), PemLabel::Certificate).expect("cert PEM"); + let key_pem = + defguard_certs::der_to_pem(key_pair.serialize_der().as_slice(), PemLabel::PrivateKey) + .expect("key PEM"); + let expiry = super::parse_cert_expiry(&cert_pem).expect("expected cert expiry"); + + let mut certs = Certificates::get_or_default(pool) + .await + .expect("failed to load certificates"); + certs.proxy_http_cert_source = ProxyCertSource::LetsEncrypt; + certs.proxy_http_cert_pem = Some(cert_pem); + certs.proxy_http_cert_key_pem = Some(key_pem); + certs.proxy_http_cert_expiry = Some(expiry); + certs.acme_account_credentials = Some(TEST_ACCOUNT_JSON.to_string()); + certs.save(pool).await.expect("failed to save LE certs"); + } + + async fn create_proxy(pool: &sqlx::PgPool, address: &str, port: u16) { + let mut proxy = Proxy::new("test-proxy", address, i32::from(port), "tester"); + proxy.enabled = true; + proxy.save(pool).await.expect("failed to save proxy"); + } + + async fn drain_broadcasts( + rx: &mut mpsc::Receiver, + ) -> Vec<(String, String)> { + sleep(Duration::from_millis(50)).await; + let mut broadcasts = Vec::new(); + while let Ok(message) = rx.try_recv() { + if let ProxyControlMessage::BroadcastHttpsCerts { cert_pem, key_pem } = message { + broadcasts.push((cert_pem, key_pem)); + } + } + broadcasts + } + + #[sqlx::test] + async fn letsencrypt_refresh_skips_when_certificate_not_due( + _: PgPoolOptions, + options: PgConnectOptions, + ) { + let pool = setup_pool(options).await; + let ca = make_ca(); + seed_settings(&pool, "refresh.example.com").await; + seed_ca(&pool, &ca).await; + seed_letsencrypt_cert(&pool, &ca, "refresh.example.com", 89).await; + + let certs_before = Certificates::get_or_default(&pool) + .await + .expect("failed to load certificates"); + + let (proxy_control_tx, mut proxy_control_rx) = mpsc::channel(8); + let result = do_letsencrypt_refresh(&pool, proxy_control_tx).await; + + assert!(result.is_ok(), "expected skip to succeed, got {result:?}"); + + let certs_after = Certificates::get_or_default(&pool) + .await + .expect("failed to reload certificates"); + assert_eq!( + certs_after.proxy_http_cert_pem, + certs_before.proxy_http_cert_pem + ); + assert_eq!( + certs_after.proxy_http_cert_key_pem, + certs_before.proxy_http_cert_key_pem + ); + assert!(drain_broadcasts(&mut proxy_control_rx).await.is_empty()); + } + + #[sqlx::test] + async fn letsencrypt_refresh_returns_no_proxy_found_when_due( + _: PgPoolOptions, + options: PgConnectOptions, + ) { + let pool = setup_pool(options).await; + let ca = make_ca(); + seed_settings(&pool, "refresh.example.com").await; + seed_ca(&pool, &ca).await; + seed_letsencrypt_cert(&pool, &ca, "refresh.example.com", 1).await; + + let (proxy_control_tx, _proxy_control_rx) = mpsc::channel(8); + let result = do_letsencrypt_refresh(&pool, proxy_control_tx).await; + + assert!(matches!(result, Err(LetsencryptError::NoProxyFound))); + } + + #[sqlx::test] + async fn letsencrypt_refresh_success_persists_certificate_and_broadcasts( + _: PgPoolOptions, + options: PgConnectOptions, + ) { + let pool = setup_pool(options).await; + let ca = make_ca(); + seed_settings(&pool, "localhost").await; + seed_ca(&pool, &ca).await; + seed_letsencrypt_cert(&pool, &ca, "localhost", 1).await; + + let (new_cert_pem, new_key_pem) = { + let key_pair = generate_key_pair().expect("failed to generate key pair"); + let san = vec!["localhost".to_string()]; + let dn = vec![(DnType::CommonName, "localhost")]; + let csr = Csr::new(&key_pair, &san, dn).expect("failed to create CSR"); + let cert = ca.sign_csr(&csr).expect("failed to sign cert"); + ( + defguard_certs::der_to_pem(cert.der(), PemLabel::Certificate).expect("cert PEM"), + defguard_certs::der_to_pem( + key_pair.serialize_der().as_slice(), + PemLabel::PrivateKey, + ) + .expect("key PEM"), + ) + }; + + let mock_server = MockAcmeServer::start( + &ca, + "localhost", + MockAcmeBehavior::Success { + cert_pem: new_cert_pem.clone(), + key_pem: new_key_pem.clone(), + account_credentials_json: r#"{"account_url":"https://acme.example/account/2"}"# + .to_string(), + logs: vec!["proxy log line".to_string()], + }, + ) + .await; + create_proxy(&pool, "localhost", mock_server.port).await; + + let (proxy_control_tx, mut proxy_control_rx) = mpsc::channel(8); + let result = do_letsencrypt_refresh(&pool, proxy_control_tx).await; + + assert!( + result.is_ok(), + "expected successful refresh, got {result:?}" + ); + + let certs = Certificates::get_or_default(&pool) + .await + .expect("failed to reload certificates"); + assert_eq!( + certs.proxy_http_cert_pem.as_deref(), + Some(new_cert_pem.as_str()) + ); + assert_eq!( + certs.proxy_http_cert_key_pem.as_deref(), + Some(new_key_pem.as_str()) + ); + assert_eq!( + certs.acme_account_credentials.as_deref(), + Some(r#"{"account_url":"https://acme.example/account/2"}"#) + ); + assert_eq!(certs.acme_domain.as_deref(), Some("localhost")); + assert_eq!(certs.proxy_http_cert_source, ProxyCertSource::LetsEncrypt); + assert!(certs.proxy_http_cert_expiry.is_some()); + + let broadcasts = drain_broadcasts(&mut proxy_control_rx).await; + assert_eq!(broadcasts.len(), 1); + assert_eq!(broadcasts[0].0, new_cert_pem); + assert_eq!(broadcasts[0].1, new_key_pem); + } + + #[sqlx::test] + async fn letsencrypt_refresh_returns_acme_issuance_failed_on_rpc_error( + _: PgPoolOptions, + options: PgConnectOptions, + ) { + let pool = setup_pool(options).await; + let ca = make_ca(); + seed_settings(&pool, "localhost").await; + seed_ca(&pool, &ca).await; + seed_admin(&pool).await; + seed_letsencrypt_cert(&pool, &ca, "localhost", 1).await; + + let mock_server = MockAcmeServer::start( + &ca, + "localhost", + MockAcmeBehavior::RpcError(Status::unavailable("rpc unavailable")), + ) + .await; + create_proxy(&pool, "localhost", mock_server.port).await; + + let (proxy_control_tx, _proxy_control_rx) = mpsc::channel(8); + let result = do_letsencrypt_refresh(&pool, proxy_control_tx).await; + + assert!(matches!( + result, + Err(LetsencryptError::AcmeIssuanceFailed(message)) if message.contains("TriggerAcme RPC failed") + )); + } + + #[sqlx::test] + async fn letsencrypt_refresh_returns_acme_timed_out_when_stream_hangs( + _: PgPoolOptions, + options: PgConnectOptions, + ) { + let pool = setup_pool(options).await; + let ca = make_ca(); + seed_settings(&pool, "localhost").await; + seed_ca(&pool, &ca).await; + seed_admin(&pool).await; + seed_letsencrypt_cert(&pool, &ca, "localhost", 1).await; + + let mock_server = MockAcmeServer::start(&ca, "localhost", MockAcmeBehavior::Hang).await; + create_proxy(&pool, "localhost", mock_server.port).await; + + let (proxy_control_tx, _proxy_control_rx) = mpsc::channel(8); + let result = timeout( + ACME_TIMEOUT + Duration::from_secs(5), + do_letsencrypt_refresh(&pool, proxy_control_tx), + ) + .await + .expect("refresh should finish before outer timeout"); + + assert!(matches!( + result, + Err(LetsencryptError::AcmeTimedOut { timeout }) if timeout == ACME_TIMEOUT + )); + + drop(mock_server); + sleep(Duration::from_millis(50)).await; + } +} diff --git a/crates/defguard_core/src/lib.rs b/crates/defguard_core/src/lib.rs index 5089f4ce2..204dc7a96 100644 --- a/crates/defguard_core/src/lib.rs +++ b/crates/defguard_core/src/lib.rs @@ -186,6 +186,7 @@ pub mod events; pub mod grpc; pub mod handlers; pub mod headers; +pub mod letsencrypt; pub mod location_management; pub mod setup_logs; pub mod support; diff --git a/crates/defguard_core/src/utility_thread.rs b/crates/defguard_core/src/utility_thread.rs index 7f5c8cfd1..83ad4c28a 100644 --- a/crates/defguard_core/src/utility_thread.rs +++ b/crates/defguard_core/src/utility_thread.rs @@ -1,14 +1,17 @@ use std::{collections::HashSet, time::Duration}; use chrono::{NaiveDateTime, TimeDelta, Utc}; -use defguard_common::db::models::{ - Certificates, CoreCertSource, ProxyCertSource, User, WireguardNetwork, - wireguard::ServiceLocationMode, +use defguard_common::{ + db::models::{ + Certificates, CoreCertSource, ProxyCertSource, User, WireguardNetwork, + wireguard::ServiceLocationMode, + }, + types::proxy::ProxyControlMessage, }; use defguard_mail::templates; use sqlx::{PgConnection, PgPool, query_as}; use tokio::{ - sync::broadcast::Sender, + sync::{broadcast::Sender, mpsc}, time::{Instant, sleep}, }; use tracing::Instrument; @@ -23,6 +26,7 @@ use crate::{ limits::update_counts, }, grpc::GatewayEvent, + letsencrypt::do_letsencrypt_refresh, location_management::allowed_peers::get_location_allowed_peers, updates::do_new_version_check, }; @@ -33,6 +37,7 @@ const COUNT_UPDATE_INTERVAL: u64 = 60 * 60; const UPDATES_CHECK_INTERVAL: u64 = 60 * 60 * 6; const EXPIRED_ACL_RULES_CHECK_INTERVAL: u64 = 60 * 5; const ENTERPRISE_STATUS_CHECK_INTERVAL: u64 = 60 * 5; +const LETSENCRYPT_EXPIRY_CHECK_INTERVAL: u64 = 60 * 60 * 24; const CERTIFICATE_EXPIRY_CHECK_INTERVAL: u64 = 60 * 60 * 24; // 1 day const ACL_EXPIRY_SYSTEM_ACTOR: &str = "system:acl-expiry"; @@ -40,6 +45,7 @@ const ACL_EXPIRY_SYSTEM_ACTOR: &str = "system:acl-expiry"; pub async fn run_utility_thread( pool: &PgPool, wireguard_tx: Sender, + proxy_control_tx: mpsc::Sender, ) -> Result<(), anyhow::Error> { let mut last_count_update = Instant::now(); let mut last_directory_sync = Instant::now(); @@ -47,6 +53,7 @@ pub async fn run_utility_thread( let mut last_ldap_sync = Instant::now(); let mut last_expired_acl_rules_check = Instant::now(); let mut last_enterprise_status_check = Instant::now(); + let mut last_letsencrypt_expiry_check = Instant::now(); let mut last_certificate_check = Instant::now(); // helper variable which stores previous enterprise features status @@ -98,11 +105,21 @@ pub async fn run_utility_thread( } }; + let letsencrypt_refresh_task = || async { + if let Err(e) = do_letsencrypt_refresh(pool, proxy_control_tx.clone()) + .instrument(info_span!("letsencrypt_refresh_task")) + .await + { + error!("There was an error while performing letsencrypt refresh task: {e}"); + } + }; + directory_sync_task().await; count_update_task().await; updates_check_task().await; ldap_sync_task().await; expired_acl_rules_task().await; + letsencrypt_refresh_task().await; check_certificates(pool).await; loop { @@ -138,27 +155,32 @@ pub async fn run_utility_thread( last_expired_acl_rules_check = Instant::now(); } + // Check LE cert expiry dates and refresh if necessary + if last_letsencrypt_expiry_check.elapsed().as_secs() >= LETSENCRYPT_EXPIRY_CHECK_INTERVAL { + letsencrypt_refresh_task().await; + last_letsencrypt_expiry_check = Instant::now(); + } + // Check if enterprise features got enabled or disabled if last_enterprise_status_check.elapsed().as_secs() >= ENTERPRISE_STATUS_CHECK_INTERVAL { let new_enterprise_enabled = is_business_license_active(); - if new_enterprise_enabled == enterprise_enabled { - continue; - } - debug!( - "Enterprise feature status changed from {enterprise_enabled} to \ - {new_enterprise_enabled}" - ); - if let Err(err) = - enterprise_status_check(pool, wireguard_tx.clone(), new_enterprise_enabled) - .instrument(info_span!("enterprise_status_check")) - .await - { - error!("Failed to check enterprise status: {err}"); - } else { - // update status - enterprise_enabled = new_enterprise_enabled; - } last_enterprise_status_check = Instant::now(); + if new_enterprise_enabled != enterprise_enabled { + debug!( + "Enterprise feature status changed from {enterprise_enabled} to \ + {new_enterprise_enabled}" + ); + if let Err(err) = + enterprise_status_check(pool, wireguard_tx.clone(), new_enterprise_enabled) + .instrument(info_span!("enterprise_status_check")) + .await + { + error!("Failed to check enterprise status: {err}"); + } else { + // update status + enterprise_enabled = new_enterprise_enabled; + } + } } // Check certificates. diff --git a/crates/defguard_mail/src/mail.rs b/crates/defguard_mail/src/mail.rs index 407d9636e..1d4f2d1fe 100644 --- a/crates/defguard_mail/src/mail.rs +++ b/crates/defguard_mail/src/mail.rs @@ -300,6 +300,8 @@ pub enum MailMessage { UserImportBlocked, /// Enrollment notification for admins. EnrollmentNotification, + /// Letsencrypt certificate refresh failed. + LetsencryptCertRefreshFailed, CertificateExpiration, CertificateExpired, } @@ -334,6 +336,9 @@ impl MailMessage { Self::PasswordResetDone => "Defguard: Password reset success".to_string(), Self::UserImportBlocked => "User import blocked".to_string(), Self::EnrollmentNotification => "Defguard: User enrollment completed".to_string(), + Self::LetsencryptCertRefreshFailed => { + "Defguard: automatic Let's Encrypt certificate refresh failed".to_string() + } Self::CertificateExpiration => "Defguard: Certificate expiration".to_string(), Self::CertificateExpired => "Defguard: Certificate has expired".to_string(), } @@ -358,6 +363,7 @@ impl MailMessage { Self::PasswordResetDone => "password-reset-done", Self::UserImportBlocked => "user-import-blocked", Self::EnrollmentNotification => "enrollment-admin-notification", + Self::LetsencryptCertRefreshFailed => "letsencrypt-cert-refresh-failed", Self::CertificateExpiration => "certificate-expiration", Self::CertificateExpired => "certificate-expired", } @@ -384,6 +390,9 @@ impl MailMessage { Self::EnrollmentNotification => { include_str!("../templates/enrollment-admin-notification.mjml") } + Self::LetsencryptCertRefreshFailed => { + include_str!("../templates/letsencrypt-cert-refresh-failed.mjml") + } Self::CertificateExpiration | Self::CertificateExpired => { include_str!("../templates/certificate-expiration.mjml") } @@ -411,6 +420,9 @@ impl MailMessage { Self::EnrollmentNotification => { include_str!("../templates/enrollment-admin-notification.text") } + Self::LetsencryptCertRefreshFailed => { + include_str!("../templates/letsencrypt-cert-refresh-failed.text") + } Self::CertificateExpiration | Self::CertificateExpired => { include_str!("../templates/certificate-expiration.text") } diff --git a/crates/defguard_mail/src/templates.rs b/crates/defguard_mail/src/templates.rs index 5ce18eef7..b383364ae 100644 --- a/crates/defguard_mail/src/templates.rs +++ b/crates/defguard_mail/src/templates.rs @@ -363,6 +363,31 @@ pub async fn gateway_disconnected_mail( Ok(()) } +/// Notification about failed Letsencrypt cert refresh process. +pub async fn letsencrypt_cert_refresh_failed_mail( + to: &str, + conn: &mut PgConnection, + error_message: &str, + logs: &str, +) -> Result<(), TemplateError> { + let (mut tera, mut context) = get_base_tera_mjml(Context::new(), None, None, None)?; + context.insert("error_message", error_message); + + let now = Utc::now(); + let attachment = Attachment::new( + format!("defguard-letsencrypt-refresh-logs-{now}.txt"), + logs.into(), + ); + let message = MailMessage::LetsencryptCertRefreshFailed; + message.fill_context(conn, &mut context).await?; + message + .mail(&mut tera, &context, to)? + .set_attachments(vec![attachment]) + .send_and_forget(); + + Ok(()) +} + /// Notification about reconnected Gateway. pub async fn gateway_reconnected_mail( to: &str, diff --git a/crates/defguard_mail/templates/letsencrypt-cert-refresh-failed.mjml b/crates/defguard_mail/templates/letsencrypt-cert-refresh-failed.mjml new file mode 100644 index 000000000..d009ce228 --- /dev/null +++ b/crates/defguard_mail/templates/letsencrypt-cert-refresh-failed.mjml @@ -0,0 +1,23 @@ +{% import "macros.mjml" as macros %} +{% extends "base.mjml" %} +{% block content %} + +{{ macros::email_header() }} + + + +

+ {{ content }} +

+ {% if error_message %} +

+ Error: {{ error_message }} +

+ {% endif %} +
+
+
+ +{{ macros::footer_divider() }} + +{% endblock content %} diff --git a/crates/defguard_mail/templates/letsencrypt-cert-refresh-failed.text b/crates/defguard_mail/templates/letsencrypt-cert-refresh-failed.text new file mode 100644 index 000000000..a027aa331 --- /dev/null +++ b/crates/defguard_mail/templates/letsencrypt-cert-refresh-failed.text @@ -0,0 +1,5 @@ +{{ title }} +{% if content %}{{ content }}{% endif %} +{% if error_message %} +Error: {{ error_message }} +{% endif %} diff --git a/migrations/20260417073540_[2.0.0]_letsencrypt_cert_refresh.down.sql b/migrations/20260417073540_[2.0.0]_letsencrypt_cert_refresh.down.sql new file mode 100644 index 000000000..f409fb742 --- /dev/null +++ b/migrations/20260417073540_[2.0.0]_letsencrypt_cert_refresh.down.sql @@ -0,0 +1 @@ +DELETE FROM mail_context WHERE "template" = 'letsencrypt-cert-refresh-failed'; diff --git a/migrations/20260417073540_[2.0.0]_letsencrypt_cert_refresh.up.sql b/migrations/20260417073540_[2.0.0]_letsencrypt_cert_refresh.up.sql new file mode 100644 index 000000000..7594b1072 --- /dev/null +++ b/migrations/20260417073540_[2.0.0]_letsencrypt_cert_refresh.up.sql @@ -0,0 +1,3 @@ +INSERT INTO mail_context (template, section, language_tag, text) VALUES + ('letsencrypt-cert-refresh-failed', 'title', 'en_US', 'Let''s Encrypt certificate refresh failed'), + ('letsencrypt-cert-refresh-failed', 'content', 'en_US', 'Automatic Let''s Encrypt certificate refresh has failed. Please verify your Edge setup.');