diff --git a/.sqlx/query-3f8076c26bfcad6ee3a9fbfd29c0f8da0acececc3d0442ee0762f489c95b9b71.json b/.sqlx/query-10a410ef8113994c19f1fd3c2c77fcb43d3be0bc795a3839207a11c9f24bf670.json similarity index 63% rename from .sqlx/query-3f8076c26bfcad6ee3a9fbfd29c0f8da0acececc3d0442ee0762f489c95b9b71.json rename to .sqlx/query-10a410ef8113994c19f1fd3c2c77fcb43d3be0bc795a3839207a11c9f24bf670.json index f4fd9b8235..996b13d997 100644 --- a/.sqlx/query-3f8076c26bfcad6ee3a9fbfd29c0f8da0acececc3d0442ee0762f489c95b9b71.json +++ b/.sqlx/query-10a410ef8113994c19f1fd3c2c77fcb43d3be0bc795a3839207a11c9f24bf670.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "INSERT INTO \"proxy\" (\"name\",\"address\",\"port\",\"connected_at\",\"disconnected_at\",\"version\",\"has_certificate\",\"certificate_expiry\",\"modified_at\",\"modified_by\") VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10) RETURNING id", + "query": "INSERT INTO \"proxy\" (\"name\",\"address\",\"port\",\"connected_at\",\"disconnected_at\",\"version\",\"certificate\",\"certificate_expiry\",\"modified_at\",\"modified_by\") VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10) RETURNING id", "describe": { "columns": [ { @@ -17,7 +17,7 @@ "Timestamp", "Timestamp", "Text", - "Bool", + "Text", "Timestamp", "Timestamp", "Int8" @@ -27,5 +27,5 @@ false ] }, - "hash": "3f8076c26bfcad6ee3a9fbfd29c0f8da0acececc3d0442ee0762f489c95b9b71" + "hash": "10a410ef8113994c19f1fd3c2c77fcb43d3be0bc795a3839207a11c9f24bf670" } diff --git a/.sqlx/query-1e48e6b87c058b8dbc54b86c704ee3ecbfdfbd69f3d44e16d9a6e9ef25069614.json b/.sqlx/query-1e48e6b87c058b8dbc54b86c704ee3ecbfdfbd69f3d44e16d9a6e9ef25069614.json index cfc0b64f12..48914cbf1d 100644 --- a/.sqlx/query-1e48e6b87c058b8dbc54b86c704ee3ecbfdfbd69f3d44e16d9a6e9ef25069614.json +++ b/.sqlx/query-1e48e6b87c058b8dbc54b86c704ee3ecbfdfbd69f3d44e16d9a6e9ef25069614.json @@ -35,29 +35,29 @@ }, { "ordinal": 6, - "name": "has_certificate", - "type_info": "Bool" - }, - { - "ordinal": 7, "name": "certificate_expiry", "type_info": "Timestamp" }, { - "ordinal": 8, + "ordinal": 7, "name": "version", "type_info": "Text" }, { - "ordinal": 9, + "ordinal": 8, "name": "modified_at", "type_info": "Timestamp" }, { - "ordinal": 10, + "ordinal": 9, "name": "modified_by", "type_info": "Int8" }, + { + "ordinal": 10, + "name": "certificate", + "type_info": "Text" + }, { "ordinal": 11, "name": "modified_by_firstname", @@ -79,11 +79,11 @@ false, true, true, - false, true, true, false, false, + true, false, false ] diff --git a/.sqlx/query-3f3bd3e155ad5d4a2dae25ea4dda093f59342d6a6a7f7ac52650dade4a0a4f3e.json b/.sqlx/query-2ca67025b051148efdb9e00e4bb48b883d72bc6c8f481ae2734b8b6fd25977ac.json similarity index 82% rename from .sqlx/query-3f3bd3e155ad5d4a2dae25ea4dda093f59342d6a6a7f7ac52650dade4a0a4f3e.json rename to .sqlx/query-2ca67025b051148efdb9e00e4bb48b883d72bc6c8f481ae2734b8b6fd25977ac.json index 25df208cda..057c695d95 100644 --- a/.sqlx/query-3f3bd3e155ad5d4a2dae25ea4dda093f59342d6a6a7f7ac52650dade4a0a4f3e.json +++ b/.sqlx/query-2ca67025b051148efdb9e00e4bb48b883d72bc6c8f481ae2734b8b6fd25977ac.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "SELECT id, \"name\",\"address\",\"port\",\"connected_at\",\"disconnected_at\",\"version\",\"has_certificate\",\"certificate_expiry\",\"modified_at\",\"modified_by\" FROM \"proxy\"", + "query": "SELECT id, \"name\",\"address\",\"port\",\"connected_at\",\"disconnected_at\",\"version\",\"certificate\",\"certificate_expiry\",\"modified_at\",\"modified_by\" FROM \"proxy\"", "describe": { "columns": [ { @@ -40,8 +40,8 @@ }, { "ordinal": 7, - "name": "has_certificate", - "type_info": "Bool" + "name": "certificate", + "type_info": "Text" }, { "ordinal": 8, @@ -70,11 +70,11 @@ true, true, true, - false, + true, true, false, false ] }, - "hash": "3f3bd3e155ad5d4a2dae25ea4dda093f59342d6a6a7f7ac52650dade4a0a4f3e" + "hash": "2ca67025b051148efdb9e00e4bb48b883d72bc6c8f481ae2734b8b6fd25977ac" } diff --git a/.sqlx/query-a41787c8c8307414165ab23ef96d82a34d3bfa4364cbe9b8368e71445bc20877.json b/.sqlx/query-a41787c8c8307414165ab23ef96d82a34d3bfa4364cbe9b8368e71445bc20877.json index 074f3b32ac..4c6244cf54 100644 --- a/.sqlx/query-a41787c8c8307414165ab23ef96d82a34d3bfa4364cbe9b8368e71445bc20877.json +++ b/.sqlx/query-a41787c8c8307414165ab23ef96d82a34d3bfa4364cbe9b8368e71445bc20877.json @@ -35,28 +35,28 @@ }, { "ordinal": 6, - "name": "has_certificate", - "type_info": "Bool" - }, - { - "ordinal": 7, "name": "certificate_expiry", "type_info": "Timestamp" }, { - "ordinal": 8, + "ordinal": 7, "name": "version", "type_info": "Text" }, { - "ordinal": 9, + "ordinal": 8, "name": "modified_at", "type_info": "Timestamp" }, { - "ordinal": 10, + "ordinal": 9, "name": "modified_by", "type_info": "Int8" + }, + { + "ordinal": 10, + "name": "certificate", + "type_info": "Text" } ], "parameters": { @@ -72,11 +72,11 @@ false, true, true, - false, true, true, false, - false + false, + true ] }, "hash": "a41787c8c8307414165ab23ef96d82a34d3bfa4364cbe9b8368e71445bc20877" diff --git a/.sqlx/query-74c90a33b88a0e69f222147d29e83af76508654b4d965e9f0f3862fe36a9aa6c.json b/.sqlx/query-d92e63295aa9d2302d5e4a8d205b35ab98de051c9aa9e5932b2c03a6118c2587.json similarity index 67% rename from .sqlx/query-74c90a33b88a0e69f222147d29e83af76508654b4d965e9f0f3862fe36a9aa6c.json rename to .sqlx/query-d92e63295aa9d2302d5e4a8d205b35ab98de051c9aa9e5932b2c03a6118c2587.json index 49a1da45f4..b459ef3b3f 100644 --- a/.sqlx/query-74c90a33b88a0e69f222147d29e83af76508654b4d965e9f0f3862fe36a9aa6c.json +++ b/.sqlx/query-d92e63295aa9d2302d5e4a8d205b35ab98de051c9aa9e5932b2c03a6118c2587.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "UPDATE \"proxy\" SET \"name\" = $2,\"address\" = $3,\"port\" = $4,\"connected_at\" = $5,\"disconnected_at\" = $6,\"version\" = $7,\"has_certificate\" = $8,\"certificate_expiry\" = $9,\"modified_at\" = $10,\"modified_by\" = $11 WHERE id = $1", + "query": "UPDATE \"proxy\" SET \"name\" = $2,\"address\" = $3,\"port\" = $4,\"connected_at\" = $5,\"disconnected_at\" = $6,\"version\" = $7,\"certificate\" = $8,\"certificate_expiry\" = $9,\"modified_at\" = $10,\"modified_by\" = $11 WHERE id = $1", "describe": { "columns": [], "parameters": { @@ -12,7 +12,7 @@ "Timestamp", "Timestamp", "Text", - "Bool", + "Text", "Timestamp", "Timestamp", "Int8" @@ -20,5 +20,5 @@ }, "nullable": [] }, - "hash": "74c90a33b88a0e69f222147d29e83af76508654b4d965e9f0f3862fe36a9aa6c" + "hash": "d92e63295aa9d2302d5e4a8d205b35ab98de051c9aa9e5932b2c03a6118c2587" } diff --git a/.sqlx/query-1762150f43613af5bda08175ff300da6c2fa0ffc2c48df0eae6584dbb38ed7cf.json b/.sqlx/query-fd05345c81860068b5013a07ca9187c2b96d0319ba0604c0313055eb1b2eea31.json similarity index 82% rename from .sqlx/query-1762150f43613af5bda08175ff300da6c2fa0ffc2c48df0eae6584dbb38ed7cf.json rename to .sqlx/query-fd05345c81860068b5013a07ca9187c2b96d0319ba0604c0313055eb1b2eea31.json index befce31b56..023832a352 100644 --- a/.sqlx/query-1762150f43613af5bda08175ff300da6c2fa0ffc2c48df0eae6584dbb38ed7cf.json +++ b/.sqlx/query-fd05345c81860068b5013a07ca9187c2b96d0319ba0604c0313055eb1b2eea31.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "SELECT id, \"name\",\"address\",\"port\",\"connected_at\",\"disconnected_at\",\"version\",\"has_certificate\",\"certificate_expiry\",\"modified_at\",\"modified_by\" FROM \"proxy\" WHERE id = $1", + "query": "SELECT id, \"name\",\"address\",\"port\",\"connected_at\",\"disconnected_at\",\"version\",\"certificate\",\"certificate_expiry\",\"modified_at\",\"modified_by\" FROM \"proxy\" WHERE id = $1", "describe": { "columns": [ { @@ -40,8 +40,8 @@ }, { "ordinal": 7, - "name": "has_certificate", - "type_info": "Bool" + "name": "certificate", + "type_info": "Text" }, { "ordinal": 8, @@ -72,11 +72,11 @@ true, true, true, - false, + true, true, false, false ] }, - "hash": "1762150f43613af5bda08175ff300da6c2fa0ffc2c48df0eae6584dbb38ed7cf" + "hash": "fd05345c81860068b5013a07ca9187c2b96d0319ba0604c0313055eb1b2eea31" } diff --git a/Cargo.lock b/Cargo.lock index ecf19e698d..42be2db151 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -310,6 +310,28 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "aws-lc-rs" +version = "1.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a88aab2464f1f25453baa7a07c84c5b7684e274054ba06817f382357f77a288" +dependencies = [ + "aws-lc-sys", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.35.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b45afffdee1e7c9126814751f88dddc747f41d91da16c9551a0f1e8a11e788a1" +dependencies = [ + "cc", + "cmake", + "dunce", + "fs_extra", +] + [[package]] name = "axum" version = "0.8.8" @@ -740,6 +762,15 @@ dependencies = [ "digest", ] +[[package]] +name = "cmake" +version = "0.1.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b042e5d8a74ae91bb0961acd039822472ec99f8ab0948cbf6d1369588f8be586" +dependencies = [ + "cc", +] + [[package]] name = "colorchoice" version = "1.0.4" @@ -1126,7 +1157,6 @@ version = "0.0.0" dependencies = [ "anyhow", "bytes", - "defguard_certs", "defguard_common", "defguard_core", "defguard_event_logger", @@ -1152,7 +1182,6 @@ dependencies = [ "chrono", "rcgen", "rustls-pki-types", - "serde", "sqlx", "thiserror 2.0.18", "time", @@ -1260,7 +1289,6 @@ dependencies = [ "tokio-util", "tonic", "tonic-health", - "tonic-prost", "tonic-prost-build", "totp-lite", "tower", @@ -1365,8 +1393,12 @@ dependencies = [ "defguard_mail", "defguard_proto", "defguard_version", + "http", + "hyper", + "hyper-rustls", "openidconnect", "reqwest", + "rustls", "secrecy", "semver", "sqlx", @@ -1374,7 +1406,9 @@ dependencies = [ "tokio", "tokio-stream", "tonic", + "tower-service", "tracing", + "x509-parser 0.18.1", ] [[package]] @@ -1434,7 +1468,6 @@ name = "defguard_vpn_stats_purge" version = "0.0.0" dependencies = [ "chrono", - "defguard_common", "humantime", "sqlx", "tokio", @@ -1680,6 +1713,12 @@ dependencies = [ "dtoa", ] +[[package]] +name = "dunce" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" + [[package]] name = "dyn-clone" version = "1.0.20" @@ -1938,6 +1977,12 @@ dependencies = [ "thiserror 1.0.69", ] +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + [[package]] name = "funty" version = "2.0.0" @@ -2414,7 +2459,9 @@ dependencies = [ "http", "hyper", "hyper-util", + "log", "rustls", + "rustls-native-certs", "rustls-pki-types", "tokio", "tokio-rustls", @@ -4599,6 +4646,7 @@ version = "0.23.36" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c665f33d38cea657d9614f766881e4d510e0eda4239891eea56b4cadcf01801b" dependencies = [ + "aws-lc-rs", "log", "once_cell", "ring", @@ -4636,6 +4684,7 @@ version = "0.103.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53" dependencies = [ + "aws-lc-rs", "ring", "rustls-pki-types", "untrusted", diff --git a/crates/defguard/Cargo.toml b/crates/defguard/Cargo.toml index c48794e088..54964ba30f 100644 --- a/crates/defguard/Cargo.toml +++ b/crates/defguard/Cargo.toml @@ -18,7 +18,6 @@ defguard_proxy_manager = { workspace = true } defguard_session_manager = { workspace = true } defguard_version = { workspace = true } defguard_vpn_stats_purge = { workspace = true } -defguard_certs = { workspace = true } defguard_setup = { workspace = true } # external dependencies diff --git a/crates/defguard_certs/Cargo.toml b/crates/defguard_certs/Cargo.toml index 9207838d3c..6ad44d1759 100644 --- a/crates/defguard_certs/Cargo.toml +++ b/crates/defguard_certs/Cargo.toml @@ -10,7 +10,6 @@ rust-version.workspace = true [dependencies] base64.workspace = true rcgen.workspace = true -serde.workspace = true sqlx.workspace = true thiserror.workspace = true rustls-pki-types.workspace = true diff --git a/crates/defguard_certs/src/lib.rs b/crates/defguard_certs/src/lib.rs index af7c825367..30abea7844 100644 --- a/crates/defguard_certs/src/lib.rs +++ b/crates/defguard_certs/src/lib.rs @@ -146,6 +146,7 @@ pub struct CertificateInfo { pub subject_common_name: String, pub not_before: NaiveDateTime, pub not_after: NaiveDateTime, + pub serial: String, } pub fn parse_certificate_info(cert_der: &[u8]) -> Result { @@ -153,6 +154,7 @@ pub fn parse_certificate_info(cert_der: &[u8]) -> Result Result, - // path to certificate `.pem` file used if connecting to proxy over HTTPS #[arg(long, env = "DEFGUARD_PROXY_GRPC_CA")] pub proxy_grpc_ca: Option, diff --git a/crates/defguard_common/src/db/models/proxy.rs b/crates/defguard_common/src/db/models/proxy.rs index 323d46af81..7685a65389 100644 --- a/crates/defguard_common/src/db/models/proxy.rs +++ b/crates/defguard_common/src/db/models/proxy.rs @@ -20,7 +20,7 @@ pub struct Proxy { pub connected_at: Option, pub disconnected_at: Option, pub version: Option, - pub has_certificate: bool, + pub certificate: Option, pub certificate_expiry: Option, pub modified_at: NaiveDateTime, pub modified_by: Id, @@ -55,7 +55,7 @@ impl Proxy { port, connected_at: None, disconnected_at: None, - has_certificate: false, + certificate: None, certificate_expiry: None, version: None, modified_by, diff --git a/crates/defguard_common/src/types/proxy.rs b/crates/defguard_common/src/types/proxy.rs index 41a155c858..0a4d1b5d0a 100644 --- a/crates/defguard_common/src/types/proxy.rs +++ b/crates/defguard_common/src/types/proxy.rs @@ -8,6 +8,7 @@ use crate::db::Id; pub enum ProxyControlMessage { StartConnection(Id), ShutdownConnection(Id), + Purge(Id), } #[derive(ToSchema, Serialize)] @@ -19,7 +20,7 @@ pub struct ProxyInfo { pub connected_at: Option, pub disconnected_at: Option, pub version: Option, - pub has_certificate: bool, + pub certificate: Option, pub certificate_expiry: Option, pub modified_at: NaiveDateTime, pub modified_by: Id, diff --git a/crates/defguard_core/Cargo.toml b/crates/defguard_core/Cargo.toml index 83a4706456..a8761e1093 100644 --- a/crates/defguard_core/Cargo.toml +++ b/crates/defguard_core/Cargo.toml @@ -66,7 +66,6 @@ tokio-stream = { workspace = true } tokio-util = { workspace = true } tonic = { workspace = true } tonic-health = { workspace = true } -tonic-prost.workspace = true totp-lite = { workspace = true } tower-http = { workspace = true } tracing = { workspace = true } diff --git a/crates/defguard_core/src/handlers/component_setup.rs b/crates/defguard_core/src/handlers/component_setup.rs index 54fbfec9d4..44dbdaf05e 100644 --- a/crates/defguard_core/src/handlers/component_setup.rs +++ b/crates/defguard_core/src/handlers/component_setup.rs @@ -521,6 +521,7 @@ pub async fn setup_proxy_tls_stream( let defguard_certs::CertificateInfo { not_after: expiry, + serial, .. } = match parse_certificate_info(cert.der()) { Ok(dt) => { @@ -541,7 +542,7 @@ pub async fn setup_proxy_tls_stream( session.user.id, ); - proxy.has_certificate = true; + proxy.certificate = Some(serial); proxy.certificate_expiry = Some(expiry); diff --git a/crates/defguard_core/src/handlers/proxy.rs b/crates/defguard_core/src/handlers/proxy.rs index 3cf8c3c55e..01810f4738 100644 --- a/crates/defguard_core/src/handlers/proxy.rs +++ b/crates/defguard_core/src/handlers/proxy.rs @@ -169,10 +169,10 @@ pub(crate) async fn delete_proxy( return Ok(ApiResponse::json(Value::Null, StatusCode::NOT_FOUND)); }; - // Disconnect the proxy + // Disconnect and purge the proxy if let Err(err) = appstate .proxy_control_tx - .send(ProxyControlMessage::ShutdownConnection(proxy.id)) + .send(ProxyControlMessage::Purge(proxy.id)) .await { error!( @@ -181,9 +181,6 @@ pub(crate) async fn delete_proxy( ); } - // TODO - // 1. Add proxy cert to CRL - // 2. Remove cert files on deleted proxy proxy.clone().delete(&appstate.pool).await?; info!("User {} deleted proxy {proxy_id}", session.user.username); diff --git a/crates/defguard_proxy_manager/Cargo.toml b/crates/defguard_proxy_manager/Cargo.toml index eec4a56c65..6719869278 100644 --- a/crates/defguard_proxy_manager/Cargo.toml +++ b/crates/defguard_proxy_manager/Cargo.toml @@ -28,3 +28,9 @@ thiserror.workspace = true tokio.workspace = true tonic.workspace = true tracing.workspace = true +hyper-rustls = { version = "0.27", features = ["http2"] } +rustls = { version = "0.23", features = ["ring"] } +x509-parser = "0.18" +http = "1.1" +hyper = "1.4" +tower-service = "0.3" diff --git a/crates/defguard_proxy_manager/src/certs.rs b/crates/defguard_proxy_manager/src/certs.rs new file mode 100644 index 0000000000..a1010dc4cf --- /dev/null +++ b/crates/defguard_proxy_manager/src/certs.rs @@ -0,0 +1,301 @@ +//! Custom TLS verification for proxy connections. +//! +//! Motivation: +//! - tonic/rustls does not fetch or enforce CRL distribution points, so revocation +//! has to be enforced by the application. +//! - We pin each proxy to its expected certificate serial and reject mismatches at +//! the TLS layer, before any gRPC requests are processed. +//! - A lightweight in-memory cache (refreshed periodically) avoids database access +//! during the handshake and keeps verification synchronous. + +use std::{collections::HashMap, sync::Arc}; + +use defguard_common::db::{Id, models::proxy::Proxy}; +use rustls::{ + CertificateError, DistinguishedName, Error as RustlsError, RootCertStore, SignatureScheme, + client::{ + WebPkiServerVerifier, + danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}, + }, + crypto, + pki_types::{CertificateDer, ServerName, UnixTime}, +}; +use sqlx::PgPool; +use tokio::sync::watch; +use x509_parser::parse_x509_certificate; + +use crate::error::ProxyError; + +/// Wraps WebPKI verification to enforce proxy-specific certificate serials. +#[derive(Debug)] +struct CertVerifier { + inner: Arc, + certs_rx: watch::Receiver>>, + proxy_id: Id, +} + +impl CertVerifier { + fn new( + inner: Arc, + certs_rx: watch::Receiver>>, + proxy_id: Id, + ) -> Self { + Self { + inner, + certs_rx, + proxy_id, + } + } + + /// Validate the peer certificate serial against the expected proxy serial. + fn verify(&self, end_entity: &CertificateDer<'_>) -> Result<(), RustlsError> { + let (_, cert) = parse_x509_certificate(end_entity.as_ref()) + .map_err(|_| RustlsError::InvalidCertificate(CertificateError::BadEncoding))?; + let serial = cert.tbs_certificate.raw_serial_as_string(); + let certs = self.certs_rx.borrow(); + let Some(expected) = certs.get(&self.proxy_id) else { + error!("Missing expected certificate for proxy: {}", self.proxy_id); + return Err(RustlsError::InvalidCertificate(CertificateError::Revoked)); + }; + if !expected.eq_ignore_ascii_case(&serial) { + error!( + "Invalid certificate for proxy {}: expected={expected} got={serial}", + self.proxy_id + ); + return Err(RustlsError::InvalidCertificate(CertificateError::Revoked)); + } + Ok(()) + } +} + +impl ServerCertVerifier for CertVerifier { + /// Delegate chain validation to WebPKI, then enforce the proxy-specific pin. + fn verify_server_cert( + &self, + end_entity: &CertificateDer<'_>, + intermediates: &[CertificateDer<'_>], + server_name: &ServerName<'_>, + ocsp_response: &[u8], + now: UnixTime, + ) -> Result { + self.inner.verify_server_cert( + end_entity, + intermediates, + server_name, + ocsp_response, + now, + )?; + self.verify(end_entity)?; + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + self.inner.verify_tls12_signature(message, cert, dss) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + self.inner.verify_tls13_signature(message, cert, dss) + } + + fn supported_verify_schemes(&self) -> Vec { + self.inner.supported_verify_schemes() + } + + fn root_hint_subjects(&self) -> Option<&[DistinguishedName]> { + self.inner.root_hint_subjects() + } +} + +/// Build a compact id->serial map, skipping proxies without a stored cert. +fn collect_certs(items: I) -> HashMap +where + I: IntoIterator)>, +{ + items + .into_iter() + .filter_map(|(id, cert)| cert.map(|cert| (id, cert))) + .collect() +} + +/// Refresh the cached cert serials for all proxies. +pub(crate) async fn refresh_certs(pool: &PgPool, tx: &watch::Sender>>) { + match Proxy::all(pool).await { + Ok(proxies) => { + let certs = collect_certs( + proxies + .into_iter() + .map(|proxy| (proxy.id, proxy.certificate)), + ); + let _ = tx.send(Arc::new(certs)); + } + Err(err) => { + warn!("Failed to refresh revoked certificate list: {err}"); + } + } +} + +/// Build a root store from the configured CA for WebPKI validation. +fn root_store_from_ca(ca_cert_der: &[u8]) -> Result { + let mut roots = RootCertStore::empty(); + roots + .add(CertificateDer::from(ca_cert_der.to_vec())) + .map_err(|err| ProxyError::TlsConfigError(err.to_string()))?; + Ok(roots) +} + +/// Create a rustls client config that enforces the pinned proxy certificate serial. +pub(crate) fn client_config( + ca_cert_der: &[u8], + certs_rx: watch::Receiver>>, + proxy_id: Id, +) -> Result { + let provider = Arc::new(crypto::ring::default_provider()); + let roots = root_store_from_ca(ca_cert_der)?; + let verifier_roots = root_store_from_ca(ca_cert_der)?; + let verifier = WebPkiServerVerifier::builder_with_provider( + Arc::new(verifier_roots), + Arc::clone(&provider), + ) + .build() + .map_err(|err| ProxyError::TlsConfigError(err.to_string()))?; + let builder = rustls::ClientConfig::builder_with_provider(provider) + .with_safe_default_protocol_versions() + .map_err(|err| ProxyError::TlsConfigError(err.to_string()))?; + let mut config = builder.with_root_certificates(roots).with_no_client_auth(); + config + .dangerous() + .set_certificate_verifier(Arc::new(CertVerifier::new(verifier, certs_rx, proxy_id))); + Ok(config) +} + +#[cfg(test)] +mod tests { + use super::*; + + use defguard_certs::{CertificateAuthority, Csr, DnType, generate_key_pair}; + use rustls::client::danger::HandshakeSignatureValid; + + #[derive(Debug)] + struct NoopVerifier; + + impl ServerCertVerifier for NoopVerifier { + fn verify_server_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, + _ocsp_response: &[u8], + _now: UnixTime, + ) -> Result { + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } + + fn supported_verify_schemes(&self) -> Vec { + Vec::new() + } + + fn root_hint_subjects(&self) -> Option<&[DistinguishedName]> { + None + } + } + + fn make_cert_and_serial() -> (CertificateDer<'static>, String) { + let ca = CertificateAuthority::new("Defguard CA", "test@example.com", 30).unwrap(); + let key_pair = generate_key_pair().unwrap(); + let csr = Csr::new( + &key_pair, + &["proxy.local".to_string()], + vec![(DnType::CommonName, "proxy.local")], + ) + .unwrap(); + let cert = ca.sign_csr(&csr).unwrap(); + let cert_der = CertificateDer::from(cert.der().to_vec()); + let (_, parsed) = parse_x509_certificate(cert_der.as_ref()).unwrap(); + let serial = parsed.tbs_certificate.raw_serial_as_string(); + (cert_der, serial) + } + + #[test] + fn collect_certs_skips_missing() { + let certs = collect_certs(vec![(1, None), (2, Some("abc".to_string()))]); + assert_eq!(certs.len(), 1); + assert_eq!(certs.get(&2), Some(&"abc".to_string())); + } + + #[test] + fn verify_accepts_expected_serial() { + let (cert_der, serial) = make_cert_and_serial(); + let (_tx, rx) = watch::channel(Arc::new(HashMap::from([(1, serial.clone())]))); + let verifier = CertVerifier::new(Arc::new(NoopVerifier), rx, 1); + let result = verifier.verify(&cert_der); + assert!(result.is_ok()); + } + + #[test] + fn verify_rejects_missing_expected_cert() { + let (cert_der, serial) = make_cert_and_serial(); + let (_tx, rx) = watch::channel(Arc::new(HashMap::from([(2, serial)]))); + let verifier = CertVerifier::new(Arc::new(NoopVerifier), rx, 1); + let result = verifier.verify(&cert_der); + assert!(matches!( + result, + Err(RustlsError::InvalidCertificate(CertificateError::Revoked)) + )); + } + + #[test] + fn verify_rejects_mismatched_serial() { + let (cert_der, _serial) = make_cert_and_serial(); + let (_tx, rx) = watch::channel(Arc::new(HashMap::from([(1, "deadbeef".to_string())]))); + let verifier = CertVerifier::new(Arc::new(NoopVerifier), rx, 1); + let result = verifier.verify(&cert_der); + assert!(matches!( + result, + Err(RustlsError::InvalidCertificate(CertificateError::Revoked)) + )); + } + + #[test] + fn verify_accepts_case_insensitive_serial() { + let (cert_der, serial) = make_cert_and_serial(); + let expected_lower = serial.to_ascii_lowercase(); + let (_tx, rx) = watch::channel(Arc::new(HashMap::from([(1, expected_lower)]))); + let verifier = CertVerifier::new(Arc::new(NoopVerifier), rx, 1); + let result = verifier.verify(&cert_der); + assert!(result.is_ok()); + + let expected_upper = serial.to_ascii_uppercase(); + let (_tx, rx) = watch::channel(Arc::new(HashMap::from([(1, expected_upper)]))); + let verifier = CertVerifier::new(Arc::new(NoopVerifier), rx, 1); + let result = verifier.verify(&cert_der); + assert!(result.is_ok()); + } +} diff --git a/crates/defguard_proxy_manager/src/error.rs b/crates/defguard_proxy_manager/src/error.rs new file mode 100644 index 0000000000..cc179ab2ff --- /dev/null +++ b/crates/defguard_proxy_manager/src/error.rs @@ -0,0 +1,32 @@ +use defguard_core::db::models::enrollment::TokenError; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum ProxyError { + #[error(transparent)] + InvalidUriError(#[from] axum::http::uri::InvalidUri), + #[error("Failed to read CA certificate: {0}")] + CaCertReadError(std::io::Error), + #[error(transparent)] + TonicError(#[from] tonic::transport::Error), + #[error(transparent)] + SemverError(#[from] semver::Error), + #[error(transparent)] + SqlxError(#[from] sqlx::Error), + #[error(transparent)] + TokenError(#[from] TokenError), + #[error(transparent)] + CertificateError(#[from] defguard_certs::CertificateError), + #[error(transparent)] + UrlParseError(#[from] openidconnect::url::ParseError), + #[error("Missing proxy configuration: {0}")] + MissingConfiguration(String), + #[error("URL error: {0}")] + UrlError(String), + #[error(transparent)] + Transport(#[from] tonic::Status), + #[error("Connection timeout: {0}")] + ConnectionTimeout(String), + #[error("TLS config error: {0}")] + TlsConfigError(String), +} diff --git a/crates/defguard_proxy_manager/src/lib.rs b/crates/defguard_proxy_manager/src/lib.rs index ee3fe4b9ef..b6b8a00c44 100644 --- a/crates/defguard_proxy_manager/src/lib.rs +++ b/crates/defguard_proxy_manager/src/lib.rs @@ -1,127 +1,37 @@ use std::{ collections::HashMap, - str::FromStr, sync::{Arc, RwLock}, time::Duration, }; -use axum_extra::extract::cookie::Key; -use defguard_certs::der_to_pem; -use defguard_common::{ - VERSION, - config::server_config, - db::{ - Id, - models::{Settings, proxy::Proxy}, - }, - types::proxy::ProxyControlMessage, -}; +use defguard_common::{db::models::proxy::Proxy, types::proxy::ProxyControlMessage}; use defguard_core::{ - db::models::enrollment::{ENROLLMENT_TOKEN_TYPE, Token, TokenError}, - enrollment_management::clear_unused_enrollment_tokens, - enterprise::{ - db::models::openid_provider::OpenIdProvider, - directory_sync::sync_user_groups_if_configured, - grpc::polling::PollingServer, - handlers::openid_login::{ - SELECT_ACCOUNT_SUPPORTED_PROVIDERS, build_state, make_oidc_client, user_from_claims, - }, - is_business_license_active, - ldap::utils::ldap_update_user_state, - }, - events::BidiStreamEvent, - grpc::{ - gateway::events::GatewayEvent, - proxy::client_mfa::{ClientLoginSession, ClientMfaServer}, - }, - version::{IncompatibleComponents, IncompatibleProxyData, is_proxy_version_supported}, + events::BidiStreamEvent, grpc::gateway::events::GatewayEvent, version::IncompatibleComponents, }; use defguard_mail::Mail; -use defguard_proto::proxy::{ - AuthCallbackResponse, AuthInfoResponse, CoreError, CoreRequest, CoreResponse, InitialInfo, - core_request, core_response, proxy_client::ProxyClient, -}; -use defguard_version::{ - ComponentInfo, DefguardComponent, client::ClientVersionInterceptor, get_tracing_variables, -}; -use openidconnect::{AuthorizationCode, Nonce, Scope, core::CoreAuthenticationFlow, url}; -use reqwest::Url; -use secrecy::ExposeSecret; -use semver::Version; use sqlx::PgPool; -use thiserror::Error; use tokio::{ select, sync::{ Mutex, broadcast::Sender, - mpsc::{self, Receiver, UnboundedSender}, - oneshot, + mpsc::{Receiver, UnboundedSender}, + watch, }, task::JoinSet, - time::sleep, -}; -use tokio_stream::wrappers::UnboundedReceiverStream; -use tonic::{ - Code, Streaming, - transport::{Certificate, ClientTlsConfig, Endpoint}, }; -use crate::{enrollment::EnrollmentServer, password_reset::PasswordResetServer}; +use crate::{certs::refresh_certs, error::ProxyError, proxy_handler::ProxyHandler}; -mod enrollment; -pub(crate) mod password_reset; +mod certs; +mod error; +mod proxy_handler; +mod servers; #[macro_use] extern crate tracing; const TEN_SECS: Duration = Duration::from_secs(10); -static VERSION_ZERO: Version = Version::new(0, 0, 0); - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub(crate) enum Scheme { - #[allow(dead_code)] - Http, - Https, -} - -impl Scheme { - #[must_use] - pub const fn as_str(&self) -> &str { - match self { - Self::Http => "http", - Self::Https => "https", - } - } -} - -#[derive(Error, Debug)] -pub enum ProxyError { - #[error(transparent)] - InvalidUriError(#[from] axum::http::uri::InvalidUri), - #[error("Failed to read CA certificate: {0}")] - CaCertReadError(std::io::Error), - #[error(transparent)] - TonicError(#[from] tonic::transport::Error), - #[error(transparent)] - SemverError(#[from] semver::Error), - #[error(transparent)] - SqlxError(#[from] sqlx::Error), - #[error(transparent)] - TokenError(#[from] TokenError), - #[error(transparent)] - CertificateError(#[from] defguard_certs::CertificateError), - #[error(transparent)] - UrlParseError(#[from] url::ParseError), - #[error("Missing proxy configuration: {0}")] - MissingConfiguration(String), - #[error("URL error: {0}")] - UrlError(String), - #[error(transparent)] - Transport(#[from] tonic::Status), - #[error("Connection timeout: {0}")] - ConnectionTimeout(String), -} /// Coordinates communication between the Core and multiple proxy instances. /// @@ -159,13 +69,21 @@ impl ProxyManager { debug!("ProxyManager starting"); let remote_mfa_responses = Arc::default(); let sessions = Arc::default(); + let (certs_tx, certs_rx) = watch::channel(Arc::new(HashMap::new())); + let refresh_pool = self.pool.clone(); + tokio::spawn(async move { + loop { + refresh_certs(&refresh_pool, &certs_tx).await; + tokio::time::sleep(TEN_SECS).await; + } + }); // Retrieve proxies from DB. let mut shutdown_channels = HashMap::new(); - let mut proxies = Proxy::all(&self.pool) + let proxies = Proxy::all(&self.pool) .await? .iter() .map(|proxy| { - let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::(); shutdown_channels.insert(proxy.id, shutdown_tx); ProxyHandler::from_proxy( proxy, @@ -179,29 +97,15 @@ impl ProxyManager { .collect::, _>>()?; debug!("Retrieved {} proxies from the DB", proxies.len()); - // For backwards compatibility add the proxy specified in cli arg as well. - if let Some(ref url) = server_config().proxy_url { - debug!("Adding proxy from cli arg: {url}"); - let url = Url::from_str(url)?; - let proxy = ProxyHandler::new( - self.pool.clone(), - url, - &self.tx, - Arc::clone(&remote_mfa_responses), - Arc::clone(&sessions), - // Currently we can't shutdown this proxy since it was started via CLI arguments (no ID in DB) - // This should be removed when we do a proper import of old proxies - Arc::new(Mutex::new(None)), - None, - ); - proxies.push(proxy); - } - // Connect to all proxies. let mut tasks = JoinSet::>::new(); for proxy in proxies { debug!("Spawning proxy task for proxy {}", proxy.url); - tasks.spawn(proxy.run(self.tx.clone(), self.incompatible_components.clone())); + tasks.spawn(proxy.run( + self.tx.clone(), + self.incompatible_components.clone(), + certs_rx.clone(), + )); } loop { @@ -221,7 +125,7 @@ impl ProxyManager { Some(ProxyControlMessage::StartConnection(id)) => { debug!("Starting proxy with ID: {id}"); if let Ok(Some(proxy_model)) = Proxy::find_by_id(&self.pool, id).await { - let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::(); shutdown_channels.insert(id, shutdown_tx); match ProxyHandler::from_proxy( &proxy_model, @@ -233,7 +137,7 @@ impl ProxyManager { ) { Ok(proxy) => { debug!("Spawning proxy task for proxy {}", proxy.url); - tasks.spawn(proxy.run(self.tx.clone(), self.incompatible_components.clone())); + tasks.spawn(proxy.run(self.tx.clone(), self.incompatible_components.clone(), certs_rx.clone())); } Err(err) => error!("Failed to create proxy server: {err}"), } @@ -244,7 +148,15 @@ impl ProxyManager { Some(ProxyControlMessage::ShutdownConnection(id)) => { debug!("Shutting down proxy with ID: {id}"); if let Some(shutdown_tx) = shutdown_channels.remove(&id) { - let _ = shutdown_tx.send(()); + let _ = shutdown_tx.send(false); + } else { + warn!("No shutdown channel found for proxy ID: {id}"); + } + } + Some(ProxyControlMessage::Purge(id)) => { + debug!("Purging proxy with ID: {id}"); + if let Some(shutdown_tx) = shutdown_channels.remove(&id) { + let _ = shutdown_tx.send(true); } else { warn!("No shutdown channel found for proxy ID: {id}"); } @@ -284,778 +196,3 @@ impl ProxyTxSet { } } } - -type ShutdownReceiver = tokio::sync::oneshot::Receiver<()>; - -/// Represents a single Core - Proxy connection. -/// -/// A `Proxy` is responsible for establishing and maintaining a gRPC -/// bidirectional stream to one proxy instance, handling incoming requests -/// from that proxy, and forwarding responses back through the same stream. -/// Each `Proxy` runs independently and is supervised by the -/// `ProxyManager`. -struct ProxyHandler { - pool: PgPool, - /// gRPC servers - services: ProxyServices, - /// Proxy server gRPC URL - url: Url, - shutdown_signal: Arc>>, - proxy_id: Option, -} - -impl ProxyHandler { - pub fn new( - pool: PgPool, - url: Url, - tx: &ProxyTxSet, - remote_mfa_responses: Arc>>>, - sessions: Arc>>, - shutdown_signal: Arc>>, - proxy_id: Option, - ) -> Self { - // Instantiate gRPC servers. - let services = ProxyServices::new(&pool, tx, remote_mfa_responses, sessions); - - Self { - pool, - services, - url, - shutdown_signal, - proxy_id, - } - } - - fn from_proxy( - proxy: &Proxy, - pool: PgPool, - tx: &ProxyTxSet, - remote_mfa_responses: Arc>>>, - sessions: Arc>>, - shutdown_signal: Arc>>, - ) -> Result { - let url = Url::from_str(&format!("http://{}:{}", proxy.address, proxy.port))?; - let proxy_id = proxy.id; - Ok(Self::new( - pool, - url, - tx, - remote_mfa_responses, - sessions, - shutdown_signal, - Some(proxy_id), - )) - } - - async fn mark_connected(&self, version: &Version) -> Result<(), ProxyError> { - let Some(proxy_id) = self.proxy_id else { - warn!( - "Skipping marking connection time for proxy without id: {}", - self.url - ); - return Ok(()); - }; - - if let Some(mut proxy) = Proxy::find_by_id(&self.pool, proxy_id).await? { - proxy - .mark_connected(&self.pool, &version.to_string()) - .await?; - } else { - warn!("Couldn't find proxy by id, URL: {}", self.url); - } - - Ok(()) - } - - async fn mark_disconnected(&self) -> Result<(), ProxyError> { - let Some(proxy_id) = self.proxy_id else { - warn!( - "Skipping marking connection time for proxy without id: {}", - self.url - ); - return Ok(()); - }; - - let Some(mut proxy) = Proxy::find_by_id(&self.pool, proxy_id).await? else { - warn!("Couldn't find proxy by id, URL: {}", self.url); - return Ok(()); - }; - - // Make sure we don't continuously update disconnected time in connection loop - let should_mark = match (proxy.connected_at, proxy.disconnected_at) { - (Some(connected), Some(disconnected)) => disconnected < connected, - (Some(_), None) => true, - _ => false, - }; - - if should_mark { - proxy.mark_disconnected(&self.pool).await?; - } - - Ok(()) - } - - fn endpoint(&self, scheme: Scheme) -> Result { - let mut url = self.url.clone(); - - url.set_scheme(scheme.as_str()).map_err(|()| { - ProxyError::UrlError(format!("Failed to set {scheme:?} scheme on URL {url}")) - })?; - let endpoint = Endpoint::from_shared(url.to_string())?; - let endpoint = endpoint - .http2_keep_alive_interval(TEN_SECS) - .tcp_keepalive(Some(TEN_SECS)) - .keep_alive_while_idle(true); - - let endpoint = if scheme == Scheme::Https { - let settings = Settings::get_current_settings(); - let Some(ca_cert_der) = settings.ca_cert_der else { - return Err(ProxyError::MissingConfiguration( - "Core CA is not setup, can't create a Proxy endpoint.".to_string(), - )); - }; - - let cert_pem = der_to_pem(&ca_cert_der, defguard_certs::PemLabel::Certificate)?; - let tls = ClientTlsConfig::new().ca_certificate(Certificate::from_pem(&cert_pem)); - - endpoint.tls_config(tls)? - } else { - endpoint - }; - - Ok(endpoint) - } - - /// Establishes and maintains a gRPC bidirectional stream to the proxy. - /// - /// The proxy connection is retried on failure, compatibility is checked - /// on each successful connection, and incoming messages are handled - /// until the stream is closed. - pub(crate) async fn run( - mut self, - tx_set: ProxyTxSet, - incompatible_components: Arc>, - ) -> Result<(), ProxyError> { - loop { - let endpoint = self.endpoint(Scheme::Https)?; - - debug!("Connecting to proxy at {}", endpoint.uri()); - let interceptor = ClientVersionInterceptor::new(Version::parse(VERSION)?); - let mut client = ProxyClient::with_interceptor(endpoint.connect_lazy(), interceptor); - let (tx, rx) = mpsc::unbounded_channel(); - let response = match client.bidi(UnboundedReceiverStream::new(rx)).await { - Ok(response) => response, - Err(err) => { - match err.code() { - Code::FailedPrecondition => { - error!( - "Failed to connect to proxy @ {}, version check failed, retrying in \ - 10s: {err}", - endpoint.uri() - ); - // TODO push event - } - err => { - error!( - "Failed to connect to proxy @ {}, retrying in 10s: {err}", - endpoint.uri() - ); - } - } - self.mark_disconnected().await?; - sleep(TEN_SECS).await; - continue; - } - }; - let maybe_info = ComponentInfo::from_metadata(response.metadata()); - - // Check proxy version and continue if it's not supported. - let (version, info) = get_tracing_variables(&maybe_info); - let proxy_is_supported = is_proxy_version_supported(Some(&version)); - self.mark_connected(&version).await?; - - let span = tracing::info_span!("proxy_bidi", component = %DefguardComponent::Proxy, - version = version.to_string(), info); - let _guard = span.enter(); - if !proxy_is_supported { - // Store incompatible proxy - let maybe_version = if version == VERSION_ZERO { - None - } else { - Some(version) - }; - let data = IncompatibleProxyData::new(maybe_version); - data.insert(&incompatible_components); - - // Sleep before trying to reconnect - sleep(TEN_SECS).await; - continue; - } - IncompatibleComponents::remove_proxy(&incompatible_components); - - info!("Connected to proxy at {}", endpoint.uri()); - let mut resp_stream = response.into_inner(); - - // Derive proxy cookie key from core secret to avoid transmitting it over gRPC. - let config = server_config(); - let proxy_cookie_key = Key::derive_from(config.secret_key.expose_secret().as_bytes()); - - // Send initial info with private cookies key. - let initial_info = InitialInfo { - private_cookies_key: proxy_cookie_key.master().to_vec(), - }; - let _ = tx.send(CoreResponse { - id: 0, - payload: Some(core_response::Payload::InitialInfo(initial_info)), - }); - - let shutdown_signal = self.shutdown_signal.lock().await.take(); - if let Some(shutdown_signal) = shutdown_signal { - select! { - res = self.message_loop(tx, tx_set.wireguard.clone(), &mut resp_stream) => { - if let Err(err) = res { - error!("Proxy message loop ended with error: {err}, reconnecting in {TEN_SECS:?}",); - } else { - info!("Proxy message loop ended, reconnecting in {TEN_SECS:?}"); - } - self.mark_disconnected().await?; - sleep(TEN_SECS).await; - } - res = shutdown_signal => { - if let Err(err) = res { - error!("An error occurred when trying to wait for a shutdown signal for Proxy: {err}. Reconnecting to: {}", endpoint.uri()); - } else { - info!("Shutdown signal received, stopping proxy connection to {}", endpoint.uri()); - } - self.mark_disconnected().await?; - break; - } - } - } else { - self.message_loop(tx, tx_set.wireguard.clone(), &mut resp_stream) - .await?; - } - } - - Ok(()) - } - - /// Processes incoming requests from the proxy over an active gRPC stream. - /// - /// This loop receives `CoreRequest` messages from the proxy, dispatches - /// them to the appropriate Core-side handlers, and sends corresponding - /// `CoreResponse` messages back through the stream. Certain requests may - /// also register routing state for future responses. - async fn message_loop( - &mut self, - tx: UnboundedSender, - wireguard_tx: Sender, - resp_stream: &mut Streaming, - ) -> Result<(), ProxyError> { - let pool = self.pool.clone(); - 'message: loop { - match resp_stream.message().await { - Ok(None) => { - info!("stream was closed by the sender"); - break 'message; - } - Ok(Some(received)) => { - debug!("Received message from proxy; ID={}", received.id); - let payload = match received.payload { - // rpc CodeMfaSetupStart return (CodeMfaSetupStartResponse) - Some(core_request::Payload::CodeMfaSetupStart(request)) => { - match self - .services - .enrollment - .register_code_mfa_start(request) - .await - { - Ok(response) => Some( - core_response::Payload::CodeMfaSetupStartResponse(response), - ), - Err(err) => { - error!("Register mfa start error {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - // rpc CodeMfaSetupFinish return (CodeMfaSetupFinishResponse) - Some(core_request::Payload::CodeMfaSetupFinish(request)) => { - match self - .services - .enrollment - .register_code_mfa_finish(request) - .await - { - Ok(response) => Some( - core_response::Payload::CodeMfaSetupFinishResponse(response), - ), - Err(err) => { - error!("Register MFA finish error {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - // rpc ClientMfaTokenValidation return (ClientMfaTokenValidationResponse) - Some(core_request::Payload::ClientMfaTokenValidation(request)) => { - match self.services.client_mfa.validate_mfa_token(request).await { - Ok(response_payload) => { - Some(core_response::Payload::ClientMfaTokenValidation( - response_payload, - )) - } - Err(err) => { - error!("Client MFA validate token error {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - // rpc RegisterMobileAuth (RegisterMobileAuthRequest) return (google.protobuf.Empty) - Some(core_request::Payload::RegisterMobileAuth(request)) => { - match self.services.enrollment.register_mobile_auth(request).await { - Ok(()) => Some(core_response::Payload::Empty(())), - Err(err) => { - error!("Register mobile auth error {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - // rpc StartEnrollment (EnrollmentStartRequest) returns (EnrollmentStartResponse) - Some(core_request::Payload::EnrollmentStart(request)) => { - match self - .services - .enrollment - .start_enrollment(request, received.device_info) - .await - { - Ok(response_payload) => { - Some(core_response::Payload::EnrollmentStart(response_payload)) - } - Err(err) => { - error!("start enrollment error {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - // rpc ActivateUser (ActivateUserRequest) returns (google.protobuf.Empty) - Some(core_request::Payload::ActivateUser(request)) => { - match self - .services - .enrollment - .activate_user(request, received.device_info) - .await - { - Ok(()) => Some(core_response::Payload::Empty(())), - Err(err) => { - error!("activate user error {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - // rpc CreateDevice (NewDevice) returns (DeviceConfigResponse) - Some(core_request::Payload::NewDevice(request)) => { - match self - .services - .enrollment - .create_device(request, received.device_info) - .await - { - Ok(response_payload) => { - Some(core_response::Payload::DeviceConfig(response_payload)) - } - Err(err) => { - error!("create device error {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - // rpc GetNetworkInfo (ExistingDevice) returns (DeviceConfigResponse) - Some(core_request::Payload::ExistingDevice(request)) => { - match self - .services - .enrollment - .get_network_info(request, received.device_info) - .await - { - Ok(response_payload) => { - Some(core_response::Payload::DeviceConfig(response_payload)) - } - Err(err) => { - error!("get network info error {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - // rpc RequestPasswordReset (PasswordResetInitializeRequest) returns (google.protobuf.Empty) - Some(core_request::Payload::PasswordResetInit(request)) => { - match self - .services - .password_reset - .request_password_reset(request, received.device_info) - .await - { - Ok(()) => Some(core_response::Payload::Empty(())), - Err(err) => { - error!("password reset init error {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - // rpc StartPasswordReset (PasswordResetStartRequest) returns (PasswordResetStartResponse) - Some(core_request::Payload::PasswordResetStart(request)) => { - match self - .services - .password_reset - .start_password_reset(request, received.device_info) - .await - { - Ok(response_payload) => Some( - core_response::Payload::PasswordResetStart(response_payload), - ), - Err(err) => { - error!("password reset start error {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - // rpc ResetPassword (PasswordResetRequest) returns (google.protobuf.Empty) - Some(core_request::Payload::PasswordReset(request)) => { - match self - .services - .password_reset - .reset_password(request, received.device_info) - .await - { - Ok(()) => Some(core_response::Payload::Empty(())), - Err(err) => { - error!("password reset error {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - // rpc ClientMfaStart (ClientMfaStartRequest) returns (ClientMfaStartResponse) - Some(core_request::Payload::ClientMfaStart(request)) => { - match self - .services - .client_mfa - .start_client_mfa_login(request) - .await - { - Ok(response_payload) => { - Some(core_response::Payload::ClientMfaStart(response_payload)) - } - Err(err) => { - error!("client MFA start error {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - // rpc ClientRemoteMfaFinish (ClientRemoteMfaFinishRequest) returns (ClientRemoteMfaFinishResponse) - Some(core_request::Payload::AwaitRemoteMfaFinish(request)) => { - match self - .services - .client_mfa - .await_remote_mfa_login(request, tx.clone(), received.id) - .await - { - Ok(()) => None, - Err(err) => { - error!("Client remote MFA finish error: {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - // rpc ClientMfaFinish (ClientMfaFinishRequest) returns (ClientMfaFinishResponse) - Some(core_request::Payload::ClientMfaFinish(request)) => { - match self - .services - .client_mfa - .finish_client_mfa_login(request, received.device_info) - .await - { - Ok(response_payload) => { - Some(core_response::Payload::ClientMfaFinish(response_payload)) - } - Err(err) => { - match err.code() { - Code::FailedPrecondition => { - // User not yet done with OIDC authentication. Don't log it - // as an error. - debug!("Client MFA finish error: {err}"); - } - _ => { - // Log other errors as errors. - error!("Client MFA finish error: {err}"); - } - } - Some(core_response::Payload::CoreError(err.into())) - } - } - } - Some(core_request::Payload::ClientMfaOidcAuthenticate(request)) => { - match self - .services - .client_mfa - .auth_mfa_session_with_oidc(request, received.device_info) - .await - { - Ok(()) => Some(core_response::Payload::Empty(())), - Err(err) => { - error!("client MFA OIDC authenticate error {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - // rpc LocationInfo (LocationInfoRequest) returns (LocationInfoResponse) - Some(core_request::Payload::InstanceInfo(request)) => { - match self - .services - .polling - .info(request, received.device_info) - .await - { - Ok(response_payload) => { - Some(core_response::Payload::InstanceInfo(response_payload)) - } - Err(err) => { - if Code::FailedPrecondition == err.code() { - // Ignore the case when we are not enterprise but the client is - // trying to fetch the instance config, - // to avoid spamming the logs with misleading errors. - - debug!( - "A client tried to fetch the instance config, but we are \ - not enterprise." - ); - Some(core_response::Payload::CoreError(err.into())) - } else { - error!("Instance info error {err}"); - Some(core_response::Payload::CoreError(err.into())) - } - } - } - } - Some(core_request::Payload::AuthInfo(request)) => { - if !is_business_license_active() { - warn!("Enterprise license required"); - Some(core_response::Payload::CoreError(CoreError { - status_code: Code::FailedPrecondition as i32, - message: "no valid license".into(), - })) - } else if let Ok(redirect_url) = Url::parse(&request.redirect_url) { - if let Some(provider) = OpenIdProvider::get_current(&pool).await? { - match make_oidc_client(redirect_url, &provider).await { - Ok((_client_id, client)) => { - let mut authorize_url_builder = client - .authorize_url( - CoreAuthenticationFlow::AuthorizationCode, - || build_state(request.state), - Nonce::new_random, - ) - .add_scope(Scope::new("email".to_string())) - .add_scope(Scope::new("profile".to_string())); - - if SELECT_ACCOUNT_SUPPORTED_PROVIDERS - .iter() - .all(|p| p.eq_ignore_ascii_case(&provider.name)) - { - authorize_url_builder = authorize_url_builder - .add_prompt( - openidconnect::core::CoreAuthPrompt::SelectAccount, - ); - } - let (url, csrf_token, nonce) = - authorize_url_builder.url(); - - Some(core_response::Payload::AuthInfo( - AuthInfoResponse { - url: url.into(), - csrf_token: csrf_token.secret().to_owned(), - nonce: nonce.secret().to_owned(), - button_display_name: provider.display_name, - }, - )) - } - Err(err) => { - error!( - "Failed to setup external OIDC provider client: {err}" - ); - Some(core_response::Payload::CoreError(CoreError { - status_code: Code::Internal as i32, - message: "failed to build OIDC client".into(), - })) - } - } - } else { - error!("Failed to get current OpenID provider"); - Some(core_response::Payload::CoreError(CoreError { - status_code: Code::NotFound as i32, - message: "failed to get current OpenID provider".into(), - })) - } - } else { - error!( - "Invalid redirect URL in authentication info request: {}", - request.redirect_url - ); - Some(core_response::Payload::CoreError(CoreError { - status_code: Code::Internal as i32, - message: "invalid redirect URL".into(), - })) - } - } - Some(core_request::Payload::AuthCallback(request)) => { - match Url::parse(&request.callback_url) { - Ok(callback_url) => { - let code = AuthorizationCode::new(request.code); - match user_from_claims( - &pool, - Nonce::new(request.nonce), - code, - callback_url, - ) - .await - { - Ok(mut user) => { - clear_unused_enrollment_tokens(&user, &pool).await?; - if let Err(err) = sync_user_groups_if_configured( - &user, - &pool, - &wireguard_tx, - ) - .await - { - error!( - "Failed to sync user groups for user {} with the \ - directory while the user was logging in through an \ - external provider: {err}", - user.username, - ); - } else { - ldap_update_user_state(&mut user, &pool).await; - } - debug!("Cleared unused tokens for {}.", user.username); - debug!( - "Creating a new desktop activation token for user {} \ - as a result of proxy OpenID auth callback.", - user.username - ); - let config = server_config(); - let desktop_configuration = Token::new( - user.id, - Some(user.id), - Some(user.email), - config.enrollment_token_timeout.as_secs(), - Some(ENROLLMENT_TOKEN_TYPE.to_string()), - ); - debug!("Saving a new desktop configuration token..."); - desktop_configuration.save(&pool).await?; - debug!( - "Saved desktop configuration token. Responding to \ - proxy with the token." - ); - let settings = Settings::get_current_settings(); - let public_proxy_url = settings.proxy_public_url()?; - - Some(core_response::Payload::AuthCallback( - AuthCallbackResponse { - url: public_proxy_url.into(), - token: desktop_configuration.id, - }, - )) - } - Err(err) => { - let message = format!("OpenID auth error {err}"); - error!(message); - Some(core_response::Payload::CoreError(CoreError { - status_code: Code::Internal as i32, - message, - })) - } - } - } - Err(err) => { - error!( - "Proxy requested an OpenID authentication info for a callback \ - URL ({}) that couldn't be parsed. Details: {err}", - request.callback_url - ); - Some(core_response::Payload::CoreError(CoreError { - status_code: Code::Internal as i32, - message: "invalid callback URL".into(), - })) - } - } - } - // Reply without payload. - None => None, - }; - - if let Some(payload) = payload { - let req = CoreResponse { - id: received.id, - payload: Some(payload), - }; - let _ = tx.send(req); - } - } - Err(err) => { - error!("Disconnected from proxy at {}: {err}", self.url); - debug!("waiting 10s to re-establish the connection"); - self.mark_disconnected().await?; - sleep(TEN_SECS).await; - break 'message; - } - } - } - - Ok(()) - } -} - -/// Groups Core-side service handlers used to process requests originating -/// from a proxy instance. -/// -/// Each `ProxyServices` instance is owned by a single `Proxy` and provides -/// the concrete handlers for enrollment, authentication, and polling-related -/// requests received over the gRPC bidirectional stream. -struct ProxyServices { - enrollment: EnrollmentServer, - password_reset: PasswordResetServer, - client_mfa: ClientMfaServer, - polling: PollingServer, -} - -impl ProxyServices { - pub fn new( - pool: &PgPool, - tx: &ProxyTxSet, - remote_mfa_responses: Arc>>>, - sessions: Arc>>, - ) -> Self { - let enrollment = EnrollmentServer::new( - pool.clone(), - tx.wireguard.clone(), - tx.mail.clone(), - tx.bidi_events.clone(), - ); - let password_reset = - PasswordResetServer::new(pool.clone(), tx.mail.clone(), tx.bidi_events.clone()); - let client_mfa = ClientMfaServer::new( - pool.clone(), - tx.mail.clone(), - tx.wireguard.clone(), - tx.bidi_events.clone(), - remote_mfa_responses, - sessions, - ); - let polling = PollingServer::new(pool.clone()); - - Self { - enrollment, - password_reset, - client_mfa, - polling, - } - } -} diff --git a/crates/defguard_proxy_manager/src/proxy_handler.rs b/crates/defguard_proxy_manager/src/proxy_handler.rs new file mode 100644 index 0000000000..8d66e24a2c --- /dev/null +++ b/crates/defguard_proxy_manager/src/proxy_handler.rs @@ -0,0 +1,894 @@ +use std::{ + collections::HashMap, + str::FromStr, + sync::{Arc, RwLock}, +}; + +use axum_extra::extract::cookie::Key; +use defguard_common::{ + VERSION, + config::server_config, + db::{ + Id, + models::{Settings, proxy::Proxy}, + }, +}; +use defguard_core::{ + db::models::enrollment::{ENROLLMENT_TOKEN_TYPE, Token}, + enrollment_management::clear_unused_enrollment_tokens, + enterprise::{ + db::models::openid_provider::OpenIdProvider, + directory_sync::sync_user_groups_if_configured, + grpc::polling::PollingServer, + handlers::openid_login::{ + SELECT_ACCOUNT_SUPPORTED_PROVIDERS, build_state, make_oidc_client, user_from_claims, + }, + is_business_license_active, + ldap::utils::ldap_update_user_state, + }, + grpc::{ + gateway::events::GatewayEvent, + proxy::client_mfa::{ClientLoginSession, ClientMfaServer}, + }, + version::{IncompatibleComponents, IncompatibleProxyData, is_proxy_version_supported}, +}; +use defguard_proto::proxy::{ + AuthCallbackResponse, AuthInfoResponse, CoreError, CoreRequest, CoreResponse, InitialInfo, + core_request, core_response, proxy_client::ProxyClient, +}; +use defguard_version::{ + ComponentInfo, DefguardComponent, client::ClientVersionInterceptor, get_tracing_variables, +}; +use http::Uri; +use hyper_rustls::HttpsConnectorBuilder; +use openidconnect::{AuthorizationCode, Nonce, Scope, core::CoreAuthenticationFlow}; +use reqwest::Url; +use secrecy::ExposeSecret; +use semver::Version; +use sqlx::PgPool; +use tokio::{ + select, + sync::{ + Mutex, + broadcast::Sender, + mpsc::{self, UnboundedSender}, + oneshot, watch, + }, + time::sleep, +}; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tonic::{ + Code, Request, Streaming, + service::interceptor::InterceptedService, + transport::{Channel, Endpoint}, +}; + +use crate::{ + ProxyError, ProxyTxSet, TEN_SECS, + certs::client_config, + servers::{EnrollmentServer, PasswordResetServer}, +}; + +static VERSION_ZERO: Version = Version::new(0, 0, 0); + +type ShutdownReceiver = tokio::sync::oneshot::Receiver; + +/// Represents a single Core - Proxy connection. +/// +/// A `ProxyHandler` is responsible for establishing and maintaining a gRPC +/// bidirectional stream to one proxy instance, handling incoming requests +/// from that proxy, and forwarding responses back through the same stream. +/// Each `ProxyHandler` runs independently and is supervised by the +/// `ProxyManager`. +pub(super) struct ProxyHandler { + pool: PgPool, + /// gRPC servers + services: ProxyServices, + /// Proxy server gRPC URL + pub(super) url: Url, + shutdown_signal: Arc>>, + proxy_id: Id, + client: Option>>, +} + +impl ProxyHandler { + pub(super) fn new( + pool: PgPool, + url: Url, + tx: &ProxyTxSet, + remote_mfa_responses: Arc>>>, + sessions: Arc>>, + shutdown_signal: Arc>>, + proxy_id: Id, + ) -> Self { + // Instantiate gRPC servers. + let services = ProxyServices::new(&pool, tx, remote_mfa_responses, sessions); + + Self { + pool, + services, + url, + shutdown_signal, + proxy_id, + client: None, + } + } + + pub(super) fn from_proxy( + proxy: &Proxy, + pool: PgPool, + tx: &ProxyTxSet, + remote_mfa_responses: Arc>>>, + sessions: Arc>>, + shutdown_signal: Arc>>, + ) -> Result { + let url = Url::from_str(&format!("http://{}:{}", proxy.address, proxy.port))?; + let proxy_id = proxy.id; + Ok(Self::new( + pool, + url, + tx, + remote_mfa_responses, + sessions, + shutdown_signal, + proxy_id, + )) + } + + async fn mark_connected(&self, version: &Version) -> Result<(), ProxyError> { + if let Some(mut proxy) = Proxy::find_by_id(&self.pool, self.proxy_id).await? { + proxy + .mark_connected(&self.pool, &version.to_string()) + .await?; + } else { + warn!("Couldn't find proxy by id, URL: {}", self.url); + } + + Ok(()) + } + + async fn mark_disconnected(&self) -> Result<(), ProxyError> { + let Some(mut proxy) = Proxy::find_by_id(&self.pool, self.proxy_id).await? else { + warn!("Couldn't find proxy by id, URL: {}", self.url); + return Ok(()); + }; + + // Make sure we don't continuously update disconnected time in connection loop + let should_mark = match (proxy.connected_at, proxy.disconnected_at) { + (Some(connected), Some(disconnected)) => disconnected < connected, + (Some(_), None) => true, + _ => false, + }; + + if should_mark { + proxy.mark_disconnected(&self.pool).await?; + } + + Ok(()) + } + + fn endpoint(&self) -> Result { + let mut url = self.url.clone(); + + // Using http here because the connector upgrades to TLS internally. + url.set_scheme("http").map_err(|()| { + ProxyError::UrlError(format!("Failed to set http scheme on URL {url}")) + })?; + let endpoint = Endpoint::from_shared(url.to_string())?; + let endpoint = endpoint + .http2_keep_alive_interval(TEN_SECS) + .tcp_keepalive(Some(TEN_SECS)) + .keep_alive_while_idle(true); + + Ok(endpoint) + } + + /// Establishes and maintains a gRPC bidirectional stream to the proxy. + /// + /// The proxy connection is retried on failure, compatibility is checked + /// on each successful connection, and incoming messages are handled + /// until the stream is closed. + pub(super) async fn run( + mut self, + tx_set: ProxyTxSet, + incompatible_components: Arc>, + certs_rx: watch::Receiver>>, + ) -> Result<(), ProxyError> { + loop { + let endpoint = self.endpoint()?; + let settings = Settings::get_current_settings(); + let Some(ca_cert_der) = settings.ca_cert_der else { + return Err(ProxyError::MissingConfiguration( + "Core CA is not setup, can't create a Proxy endpoint.".to_string(), + )); + }; + let tls_config = client_config(&ca_cert_der, certs_rx.clone(), self.proxy_id)?; + let connector = HttpsConnectorBuilder::new() + .with_tls_config(tls_config) + .https_only() + .enable_http2() + .build(); + let connector = HttpsSchemeConnector::new(connector); + + debug!("Connecting to proxy at {}", endpoint.uri()); + let interceptor = ClientVersionInterceptor::new(Version::parse(VERSION)?); + let channel = endpoint.connect_with_connector_lazy(connector); + let mut client = ProxyClient::with_interceptor(channel, interceptor); + self.client = Some(client.clone()); + let (tx, rx) = mpsc::unbounded_channel(); + let response = match client.bidi(UnboundedReceiverStream::new(rx)).await { + Ok(response) => response, + Err(err) => { + match err.code() { + Code::FailedPrecondition => { + error!( + "Failed to connect to proxy @ {}, version check failed, retrying in \ + 10s: {err}", + endpoint.uri() + ); + // TODO push event + } + err => { + error!( + "Failed to connect to proxy @ {}, retrying in 10s: {err}", + endpoint.uri() + ); + } + } + self.mark_disconnected().await?; + sleep(TEN_SECS).await; + continue; + } + }; + let maybe_info = ComponentInfo::from_metadata(response.metadata()); + + // Check proxy version and continue if it's not supported. + let (version, info) = get_tracing_variables(&maybe_info); + let proxy_is_supported = is_proxy_version_supported(Some(&version)); + self.mark_connected(&version).await?; + + let span = tracing::info_span!("proxy_bidi", component = %DefguardComponent::Proxy, + version = version.to_string(), info); + let _guard = span.enter(); + if !proxy_is_supported { + // Store incompatible proxy + let maybe_version = if version == VERSION_ZERO { + None + } else { + Some(version) + }; + let data = IncompatibleProxyData::new(maybe_version); + data.insert(&incompatible_components); + + // Sleep before trying to reconnect + sleep(TEN_SECS).await; + continue; + } + IncompatibleComponents::remove_proxy(&incompatible_components); + + info!("Connected to proxy at {}", endpoint.uri()); + let mut resp_stream = response.into_inner(); + + // Derive proxy cookie key from core secret to avoid transmitting it over gRPC. + let config = server_config(); + let proxy_cookie_key = Key::derive_from(config.secret_key.expose_secret().as_bytes()); + + // Send initial info with private cookies key. + let initial_info = InitialInfo { + private_cookies_key: proxy_cookie_key.master().to_vec(), + }; + let _ = tx.send(CoreResponse { + id: 0, + payload: Some(core_response::Payload::InitialInfo(initial_info)), + }); + + let shutdown_signal = self.shutdown_signal.lock().await.take(); + if let Some(shutdown_signal) = shutdown_signal { + select! { + res = self.message_loop(tx, tx_set.wireguard.clone(), &mut resp_stream) => { + if let Err(err) = res { + error!("Proxy message loop ended with error: {err}, reconnecting in {TEN_SECS:?}",); + } else { + info!("Proxy message loop ended, reconnecting in {TEN_SECS:?}"); + } + self.mark_disconnected().await?; + sleep(TEN_SECS).await; + } + res = shutdown_signal => { + match res { + Err(err) => { + error!("An error occurred when trying to wait for a shutdown signal for Proxy: {err}. Reconnecting to: {}", endpoint.uri()); + } + Ok(purge) => { + info!("Shutdown signal received, purge: {purge}, stopping proxy connection to {}", endpoint.uri()); + if purge { + debug!("Sending purge request to proxy {}", endpoint.uri()); + if let Some(client) = self.client.as_mut() { + if let Err(err) = client.purge(Request::new(())).await { + error!("Error sending purge request to proxy {}: {err}", endpoint.uri()); + } + } + } + } + } + self.mark_disconnected().await?; + break; + } + } + } else { + self.message_loop(tx, tx_set.wireguard.clone(), &mut resp_stream) + .await?; + } + } + + Ok(()) + } + + /// Processes incoming requests from the proxy over an active gRPC stream. + /// + /// This loop receives `CoreRequest` messages from the proxy, dispatches + /// them to the appropriate Core-side handlers, and sends corresponding + /// `CoreResponse` messages back through the stream. Certain requests may + /// also register routing state for future responses. + async fn message_loop( + &mut self, + tx: UnboundedSender, + wireguard_tx: Sender, + resp_stream: &mut Streaming, + ) -> Result<(), ProxyError> { + let pool = self.pool.clone(); + 'message: loop { + match resp_stream.message().await { + Ok(None) => { + info!("stream was closed by the sender"); + break 'message; + } + Ok(Some(received)) => { + debug!("Received message from proxy; ID={}", received.id); + let payload = match received.payload { + // rpc CodeMfaSetupStart return (CodeMfaSetupStartResponse) + Some(core_request::Payload::CodeMfaSetupStart(request)) => { + match self + .services + .enrollment + .register_code_mfa_start(request) + .await + { + Ok(response) => Some( + core_response::Payload::CodeMfaSetupStartResponse(response), + ), + Err(err) => { + error!("Register mfa start error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + // rpc CodeMfaSetupFinish return (CodeMfaSetupFinishResponse) + Some(core_request::Payload::CodeMfaSetupFinish(request)) => { + match self + .services + .enrollment + .register_code_mfa_finish(request) + .await + { + Ok(response) => Some( + core_response::Payload::CodeMfaSetupFinishResponse(response), + ), + Err(err) => { + error!("Register MFA finish error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + // rpc ClientMfaTokenValidation return (ClientMfaTokenValidationResponse) + Some(core_request::Payload::ClientMfaTokenValidation(request)) => { + match self.services.client_mfa.validate_mfa_token(request).await { + Ok(response_payload) => { + Some(core_response::Payload::ClientMfaTokenValidation( + response_payload, + )) + } + Err(err) => { + error!("Client MFA validate token error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + // rpc RegisterMobileAuth (RegisterMobileAuthRequest) return (google.protobuf.Empty) + Some(core_request::Payload::RegisterMobileAuth(request)) => { + match self.services.enrollment.register_mobile_auth(request).await { + Ok(()) => Some(core_response::Payload::Empty(())), + Err(err) => { + error!("Register mobile auth error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + // rpc StartEnrollment (EnrollmentStartRequest) returns (EnrollmentStartResponse) + Some(core_request::Payload::EnrollmentStart(request)) => { + match self + .services + .enrollment + .start_enrollment(request, received.device_info) + .await + { + Ok(response_payload) => { + Some(core_response::Payload::EnrollmentStart(response_payload)) + } + Err(err) => { + error!("start enrollment error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + // rpc ActivateUser (ActivateUserRequest) returns (google.protobuf.Empty) + Some(core_request::Payload::ActivateUser(request)) => { + match self + .services + .enrollment + .activate_user(request, received.device_info) + .await + { + Ok(()) => Some(core_response::Payload::Empty(())), + Err(err) => { + error!("activate user error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + // rpc CreateDevice (NewDevice) returns (DeviceConfigResponse) + Some(core_request::Payload::NewDevice(request)) => { + match self + .services + .enrollment + .create_device(request, received.device_info) + .await + { + Ok(response_payload) => { + Some(core_response::Payload::DeviceConfig(response_payload)) + } + Err(err) => { + error!("create device error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + // rpc GetNetworkInfo (ExistingDevice) returns (DeviceConfigResponse) + Some(core_request::Payload::ExistingDevice(request)) => { + match self + .services + .enrollment + .get_network_info(request, received.device_info) + .await + { + Ok(response_payload) => { + Some(core_response::Payload::DeviceConfig(response_payload)) + } + Err(err) => { + error!("get network info error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + // rpc RequestPasswordReset (PasswordResetInitializeRequest) returns (google.protobuf.Empty) + Some(core_request::Payload::PasswordResetInit(request)) => { + match self + .services + .password_reset + .request_password_reset(request, received.device_info) + .await + { + Ok(()) => Some(core_response::Payload::Empty(())), + Err(err) => { + error!("password reset init error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + // rpc StartPasswordReset (PasswordResetStartRequest) returns (PasswordResetStartResponse) + Some(core_request::Payload::PasswordResetStart(request)) => { + match self + .services + .password_reset + .start_password_reset(request, received.device_info) + .await + { + Ok(response_payload) => Some( + core_response::Payload::PasswordResetStart(response_payload), + ), + Err(err) => { + error!("password reset start error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + // rpc ResetPassword (PasswordResetRequest) returns (google.protobuf.Empty) + Some(core_request::Payload::PasswordReset(request)) => { + match self + .services + .password_reset + .reset_password(request, received.device_info) + .await + { + Ok(()) => Some(core_response::Payload::Empty(())), + Err(err) => { + error!("password reset error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + // rpc ClientMfaStart (ClientMfaStartRequest) returns (ClientMfaStartResponse) + Some(core_request::Payload::ClientMfaStart(request)) => { + match self + .services + .client_mfa + .start_client_mfa_login(request) + .await + { + Ok(response_payload) => { + Some(core_response::Payload::ClientMfaStart(response_payload)) + } + Err(err) => { + error!("client MFA start error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + // rpc ClientRemoteMfaFinish (ClientRemoteMfaFinishRequest) returns (ClientRemoteMfaFinishResponse) + Some(core_request::Payload::AwaitRemoteMfaFinish(request)) => { + match self + .services + .client_mfa + .await_remote_mfa_login(request, tx.clone(), received.id) + .await + { + Ok(()) => None, + Err(err) => { + error!("Client remote MFA finish error: {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + // rpc ClientMfaFinish (ClientMfaFinishRequest) returns (ClientMfaFinishResponse) + Some(core_request::Payload::ClientMfaFinish(request)) => { + match self + .services + .client_mfa + .finish_client_mfa_login(request, received.device_info) + .await + { + Ok(response_payload) => { + Some(core_response::Payload::ClientMfaFinish(response_payload)) + } + Err(err) => { + match err.code() { + Code::FailedPrecondition => { + // User not yet done with OIDC authentication. Don't log it + // as an error. + debug!("Client MFA finish error: {err}"); + } + _ => { + // Log other errors as errors. + error!("Client MFA finish error: {err}"); + } + } + Some(core_response::Payload::CoreError(err.into())) + } + } + } + Some(core_request::Payload::ClientMfaOidcAuthenticate(request)) => { + match self + .services + .client_mfa + .auth_mfa_session_with_oidc(request, received.device_info) + .await + { + Ok(()) => Some(core_response::Payload::Empty(())), + Err(err) => { + error!("client MFA OIDC authenticate error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + // rpc LocationInfo (LocationInfoRequest) returns (LocationInfoResponse) + Some(core_request::Payload::InstanceInfo(request)) => { + match self + .services + .polling + .info(request, received.device_info) + .await + { + Ok(response_payload) => { + Some(core_response::Payload::InstanceInfo(response_payload)) + } + Err(err) => { + if Code::FailedPrecondition == err.code() { + // Ignore the case when we are not enterprise but the client is + // trying to fetch the instance config, + // to avoid spamming the logs with misleading errors. + + debug!( + "A client tried to fetch the instance config, but we are \ + not enterprise." + ); + Some(core_response::Payload::CoreError(err.into())) + } else { + error!("Instance info error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + } + Some(core_request::Payload::AuthInfo(request)) => { + if !is_business_license_active() { + warn!("Enterprise license required"); + Some(core_response::Payload::CoreError(CoreError { + status_code: Code::FailedPrecondition as i32, + message: "no valid license".into(), + })) + } else if let Ok(redirect_url) = Url::parse(&request.redirect_url) { + if let Some(provider) = OpenIdProvider::get_current(&pool).await? { + match make_oidc_client(redirect_url, &provider).await { + Ok((_client_id, client)) => { + let mut authorize_url_builder = client + .authorize_url( + CoreAuthenticationFlow::AuthorizationCode, + || build_state(request.state), + Nonce::new_random, + ) + .add_scope(Scope::new("email".to_string())) + .add_scope(Scope::new("profile".to_string())); + + if SELECT_ACCOUNT_SUPPORTED_PROVIDERS + .iter() + .all(|p| p.eq_ignore_ascii_case(&provider.name)) + { + authorize_url_builder = authorize_url_builder + .add_prompt( + openidconnect::core::CoreAuthPrompt::SelectAccount, + ); + } + let (url, csrf_token, nonce) = + authorize_url_builder.url(); + + Some(core_response::Payload::AuthInfo( + AuthInfoResponse { + url: url.into(), + csrf_token: csrf_token.secret().to_owned(), + nonce: nonce.secret().to_owned(), + button_display_name: provider.display_name, + }, + )) + } + Err(err) => { + error!( + "Failed to setup external OIDC provider client: {err}" + ); + Some(core_response::Payload::CoreError(CoreError { + status_code: Code::Internal as i32, + message: "failed to build OIDC client".into(), + })) + } + } + } else { + error!("Failed to get current OpenID provider"); + Some(core_response::Payload::CoreError(CoreError { + status_code: Code::NotFound as i32, + message: "failed to get current OpenID provider".into(), + })) + } + } else { + error!( + "Invalid redirect URL in authentication info request: {}", + request.redirect_url + ); + Some(core_response::Payload::CoreError(CoreError { + status_code: Code::Internal as i32, + message: "invalid redirect URL".into(), + })) + } + } + Some(core_request::Payload::AuthCallback(request)) => { + match Url::parse(&request.callback_url) { + Ok(callback_url) => { + let code = AuthorizationCode::new(request.code); + match user_from_claims( + &pool, + Nonce::new(request.nonce), + code, + callback_url, + ) + .await + { + Ok(mut user) => { + clear_unused_enrollment_tokens(&user, &pool).await?; + if let Err(err) = sync_user_groups_if_configured( + &user, + &pool, + &wireguard_tx, + ) + .await + { + error!( + "Failed to sync user groups for user {} with the \ + directory while the user was logging in through an \ + external provider: {err}", + user.username, + ); + } else { + ldap_update_user_state(&mut user, &pool).await; + } + debug!("Cleared unused tokens for {}.", user.username); + debug!( + "Creating a new desktop activation token for user {} \ + as a result of proxy OpenID auth callback.", + user.username + ); + let config = server_config(); + let desktop_configuration = Token::new( + user.id, + Some(user.id), + Some(user.email), + config.enrollment_token_timeout.as_secs(), + Some(ENROLLMENT_TOKEN_TYPE.to_string()), + ); + debug!("Saving a new desktop configuration token..."); + desktop_configuration.save(&pool).await?; + debug!( + "Saved desktop configuration token. Responding to \ + proxy with the token." + ); + let settings = Settings::get_current_settings(); + let public_proxy_url = settings.proxy_public_url()?; + + Some(core_response::Payload::AuthCallback( + AuthCallbackResponse { + url: public_proxy_url.into(), + token: desktop_configuration.id, + }, + )) + } + Err(err) => { + let message = format!("OpenID auth error {err}"); + error!(message); + Some(core_response::Payload::CoreError(CoreError { + status_code: Code::Internal as i32, + message, + })) + } + } + } + Err(err) => { + error!( + "Proxy requested an OpenID authentication info for a callback \ + URL ({}) that couldn't be parsed. Details: {err}", + request.callback_url + ); + Some(core_response::Payload::CoreError(CoreError { + status_code: Code::Internal as i32, + message: "invalid callback URL".into(), + })) + } + } + } + // Reply without payload. + None => None, + }; + + if let Some(payload) = payload { + let req = CoreResponse { + id: received.id, + payload: Some(payload), + }; + let _ = tx.send(req); + } + } + Err(err) => { + error!("Disconnected from proxy at {}: {err}", self.url); + debug!("waiting 10s to re-establish the connection"); + self.mark_disconnected().await?; + sleep(TEN_SECS).await; + break 'message; + } + } + } + + Ok(()) + } +} + +/// Groups Core-side service handlers used to process requests originating +/// from a proxy instance. +/// +/// Each `ProxyServices` instance is owned by a single `Proxy` and provides +/// the concrete handlers for enrollment, authentication, and polling-related +/// requests received over the gRPC bidirectional stream. +struct ProxyServices { + enrollment: EnrollmentServer, + password_reset: PasswordResetServer, + client_mfa: ClientMfaServer, + polling: PollingServer, +} + +impl ProxyServices { + pub fn new( + pool: &PgPool, + tx: &ProxyTxSet, + remote_mfa_responses: Arc>>>, + sessions: Arc>>, + ) -> Self { + let enrollment = EnrollmentServer::new( + pool.clone(), + tx.wireguard.clone(), + tx.mail.clone(), + tx.bidi_events.clone(), + ); + let password_reset = + PasswordResetServer::new(pool.clone(), tx.mail.clone(), tx.bidi_events.clone()); + let client_mfa = ClientMfaServer::new( + pool.clone(), + tx.mail.clone(), + tx.wireguard.clone(), + tx.bidi_events.clone(), + remote_mfa_responses, + sessions, + ); + let polling = PollingServer::new(pool.clone()); + + Self { + enrollment, + password_reset, + client_mfa, + polling, + } + } +} + +/// Rewrites the request URI scheme to https for the TLS connector. +/// +/// Tonic expects an http URI for its endpoint, but our custom connector performs +/// the TLS handshake and requires https to select the TLS path. +#[derive(Clone, Debug)] +struct HttpsSchemeConnector { + inner: C, +} + +impl HttpsSchemeConnector { + const fn new(inner: C) -> Self { + Self { inner } + } +} + +type BoxError = Box; + +impl tower_service::Service for HttpsSchemeConnector +where + C: tower_service::Service + Clone + Send + 'static, + C::Future: Send, +{ + type Response = C::Response; + type Error = BoxError; + type Future = std::pin::Pin< + Box> + Send>, + >; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, uri: Uri) -> Self::Future { + let mut parts = uri.into_parts(); + parts.scheme = Some(http::uri::Scheme::HTTPS); + let https_uri = match Uri::from_parts(parts) { + Ok(uri) => uri, + Err(err) => { + return Box::pin(async move { Err(err.into()) }); + } + }; + let mut inner = self.inner.clone(); + Box::pin(async move { inner.call(https_uri).await }) + } +} diff --git a/crates/defguard_proxy_manager/src/enrollment.rs b/crates/defguard_proxy_manager/src/servers/enrollment.rs similarity index 99% rename from crates/defguard_proxy_manager/src/enrollment.rs rename to crates/defguard_proxy_manager/src/servers/enrollment.rs index b6541c7f1c..4aca3800fb 100644 --- a/crates/defguard_proxy_manager/src/enrollment.rs +++ b/crates/defguard_proxy_manager/src/servers/enrollment.rs @@ -50,7 +50,7 @@ use tokio::sync::{ }; use tonic::Status; -pub(super) struct EnrollmentServer { +pub(crate) struct EnrollmentServer { pool: PgPool, wireguard_tx: Sender, mail_tx: UnboundedSender, @@ -59,7 +59,7 @@ pub(super) struct EnrollmentServer { impl EnrollmentServer { #[must_use] - pub fn new( + pub(crate) fn new( pool: PgPool, wireguard_tx: Sender, mail_tx: UnboundedSender, @@ -103,7 +103,7 @@ impl EnrollmentServer { } /// Sends given `GatewayEvent` to be handled by gateway GRPC server - pub fn send_wireguard_event(&self, event: GatewayEvent) { + pub(crate) fn send_wireguard_event(&self, event: GatewayEvent) { if let Err(err) = self.wireguard_tx.send(event) { error!("Error sending WireGuard event {err}"); } @@ -124,7 +124,7 @@ impl EnrollmentServer { } #[instrument(skip_all)] - pub async fn start_enrollment( + pub(crate) async fn start_enrollment( &self, request: EnrollmentStartRequest, info: Option, @@ -305,7 +305,7 @@ impl EnrollmentServer { } #[instrument(skip_all)] - pub async fn register_mobile_auth( + pub(crate) async fn register_mobile_auth( &self, request: RegisterMobileAuthRequest, ) -> Result<(), Status> { @@ -358,7 +358,7 @@ impl EnrollmentServer { } #[instrument(skip_all)] - pub async fn activate_user( + pub(crate) async fn activate_user( &self, request: ActivateUserRequest, req_device_info: Option, @@ -491,7 +491,7 @@ impl EnrollmentServer { } #[instrument(skip_all)] - pub async fn create_device( + pub(crate) async fn create_device( &self, request: NewDevice, req_device_info: Option, @@ -885,7 +885,7 @@ impl EnrollmentServer { /// Get all information needed to update instance information for desktop client #[instrument(skip_all)] - pub async fn get_network_info( + pub(crate) async fn get_network_info( &self, request: ExistingDevice, device_info: Option, diff --git a/crates/defguard_proxy_manager/src/servers/mod.rs b/crates/defguard_proxy_manager/src/servers/mod.rs new file mode 100644 index 0000000000..52ce7c164e --- /dev/null +++ b/crates/defguard_proxy_manager/src/servers/mod.rs @@ -0,0 +1,5 @@ +mod enrollment; +mod password_reset; + +pub(crate) use enrollment::EnrollmentServer; +pub(crate) use password_reset::PasswordResetServer; diff --git a/crates/defguard_proxy_manager/src/password_reset.rs b/crates/defguard_proxy_manager/src/servers/password_reset.rs similarity index 98% rename from crates/defguard_proxy_manager/src/password_reset.rs rename to crates/defguard_proxy_manager/src/servers/password_reset.rs index 41adbf6011..a07c87b646 100644 --- a/crates/defguard_proxy_manager/src/password_reset.rs +++ b/crates/defguard_proxy_manager/src/servers/password_reset.rs @@ -22,7 +22,7 @@ use sqlx::PgPool; use tokio::sync::mpsc::{UnboundedSender, error::SendError}; use tonic::Status; -pub(super) struct PasswordResetServer { +pub(crate) struct PasswordResetServer { pool: PgPool, mail_tx: UnboundedSender, bidi_event_tx: UnboundedSender, @@ -30,7 +30,7 @@ pub(super) struct PasswordResetServer { impl PasswordResetServer { #[must_use] - pub fn new( + pub(crate) fn new( pool: PgPool, mail_tx: UnboundedSender, bidi_event_tx: UnboundedSender, @@ -90,7 +90,7 @@ impl PasswordResetServer { } #[instrument(skip_all)] - pub async fn request_password_reset( + pub(crate) async fn request_password_reset( &self, request: PasswordResetInitializeRequest, req_device_info: Option, @@ -186,7 +186,7 @@ impl PasswordResetServer { } #[instrument(skip_all)] - pub async fn start_password_reset( + pub(crate) async fn start_password_reset( &self, request: PasswordResetStartRequest, info: Option, @@ -253,7 +253,7 @@ impl PasswordResetServer { } #[instrument(skip_all)] - pub async fn reset_password( + pub(crate) async fn reset_password( &self, request: PasswordResetRequest, req_device_info: Option, diff --git a/crates/defguard_vpn_stats_purge/Cargo.toml b/crates/defguard_vpn_stats_purge/Cargo.toml index 620de32a49..fc65f69cda 100644 --- a/crates/defguard_vpn_stats_purge/Cargo.toml +++ b/crates/defguard_vpn_stats_purge/Cargo.toml @@ -8,8 +8,6 @@ repository.workspace = true rust-version.workspace = true [dependencies] -defguard_common.workspace = true - chrono.workspace = true humantime.workspace = true sqlx.workspace = true diff --git a/flake.lock b/flake.lock index 2504b24d60..f719372873 100644 --- a/flake.lock +++ b/flake.lock @@ -74,11 +74,11 @@ ] }, "locked": { - "lastModified": 1770606655, - "narHash": "sha256-rpJf+kxvLWv32ivcgu8d+JeJooog3boJCT8J3joJvvM=", + "lastModified": 1770865833, + "narHash": "sha256-oiARqnlvaW6pVGheVi4ye6voqCwhg5hCcGish2ZvQzI=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "11a396520bf911e4ed01e78e11633d3fc63b350e", + "rev": "c8cfbe26238638e2f3a2c0ae7e8d240f5e4ded85", "type": "github" }, "original": { diff --git a/migrations/20260209080417_[2.0.0]_proxy_certificate.down.sql b/migrations/20260209080417_[2.0.0]_proxy_certificate.down.sql new file mode 100644 index 0000000000..d2c880529d --- /dev/null +++ b/migrations/20260209080417_[2.0.0]_proxy_certificate.down.sql @@ -0,0 +1,3 @@ +ALTER TABLE proxy ADD COLUMN has_certificate boolean; +UPDATE proxy SET has_certificate = (certificate IS NOT NULL); +ALTER TABLE proxy DROP COLUMN certificate; diff --git a/migrations/20260209080417_[2.0.0]_proxy_certificate.up.sql b/migrations/20260209080417_[2.0.0]_proxy_certificate.up.sql new file mode 100644 index 0000000000..2d671af774 --- /dev/null +++ b/migrations/20260209080417_[2.0.0]_proxy_certificate.up.sql @@ -0,0 +1,3 @@ +ALTER TABLE proxy + DROP COLUMN has_certificate, + ADD COLUMN certificate text; diff --git a/proto b/proto index fdbe98caa9..8326216b71 160000 --- a/proto +++ b/proto @@ -1 +1 @@ -Subproject commit fdbe98caa9413b626833da210b5b588b287bb146 +Subproject commit 8326216b71edc64acf8fe091bb27d690c8d6885f diff --git a/web/src/pages/EdgeSetupPage/steps/SetupConfirmationStep.tsx b/web/src/pages/EdgeSetupPage/steps/SetupConfirmationStep.tsx index 82120ad3b8..ffedb5fed3 100644 --- a/web/src/pages/EdgeSetupPage/steps/SetupConfirmationStep.tsx +++ b/web/src/pages/EdgeSetupPage/steps/SetupConfirmationStep.tsx @@ -1,3 +1,4 @@ +import { useQueryClient } from '@tanstack/react-query'; import { useNavigate } from '@tanstack/react-router'; import { m } from '../../../paraglide/messages'; import { ActionCard } from '../../../shared/components/ActionCard/ActionCard'; @@ -11,12 +12,14 @@ import { useEdgeWizardStore } from '../useEdgeWizardStore'; export const SetupConfirmationStep = () => { const navigate = useNavigate(); + const queryClient = useQueryClient(); const handleBack = () => { useEdgeWizardStore.getState().reset(); }; const handleFinish = () => { + queryClient.invalidateQueries({ queryKey: ['edge'] }); navigate({ to: '/edges', replace: true }).then(() => { setTimeout(() => { useEdgeWizardStore.getState().reset(); diff --git a/web/src/pages/EdgesPage/EdgesTable.tsx b/web/src/pages/EdgesPage/EdgesTable.tsx index ceae34adde..6ea89fbe93 100644 --- a/web/src/pages/EdgesPage/EdgesTable.tsx +++ b/web/src/pages/EdgesPage/EdgesTable.tsx @@ -1,3 +1,4 @@ +import { useMutation } from '@tanstack/react-query'; import { useNavigate } from '@tanstack/react-router'; import { createColumnHelper, @@ -8,6 +9,7 @@ import { import dayjs from 'dayjs'; import { useMemo, useState } from 'react'; import { m } from '../../paraglide/messages'; +import api from '../../shared/api/api'; import type { EdgeInfo } from '../../shared/api/types'; import { Badge } from '../../shared/defguard-ui/components/Badge/Badge'; import { Button } from '../../shared/defguard-ui/components/Button/Button'; @@ -48,6 +50,12 @@ const displayModifiedBy = (edge: EdgeInfo) => export const EdgesTable = ({ edges }: Props) => { const navigate = useNavigate(); + const { mutate: deleteEdge } = useMutation({ + mutationFn: api.edge.deleteEdge, + meta: { + invalidate: ['edge'], + }, + }); const addButtonProps = useMemo( (): ButtonProps => ({ @@ -192,6 +200,18 @@ export const EdgesTable = ({ edges }: Props) => { }, ], }, + { + items: [ + { + text: m.controls_delete(), + icon: 'delete', + variant: 'danger', + onClick: () => { + deleteEdge(rowData.id); + }, + }, + ], + }, ]; return ( @@ -202,7 +222,7 @@ export const EdgesTable = ({ edges }: Props) => { }, }), ], - [navigate], + [deleteEdge, navigate], ); const table = useReactTable({