diff --git a/.sqlx/query-b43694450d7abe3b93ea88fa7c95c38d3e2deb43d5ca3458724deb3ead69389a.json b/.sqlx/query-161dca354966b0bc33849d2ef1245351bf9bf9650acca042a12ad75a71fdee71.json similarity index 81% rename from .sqlx/query-b43694450d7abe3b93ea88fa7c95c38d3e2deb43d5ca3458724deb3ead69389a.json rename to .sqlx/query-161dca354966b0bc33849d2ef1245351bf9bf9650acca042a12ad75a71fdee71.json index 10051d513f..982c7ea725 100644 --- a/.sqlx/query-b43694450d7abe3b93ea88fa7c95c38d3e2deb43d5ca3458724deb3ead69389a.json +++ b/.sqlx/query-161dca354966b0bc33849d2ef1245351bf9bf9650acca042a12ad75a71fdee71.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "SELECT id, \"network_id\",\"url\",\"hostname\",\"connected_at\",\"disconnected_at\",\"has_certificate\",\"certificate_expiry\",\"version\",\"name\" FROM \"gateway\" WHERE id = $1", + "query": "SELECT id, \"network_id\",\"url\",\"hostname\",\"connected_at\",\"disconnected_at\",\"certificate\",\"certificate_expiry\",\"version\",\"name\" FROM \"gateway\" WHERE id = $1", "describe": { "columns": [ { @@ -35,8 +35,8 @@ }, { "ordinal": 6, - "name": "has_certificate", - "type_info": "Bool" + "name": "certificate", + "type_info": "Text" }, { "ordinal": 7, @@ -66,11 +66,11 @@ true, true, true, - false, + true, true, true, false ] }, - "hash": "b43694450d7abe3b93ea88fa7c95c38d3e2deb43d5ca3458724deb3ead69389a" + "hash": "161dca354966b0bc33849d2ef1245351bf9bf9650acca042a12ad75a71fdee71" } diff --git a/.sqlx/query-2f614ae8a1c1c62c11ed2e9b11e7004f869008e9dec303033ddbec8b0cee53f5.json b/.sqlx/query-2f614ae8a1c1c62c11ed2e9b11e7004f869008e9dec303033ddbec8b0cee53f5.json deleted file mode 100644 index 9af90ba77d..0000000000 --- a/.sqlx/query-2f614ae8a1c1c62c11ed2e9b11e7004f869008e9dec303033ddbec8b0cee53f5.json +++ /dev/null @@ -1,35 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "SELECT d.wireguard_pubkey pubkey, preshared_key, ARRAY(\n SELECT host(ip)\n FROM unnest(wnd.wireguard_ips) AS ip\n ) \"allowed_ips!: Vec\" FROM wireguard_network_device wnd JOIN device d ON wnd.device_id = d.id JOIN \"user\" u ON d.user_id = u.id WHERE wireguard_network_id = $1 AND (is_authorized = true OR NOT $2) AND d.configured = true AND u.is_active = true ORDER BY d.id ASC", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "pubkey", - "type_info": "Text" - }, - { - "ordinal": 1, - "name": "preshared_key", - "type_info": "Text" - }, - { - "ordinal": 2, - "name": "allowed_ips!: Vec", - "type_info": "TextArray" - } - ], - "parameters": { - "Left": [ - "Int8", - "Bool" - ] - }, - "nullable": [ - false, - true, - null - ] - }, - "hash": "2f614ae8a1c1c62c11ed2e9b11e7004f869008e9dec303033ddbec8b0cee53f5" -} diff --git a/.sqlx/query-ae3e3cef524f2a911808bf72e7c57b7f32e22adefc9b9185a9b3cd80c169a6e2.json b/.sqlx/query-5eee502cace9cd11b8d12f7345660fb8517656b090c5f93e017a1d4ffe552975.json similarity index 82% rename from .sqlx/query-ae3e3cef524f2a911808bf72e7c57b7f32e22adefc9b9185a9b3cd80c169a6e2.json rename to .sqlx/query-5eee502cace9cd11b8d12f7345660fb8517656b090c5f93e017a1d4ffe552975.json index 77722daf50..8e155ee205 100644 --- a/.sqlx/query-ae3e3cef524f2a911808bf72e7c57b7f32e22adefc9b9185a9b3cd80c169a6e2.json +++ b/.sqlx/query-5eee502cace9cd11b8d12f7345660fb8517656b090c5f93e017a1d4ffe552975.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "SELECT id, \"network_id\",\"url\",\"hostname\",\"connected_at\",\"disconnected_at\",\"has_certificate\",\"certificate_expiry\",\"version\",\"name\" FROM \"gateway\"", + "query": "SELECT id, \"network_id\",\"url\",\"hostname\",\"connected_at\",\"disconnected_at\",\"certificate\",\"certificate_expiry\",\"version\",\"name\" FROM \"gateway\"", "describe": { "columns": [ { @@ -35,8 +35,8 @@ }, { "ordinal": 6, - "name": "has_certificate", - "type_info": "Bool" + "name": "certificate", + "type_info": "Text" }, { "ordinal": 7, @@ -64,11 +64,11 @@ true, true, true, - false, + true, true, true, false ] }, - "hash": "ae3e3cef524f2a911808bf72e7c57b7f32e22adefc9b9185a9b3cd80c169a6e2" + "hash": "5eee502cace9cd11b8d12f7345660fb8517656b090c5f93e017a1d4ffe552975" } diff --git a/.sqlx/query-5af0fbf61295a5a23149c6248ea0b4a7afcbee1b63e34932c143f4697a0bc2cc.json b/.sqlx/query-6bcef8e62bfbb66c4787a95bea3187d9bdb32e1938592cd31ba98aca73d69746.json similarity index 64% rename from .sqlx/query-5af0fbf61295a5a23149c6248ea0b4a7afcbee1b63e34932c143f4697a0bc2cc.json rename to .sqlx/query-6bcef8e62bfbb66c4787a95bea3187d9bdb32e1938592cd31ba98aca73d69746.json index 0ee0514bda..7a00633baf 100644 --- a/.sqlx/query-5af0fbf61295a5a23149c6248ea0b4a7afcbee1b63e34932c143f4697a0bc2cc.json +++ b/.sqlx/query-6bcef8e62bfbb66c4787a95bea3187d9bdb32e1938592cd31ba98aca73d69746.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "INSERT INTO \"gateway\" (\"network_id\",\"url\",\"hostname\",\"connected_at\",\"disconnected_at\",\"has_certificate\",\"certificate_expiry\",\"version\",\"name\") VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9) RETURNING id", + "query": "INSERT INTO \"gateway\" (\"network_id\",\"url\",\"hostname\",\"connected_at\",\"disconnected_at\",\"certificate\",\"certificate_expiry\",\"version\",\"name\") VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9) RETURNING id", "describe": { "columns": [ { @@ -16,7 +16,7 @@ "Text", "Timestamp", "Timestamp", - "Bool", + "Text", "Timestamp", "Text", "Text" @@ -26,5 +26,5 @@ false ] }, - "hash": "5af0fbf61295a5a23149c6248ea0b4a7afcbee1b63e34932c143f4697a0bc2cc" + "hash": "6bcef8e62bfbb66c4787a95bea3187d9bdb32e1938592cd31ba98aca73d69746" } diff --git a/.sqlx/query-d10c9a7b0b391aeb8b4869f6bddf997807b66e0b532da747b146513c34e15c5c.json b/.sqlx/query-d10c9a7b0b391aeb8b4869f6bddf997807b66e0b532da747b146513c34e15c5c.json index a2e62691a2..47f529e92a 100644 --- a/.sqlx/query-d10c9a7b0b391aeb8b4869f6bddf997807b66e0b532da747b146513c34e15c5c.json +++ b/.sqlx/query-d10c9a7b0b391aeb8b4869f6bddf997807b66e0b532da747b146513c34e15c5c.json @@ -35,22 +35,22 @@ }, { "ordinal": 6, - "name": "has_certificate", - "type_info": "Bool" + "name": "certificate_expiry", + "type_info": "Timestamp" }, { "ordinal": 7, - "name": "certificate_expiry", - "type_info": "Timestamp" + "name": "version", + "type_info": "Text" }, { "ordinal": 8, - "name": "version", + "name": "name", "type_info": "Text" }, { "ordinal": 9, - "name": "name", + "name": "certificate", "type_info": "Text" } ], @@ -66,10 +66,10 @@ true, true, true, - false, true, true, - false + false, + true ] }, "hash": "d10c9a7b0b391aeb8b4869f6bddf997807b66e0b532da747b146513c34e15c5c" diff --git a/.sqlx/query-e9ca71b61f7a3736ca335d90aca36ab5a93dc8a00ad622267f13b3cd4cdb4a5a.json b/.sqlx/query-e9ca71b61f7a3736ca335d90aca36ab5a93dc8a00ad622267f13b3cd4cdb4a5a.json index db1d8414a5..ce7ff19423 100644 --- a/.sqlx/query-e9ca71b61f7a3736ca335d90aca36ab5a93dc8a00ad622267f13b3cd4cdb4a5a.json +++ b/.sqlx/query-e9ca71b61f7a3736ca335d90aca36ab5a93dc8a00ad622267f13b3cd4cdb4a5a.json @@ -35,22 +35,22 @@ }, { "ordinal": 6, - "name": "has_certificate", - "type_info": "Bool" + "name": "certificate_expiry", + "type_info": "Timestamp" }, { "ordinal": 7, - "name": "certificate_expiry", - "type_info": "Timestamp" + "name": "version", + "type_info": "Text" }, { "ordinal": 8, - "name": "version", + "name": "name", "type_info": "Text" }, { "ordinal": 9, - "name": "name", + "name": "certificate", "type_info": "Text" } ], @@ -66,10 +66,10 @@ true, true, true, - false, true, true, - false + false, + true ] }, "hash": "e9ca71b61f7a3736ca335d90aca36ab5a93dc8a00ad622267f13b3cd4cdb4a5a" diff --git a/.sqlx/query-ed3266f5f0d7b1613ad8745c9be953a7d9ef0becedf668c1d2225a1673003c77.json b/.sqlx/query-f653c2bf5fc813e1358004e2dfb77ffa5343609a16229c4a86726cc5d5148402.json similarity index 68% rename from .sqlx/query-ed3266f5f0d7b1613ad8745c9be953a7d9ef0becedf668c1d2225a1673003c77.json rename to .sqlx/query-f653c2bf5fc813e1358004e2dfb77ffa5343609a16229c4a86726cc5d5148402.json index 48849d4d38..d126e73ffc 100644 --- a/.sqlx/query-ed3266f5f0d7b1613ad8745c9be953a7d9ef0becedf668c1d2225a1673003c77.json +++ b/.sqlx/query-f653c2bf5fc813e1358004e2dfb77ffa5343609a16229c4a86726cc5d5148402.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "UPDATE \"gateway\" SET \"network_id\" = $2,\"url\" = $3,\"hostname\" = $4,\"connected_at\" = $5,\"disconnected_at\" = $6,\"has_certificate\" = $7,\"certificate_expiry\" = $8,\"version\" = $9,\"name\" = $10 WHERE id = $1", + "query": "UPDATE \"gateway\" SET \"network_id\" = $2,\"url\" = $3,\"hostname\" = $4,\"connected_at\" = $5,\"disconnected_at\" = $6,\"certificate\" = $7,\"certificate_expiry\" = $8,\"version\" = $9,\"name\" = $10 WHERE id = $1", "describe": { "columns": [], "parameters": { @@ -11,7 +11,7 @@ "Text", "Timestamp", "Timestamp", - "Bool", + "Text", "Timestamp", "Text", "Text" @@ -19,5 +19,5 @@ }, "nullable": [] }, - "hash": "ed3266f5f0d7b1613ad8745c9be953a7d9ef0becedf668c1d2225a1673003c77" + "hash": "f653c2bf5fc813e1358004e2dfb77ffa5343609a16229c4a86726cc5d5148402" } diff --git a/Cargo.lock b/Cargo.lock index 536531e9fc..5fa79b802c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1296,6 +1296,7 @@ dependencies = [ "defguard_core", "defguard_event_logger", "defguard_event_router", + "defguard_gateway_manager", "defguard_proxy_manager", "defguard_session_manager", "defguard_setup", @@ -1383,7 +1384,6 @@ dependencies = [ "defguard_web_ui", "futures", "humantime", - "hyper-util", "ipnetwork", "jsonwebkey", "jsonwebtoken", @@ -1466,6 +1466,32 @@ dependencies = [ "tracing", ] +[[package]] +name = "defguard_gateway_manager" +version = "0.0.0" +dependencies = [ + "anyhow", + "chrono", + "defguard_certs", + "defguard_common", + "defguard_core", + "defguard_grpc_tls", + "defguard_proto", + "defguard_version", + "hyper-rustls", + "hyper-util", + "reqwest", + "semver", + "serde_json", + "sqlx", + "thiserror 2.0.18", + "tokio", + "tokio-stream", + "tonic", + "tower", + "tracing", +] + [[package]] name = "defguard_generator" version = "0.0.0" @@ -1481,6 +1507,20 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "defguard_grpc_tls" +version = "0.0.0" +dependencies = [ + "defguard_common", + "http", + "rustls", + "thiserror 2.0.18", + "tokio", + "tower-service", + "tracing", + "x509-parser 0.18.1", +] + [[package]] name = "defguard_mail" version = "0.0.0" @@ -1525,14 +1565,13 @@ dependencies = [ "defguard_certs", "defguard_common", "defguard_core", + "defguard_grpc_tls", "defguard_mail", "defguard_proto", "defguard_version", - "http", "hyper-rustls", "openidconnect", "reqwest", - "rustls", "secrecy", "semver", "sqlx", @@ -1540,9 +1579,7 @@ dependencies = [ "tokio", "tokio-stream", "tonic", - "tower-service", "tracing", - "x509-parser 0.18.1", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 4da76766c6..7910a0505e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ edition = "2024" license-file = "LICENSE.md" homepage = "https://defguard.net/" repository = "https://github.com/DefGuard/defguard" -rust-version = "1.85.1" +rust-version = "1.87.0" [workspace] members = ["crates/*", "tools/*"] @@ -16,6 +16,7 @@ defguard_common = { path = "./crates/defguard_common", version = "2.0.0" } defguard_core = { path = "./crates/defguard_core", version = "0.0.0" } defguard_event_logger = { path = "./crates/defguard_event_logger", version = "0.0.0" } defguard_event_router = { path = "./crates/defguard_event_router", version = "0.0.0" } +defguard_gateway_manager = { path = "./crates/defguard_gateway_manager", version = "0.0.0" } defguard_mail = { path = "./crates/defguard_mail", version = "0.0.0" } defguard_proto = { path = "./crates/defguard_proto", version = "0.0.0" } defguard_proxy_manager = { path = "./crates/defguard_proxy_manager", version = "0.0.0" } @@ -24,6 +25,7 @@ defguard_version = { path = "./crates/defguard_version", version = "0.0.0" } defguard_vpn_stats_purge = { path = "./crates/defguard_vpn_stats_purge", version = "0.0.0" } defguard_web_ui = { path = "./crates/defguard_web_ui", version = "0.0.0" } defguard_certs = { path = "./crates/defguard_certs", version = "0.0.0" } +defguard_grpc_tls = { path = "./crates/defguard_grpc_tls", version = "0.0.0" } defguard_setup = { path = "./crates/defguard_setup", version = "0.0.0" } model_derive = { path = "./crates/model_derive", version = "0.0.0" } @@ -49,6 +51,7 @@ claims = "0.8" clap = { version = "4.5", features = ["derive", "env"] } futures = "0.3" http = "1.4" +hyper-rustls = { version = "0.27", features = ["http2"] } humantime = "2.1" # match version used by sqlx ipnetwork = "0.20" @@ -61,6 +64,7 @@ md4 = "0.10" openidconnect = { version = "4.0", default-features = false, features = [ "reqwest", ] } +os_info = "3.12" parse_link_header = "0.4" paste = "1.0" pgp = { version = "0.19", default-features = false } @@ -72,6 +76,7 @@ rcgen = { version = "0.14", features = ["x509-parser", "pem"] } reqwest = { version = "0.12", features = ["json"] } rsa = "0.9" rust-ini = "0.21" +rustls = { version = "0.23", features = ["ring"] } rustls-pki-types = "1.14" semver = { version = "1.0", features = ["serde"] } secrecy = { version = "0.10", features = ["serde"] } @@ -115,7 +120,9 @@ tonic-health = "0.14" tonic-prost = "0.14" tonic-prost-build = "0.14" totp-lite = { version = "2.0" } +tower = "0.5" tower-http = { version = "0.6", features = ["fs", "trace", "set-header"] } +tower-service = "0.3" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } trait-variant = "0.1" diff --git a/crates/defguard/Cargo.toml b/crates/defguard/Cargo.toml index 1378b80074..e6fef7e6b7 100644 --- a/crates/defguard/Cargo.toml +++ b/crates/defguard/Cargo.toml @@ -13,6 +13,7 @@ defguard_common = { workspace = true } defguard_core = { workspace = true } defguard_event_router = { workspace = true } defguard_event_logger = { workspace = true } +defguard_gateway_manager = { workspace = true } defguard_proxy_manager = { workspace = true } defguard_session_manager = { workspace = true } defguard_version = { workspace = true } diff --git a/crates/defguard/src/main.rs b/crates/defguard/src/main.rs index 0e65a8f648..b38f497619 100644 --- a/crates/defguard/src/main.rs +++ b/crates/defguard/src/main.rs @@ -23,17 +23,14 @@ use defguard_core::{ limits::update_counts, }, events::{ApiEvent, BidiStreamEvent}, - grpc::{ - WorkerState, - gateway::{events::GatewayEvent, run_grpc_gateway_stream}, - run_grpc_server, - }, + grpc::{GatewayEvent, WorkerState, run_grpc_server}, init_dev_env, init_vpn_location, run_web_server, utility_thread::run_utility_thread, version::IncompatibleComponents, }; use defguard_event_logger::{message::EventLoggerMessage, run_event_logger}; use defguard_event_router::{RouterReceiverSet, run_event_router}; +use defguard_gateway_manager::GatewayManager; use defguard_proxy_manager::{ProxyManager, ProxyTxSet}; use defguard_session_manager::{events::SessionManagerEvent, run_session_manager}; use defguard_setup::setup::run_setup_web_server; @@ -183,10 +180,12 @@ async fn main() -> Result<(), anyhow::Error> { proxy_control_rx, ); + let mut gateway_manager = GatewayManager::default(); + // run services tokio::select! { res = proxy_manager.run() => error!("ProxyManager returned early: {res:?}"), - res = run_grpc_gateway_stream( + res = gateway_manager.run( pool.clone(), gateway_tx.clone(), peer_stats_tx, diff --git a/crates/defguard_common/src/db/models/gateway.rs b/crates/defguard_common/src/db/models/gateway.rs index 613d5f2b47..82e1f53655 100644 --- a/crates/defguard_common/src/db/models/gateway.rs +++ b/crates/defguard_common/src/db/models/gateway.rs @@ -15,7 +15,7 @@ pub struct Gateway { pub hostname: Option, pub connected_at: Option, pub disconnected_at: Option, - pub has_certificate: bool, + pub certificate: Option, pub certificate_expiry: Option, pub version: Option, pub name: String, @@ -43,7 +43,7 @@ impl Gateway { hostname: None, connected_at: None, disconnected_at: None, - has_certificate: false, + certificate: None, certificate_expiry: None, version: None, name: name.into(), diff --git a/crates/defguard_core/Cargo.toml b/crates/defguard_core/Cargo.toml index 193eb80d18..15ec63bf76 100644 --- a/crates/defguard_core/Cargo.toml +++ b/crates/defguard_core/Cargo.toml @@ -66,6 +66,7 @@ tokio-util = { workspace = true } tonic = { workspace = true } tonic-health = { workspace = true } totp-lite = { workspace = true } +tower = { workspace = true } tower-http = { workspace = true } tracing = { workspace = true } trait-variant = { workspace = true } @@ -79,13 +80,11 @@ webauthn-rs-proto = { workspace = true } x25519-dalek = { workspace = true } ammonia = "4.1" regex = "1.10" -tower = "0.5" uaparser = "0.6" async-stream = "0.3" [dev-dependencies] claims.workspace = true -hyper-util = "0.1" matches.workspace = true reqwest = { version = "0.12", features = [ "cookies", diff --git a/crates/defguard_core/src/appstate.rs b/crates/defguard_core/src/appstate.rs index 6b1ea4b8d2..2db868f0af 100644 --- a/crates/defguard_core/src/appstate.rs +++ b/crates/defguard_core/src/appstate.rs @@ -23,7 +23,7 @@ use crate::{ db::{AppEvent, WebHook}, error::WebError, events::ApiEvent, - grpc::gateway::{events::GatewayEvent, send_multiple_wireguard_events, send_wireguard_event}, + grpc::{GatewayEvent, send_multiple_wireguard_events, send_wireguard_event}, version::IncompatibleComponents, }; diff --git a/crates/defguard_core/src/enterprise/db/models/acl.rs b/crates/defguard_core/src/enterprise/db/models/acl.rs index ccf1f7edd9..c876dff947 100644 --- a/crates/defguard_core/src/enterprise/db/models/acl.rs +++ b/crates/defguard_core/src/enterprise/db/models/acl.rs @@ -32,7 +32,7 @@ use crate::{ ApiAclRule, EditAclRule, alias::EditAclAlias, destination::EditAclDestination, }, }, - grpc::gateway::events::GatewayEvent, + grpc::GatewayEvent, }; #[derive(Debug, Error)] diff --git a/crates/defguard_core/src/enterprise/directory_sync/mod.rs b/crates/defguard_core/src/enterprise/directory_sync/mod.rs index ea73b0b6d1..1bb93a4098 100644 --- a/crates/defguard_core/src/enterprise/directory_sync/mod.rs +++ b/crates/defguard_core/src/enterprise/directory_sync/mod.rs @@ -29,7 +29,7 @@ use crate::{ utils::{ldap_add_users_to_groups, ldap_delete_users, ldap_remove_users_from_groups}, }, }, - grpc::gateway::events::GatewayEvent, + grpc::GatewayEvent, handlers::user::check_username, user_management::{delete_user_and_cleanup_devices, disable_user, sync_allowed_user_devices}, }; diff --git a/crates/defguard_core/src/enterprise/snat/handlers.rs b/crates/defguard_core/src/enterprise/snat/handlers.rs index 790a02e8cf..f0552c3938 100644 --- a/crates/defguard_core/src/enterprise/snat/handlers.rs +++ b/crates/defguard_core/src/enterprise/snat/handlers.rs @@ -21,7 +21,7 @@ use crate::{ }, error::WebError, events::{ApiEvent, ApiEventType, ApiRequestContext}, - grpc::gateway::events::GatewayEvent, + grpc::GatewayEvent, handlers::{ApiResponse, ApiResult}, }; diff --git a/crates/defguard_core/src/grpc/auth.rs b/crates/defguard_core/src/grpc/auth.rs index 6f61986355..b68ce63556 100644 --- a/crates/defguard_core/src/grpc/auth.rs +++ b/crates/defguard_core/src/grpc/auth.rs @@ -11,7 +11,7 @@ use tonic::{Request, Response, Status}; use crate::auth::failed_login::{FailedLoginMap, check_failed_logins, log_failed_login_attempt}; -pub struct AuthServer { +pub(super) struct AuthServer { pool: PgPool, failed_logins: Arc>, } diff --git a/crates/defguard_core/src/grpc/gateway/events.rs b/crates/defguard_core/src/grpc/gateway/events.rs deleted file mode 100644 index 68596f21b8..0000000000 --- a/crates/defguard_core/src/grpc/gateway/events.rs +++ /dev/null @@ -1,30 +0,0 @@ -use defguard_common::db::{ - Id, - models::{ - Device, WireguardNetwork, - device::{DeviceInfo, WireguardNetworkDevice}, - }, -}; -use defguard_proto::{enterprise::firewall::FirewallConfig, gateway::Peer}; - -type LocationId = Id; - -// TODO: move this to common crate -#[derive(Clone, Debug)] -pub enum GatewayEvent { - NetworkCreated(LocationId, WireguardNetwork), - NetworkModified( - LocationId, - WireguardNetwork, - Vec, - Option, - ), - NetworkDeleted(LocationId, String), - DeviceCreated(DeviceInfo), - DeviceModified(DeviceInfo), - DeviceDeleted(DeviceInfo), - FirewallConfigChanged(LocationId, FirewallConfig), - FirewallDisabled(LocationId), - MfaSessionAuthorized(LocationId, Device, WireguardNetworkDevice), - MfaSessionDisconnected(LocationId, Device), -} diff --git a/crates/defguard_core/src/grpc/gateway/handler.rs b/crates/defguard_core/src/grpc/gateway/handler.rs deleted file mode 100644 index d4974b8a1d..0000000000 --- a/crates/defguard_core/src/grpc/gateway/handler.rs +++ /dev/null @@ -1,379 +0,0 @@ -use std::{ - str::FromStr, - sync::atomic::{AtomicU64, Ordering}, -}; - -use defguard_certs::der_to_pem; -use defguard_common::{ - VERSION, - db::{ - Id, - models::{Settings, WireguardNetwork, gateway::Gateway}, - }, - messages::peer_stats_update::PeerStatsUpdate, -}; -use defguard_proto::gateway::{CoreResponse, core_request, core_response, gateway_client}; -use defguard_version::client::ClientVersionInterceptor; -use reqwest::Url; -use semver::Version; -use sqlx::PgPool; -use tokio::{ - sync::{ - broadcast::Sender, - mpsc::{self, UnboundedSender}, - }, - time::sleep, -}; -use tokio_stream::wrappers::UnboundedReceiverStream; -use tonic::transport::{Certificate, ClientTlsConfig, Endpoint}; - -use crate::{ - enterprise::firewall::try_get_location_firewall_config, - grpc::{ - TEN_SECS, - gateway::{GatewayError, events::GatewayEvent, try_protos_into_stats_message}, - }, - handlers::mail::send_gateway_disconnected_email, - location_management::allowed_peers::get_location_allowed_peers, -}; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum Scheme { - #[allow(dead_code)] - Http, - Https, -} - -impl Scheme { - #[must_use] - pub const fn as_str(&self) -> &str { - match self { - Self::Http => "http", - Self::Https => "https", - } - } -} - -/// One instance per connected Gateway. -pub(crate) struct GatewayHandler { - // Gateway server endpoint URL. - url: Url, - gateway: Gateway, - message_id: AtomicU64, - pool: PgPool, - events_tx: Sender, - peer_stats_tx: UnboundedSender, -} - -impl GatewayHandler { - pub(crate) fn new( - gateway: Gateway, - pool: PgPool, - events_tx: Sender, - peer_stats_tx: UnboundedSender, - ) -> Result { - let url = Url::from_str(&gateway.url).map_err(|err| { - GatewayError::EndpointError(format!( - "Failed to parse Gateway URL {}: {}", - &gateway.url, err - )) - })?; - - Ok(Self { - url, - gateway, - message_id: AtomicU64::new(0), - pool, - events_tx, - peer_stats_tx, - }) - } - - fn endpoint(&self, scheme: Scheme) -> Result { - let mut url = self.url.clone(); - - if let Err(()) = url.set_scheme(scheme.as_str()) { - return Err(GatewayError::EndpointError(format!( - "Failed to set scheme {} for Gateway URL {:?}", - scheme.as_str(), - self.url - ))); - } - - let endpoint = Endpoint::from_shared(url.to_string()) - .map_err(|err| { - GatewayError::EndpointError(format!( - "Failed to create endpoint for Gateway URL {url:?}: {err}", - )) - })? - .http2_keep_alive_interval(TEN_SECS) - .tcp_keepalive(Some(TEN_SECS)) - .keep_alive_while_idle(true); - - if scheme == Scheme::Https { - let settings = Settings::get_current_settings(); - let Some(ca_cert_der) = settings.ca_cert_der else { - return Err(GatewayError::EndpointError( - "Core CA is not setup, can't create a Gateway endpoint.".to_string(), - )); - }; - - let cert_pem = der_to_pem(&ca_cert_der, defguard_certs::PemLabel::Certificate) - .map_err(|err| { - GatewayError::EndpointError(format!( - "Failed to convert CA certificate DER to PEM for Gateway URL {url:?}: {err}", - )) - })?; - let tls = ClientTlsConfig::new().ca_certificate(Certificate::from_pem(&cert_pem)); - - Ok(endpoint.tls_config(tls).map_err(|err| { - GatewayError::EndpointError(format!( - "Failed to set TLS config for Gateway URL {url:?}: {err}", - )) - })?) - } else { - Ok(endpoint) - } - } - - /// Send network and VPN configuration to Gateway. - async fn send_configuration( - &self, - tx: &UnboundedSender, - ) -> Result, GatewayError> { - debug!("Sending configuration to Gateway"); - let network_id = self.gateway.network_id; - - let mut conn = self.pool.acquire().await?; - - let mut network = WireguardNetwork::find_by_id(&mut *conn, network_id) - .await? - .ok_or_else(|| { - GatewayError::NotFound(format!("Network with id {network_id} not found")) - })?; - - debug!( - "Sending configuration to {}, network {network}", - self.gateway - ); - if let Err(err) = network.touch_connected(&mut *conn).await { - error!( - "Failed to update connection time for network {network_id} in the database, \ - status: {err}" - ); - } - - let peers = get_location_allowed_peers(&network, &self.pool).await?; - - let maybe_firewall_config = try_get_location_firewall_config(&network, &mut conn).await?; - let payload = Some(core_response::Payload::Config(super::gen_config( - &network, - peers, - maybe_firewall_config, - ))); - let id = self.message_id.fetch_add(1, Ordering::Relaxed); - let req = CoreResponse { id, payload }; - match tx.send(req) { - Ok(()) => { - info!("Configuration sent to {}, network {network}", self.gateway); - Ok(network) - } - Err(err) => { - error!("Failed to send configuration sent to {}", self.gateway); - Err(GatewayError::MessageChannelError(format!( - "Configuration not sent to {}, error {err}", - self.gateway - ))) - } - } - } - - /// Send gateway disconnected notification. - /// Sends notification only if last notification time is bigger than specified in config. - async fn send_disconnect_notification(&self) { - debug!("Sending gateway disconnect email notification"); - let hostname = self.gateway.hostname.clone(); - let pool = self.pool.clone(); - let url = self.gateway.url.clone(); - - let Ok(Some(network)) = - WireguardNetwork::find_by_id(&self.pool, self.gateway.network_id).await - else { - error!( - "Failed to fetch network ID {} from database", - self.gateway.network_id - ); - return; - }; - - // Send email only if disconnection time is before the connection time. - let send_email = if let (Some(connected_at), Some(disconnected_at)) = - (self.gateway.connected_at, self.gateway.disconnected_at) - { - disconnected_at <= connected_at - } else { - true - }; - if send_email { - // FIXME: Try to get rid of spawn and use something like block_on - // To return result instead of logging - tokio::spawn(async move { - if let Err(err) = - send_gateway_disconnected_email(hostname, network.name, &url, &pool).await - { - error!("Failed to send gateway disconnect notification: {err}"); - } else { - info!("Email notification sent about gateway being disconnected"); - } - }); - } else { - info!( - "{} disconnected. Email notification not sent.", - self.gateway - ); - } - } - - /// Connect to Gateway and handle its messages through gRPC. - pub(crate) async fn handle_connection(&mut self) -> Result<(), GatewayError> { - let endpoint = self.endpoint(Scheme::Https)?; - let uri = endpoint.uri().to_string(); - loop { - #[cfg(not(test))] - let channel = endpoint.connect_lazy(); - #[cfg(test)] - let channel = endpoint.connect_with_connector_lazy(tower::service_fn( - |_: tonic::transport::Uri| async { - Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new( - tokio::net::UnixStream::connect(super::TONIC_SOCKET).await?, - )) - }, - )); - - debug!("Connecting to Gateway {uri}"); - let interceptor = ClientVersionInterceptor::new( - Version::parse(VERSION).expect("failed to parse self version"), - ); - let mut client = gateway_client::GatewayClient::with_interceptor(channel, interceptor); - let (tx, rx) = mpsc::unbounded_channel(); - let response = match client.bidi(UnboundedReceiverStream::new(rx)).await { - Ok(response) => response, - Err(err) => { - error!("Failed to connect to Gateway {uri}, retrying: {err}"); - sleep(TEN_SECS).await; - continue; - } - }; - info!("Connected to Defguard Gateway {uri}"); - - let maybe_info = defguard_version::ComponentInfo::from_metadata(response.metadata()); - let (version, _info) = defguard_version::get_tracing_variables(&maybe_info); - - if let Some(mut gateway) = Gateway::find_by_id(&self.pool, self.gateway.id).await? { - gateway.version = Some(version.to_string()); - gateway.save(&self.pool).await?; - } - - let mut resp_stream = response.into_inner(); - let mut config_sent = false; - - 'message: loop { - match resp_stream.message().await { - Ok(None) => { - info!("Stream was closed by the sender."); - break 'message; - } - Ok(Some(received)) => { - info!("Received message from Gateway."); - debug!("Message from Gateway {uri}"); - - match received.payload { - Some(core_request::Payload::ConfigRequest(config_request)) => { - if config_sent { - warn!( - "Ignoring repeated configuration request from {}", - self.gateway - ); - continue; - } - - // Send network configuration to Gateway. - match self.send_configuration(&tx).await { - Ok(network) => { - info!("Sent configuration to {}", self.gateway); - config_sent = true; - let _ = self - .gateway - .touch_connected(&self.pool, config_request.hostname) - .await; - let mut updates_handler = super::GatewayUpdatesHandler::new( - self.gateway.network_id, - network, - self.gateway - .hostname - .clone() - .unwrap_or_default() - .clone(), - self.events_tx.subscribe(), - tx.clone(), - ); - tokio::spawn(async move { - updates_handler.run().await; - }); - } - Err(err) => { - error!( - "Failed to send configuration to {}: {err}", - self.gateway - ); - } - } - } - Some(core_request::Payload::PeerStats(peer_stats)) => { - if !config_sent { - warn!( - "Ignoring peer statistics from {} because it hasn't \ - authorized itself", - self.gateway - ); - continue; - } - - // convert stats to DB storage format - match try_protos_into_stats_message( - peer_stats.clone(), - self.gateway.network_id, - self.gateway.id, - ) { - None => { - warn!( - "Failed to parse peer stats update. Skipping sending \ - message to session manager." - ); - } - Some(message) => { - if let Err(err) = self.peer_stats_tx.send(message) { - error!( - "Failed to send peers stats update to session manager: {err}" - ); - } - } - } - } - None => (), - } - } - Err(err) => { - error!("Disconnected from Gateway at {uri}, error: {err}"); - // Important: call this funtion before setting disconnection time. - self.send_disconnect_notification().await; - let _ = self.gateway.touch_disconnected(&self.pool).await; - debug!("Waiting 10s to re-establish the connection"); - sleep(TEN_SECS).await; - break 'message; - } - } - } - } - } -} diff --git a/crates/defguard_core/src/grpc/mod.rs b/crates/defguard_core/src/grpc/mod.rs index b582e18b55..ef8a277d61 100644 --- a/crates/defguard_core/src/grpc/mod.rs +++ b/crates/defguard_core/src/grpc/mod.rs @@ -5,18 +5,6 @@ use std::{ time::{Duration, Instant}, }; -use defguard_common::{ - auth::claims::ClaimsType, - db::{Id, models::Settings}, - types::UrlParseError, -}; -use reqwest::Url; -use serde::Serialize; -use sqlx::PgPool; -use tokio::sync::mpsc::UnboundedSender; -use tonic::transport::{Identity, Server, ServerTlsConfig, server::Router}; - -use self::{auth::AuthServer, interceptor::JwtInterceptor, worker::WorkerServer}; use crate::{ auth::failed_login::FailedLoginMap, db::AppEvent, @@ -25,15 +13,31 @@ use crate::{ enterprise_settings::{ClientTrafficPolicy, EnterpriseSettings}, openid_provider::OpenIdProvider, }, - is_business_license_active, + is_business_license_active, is_enterprise_license_active, }, - server_config, + grpc::{auth::AuthServer, interceptor::JwtInterceptor, worker::WorkerServer}, +}; +use defguard_common::{ + auth::claims::ClaimsType, + config::server_config, + db::{ + Id, + models::{ + Device, Settings, WireguardNetwork, + device::{DeviceInfo, WireguardNetworkDevice}, + wireguard::ServiceLocationMode, + }, + }, + types::UrlParseError, }; +use reqwest::Url; +use serde::Serialize; +use sqlx::PgPool; +use tokio::sync::{broadcast::Sender, mpsc::UnboundedSender}; mod auth; pub mod client_version; -pub mod gateway; -mod interceptor; +pub mod interceptor; pub mod proxy; pub mod utils; pub mod worker; @@ -47,16 +51,16 @@ pub mod proto { } use defguard_proto::{ - auth::auth_service_server::AuthServiceServer, - worker::worker_service_server::WorkerServiceServer, + auth::auth_service_server::AuthServiceServer, enterprise::firewall::FirewallConfig, + gateway::Peer, worker::worker_service_server::WorkerServiceServer, }; +use tonic::transport::{Identity, Server, ServerTlsConfig, server::Router}; // gRPC header for passing auth token from clients pub static AUTHORIZATION_HEADER: &str = "authorization"; // gRPC header for passing hostname from clients pub static HOSTNAME_HEADER: &str = "hostname"; - const TEN_SECS: Duration = Duration::from_secs(10); /// Runs gRPC server with core services. @@ -91,7 +95,7 @@ pub async fn run_grpc_server( Ok(()) } -pub async fn build_grpc_service_router( +pub(crate) async fn build_grpc_service_router( server: Server, pool: PgPool, worker_state: Arc>, @@ -212,3 +216,47 @@ impl From for defguard_proto::proxy::InstanceInfo { } } } + +// TODO: move this to common crate +#[derive(Clone, Debug)] +pub enum GatewayEvent { + NetworkCreated(Id, WireguardNetwork), + NetworkModified(Id, WireguardNetwork, Vec, Option), + NetworkDeleted(Id, String), + DeviceCreated(DeviceInfo), + DeviceModified(DeviceInfo), + DeviceDeleted(DeviceInfo), + FirewallConfigChanged(Id, FirewallConfig), + FirewallDisabled(Id), + MfaSessionAuthorized(Id, Device, WireguardNetworkDevice), + MfaSessionDisconnected(Id, Device), +} + +/// Sends given `GatewayEvent` to be handled by gateway GRPC server +/// +/// If you want to use it inside the API context, use [`crate::AppState::send_wireguard_event`] instead +pub fn send_wireguard_event(event: GatewayEvent, wg_tx: &Sender) { + debug!("Sending the following WireGuard event to Defguard Gateway: {event:?}"); + if let Err(err) = wg_tx.send(event) { + error!("Error sending WireGuard event {err}"); + } +} + +/// Sends multiple events to be handled by gateway gRPC server. +/// +/// If you want to use it inside the API context, use [`crate::AppState::send_multiple_wireguard_events`] instead +pub fn send_multiple_wireguard_events(events: Vec, wg_tx: &Sender) { + debug!("Sending {} WireGuard events", events.len()); + for event in events { + send_wireguard_event(event, wg_tx); + } +} + +/// If this location is marked as a service location, checks if all requirements are met for it to +/// function: +/// - Enterprise is enabled +#[must_use] +pub fn should_prevent_service_location_usage(location: &WireguardNetwork) -> bool { + location.service_location_mode != ServiceLocationMode::Disabled + && !is_enterprise_license_active() +} diff --git a/crates/defguard_core/src/grpc/proxy/client_mfa.rs b/crates/defguard_core/src/grpc/proxy/client_mfa.rs index 9fdcb75151..a2aab9a605 100644 --- a/crates/defguard_core/src/grpc/proxy/client_mfa.rs +++ b/crates/defguard_core/src/grpc/proxy/client_mfa.rs @@ -40,7 +40,7 @@ use tonic::{Code, Status}; use crate::{ enterprise::{db::models::openid_provider::OpenIdProvider, is_business_license_active}, events::{BidiRequestContext, BidiStreamEvent, BidiStreamEventType, DesktopClientMfaEvent}, - grpc::{gateway::events::GatewayEvent, utils::parse_client_ip_agent}, + grpc::{GatewayEvent, utils::parse_client_ip_agent}, }; const CLIENT_SESSION_TIMEOUT: u64 = 60 * 5; // 10 minutes diff --git a/crates/defguard_core/src/grpc/utils.rs b/crates/defguard_core/src/grpc/utils.rs index 64e1e2e619..a9ac22b5fc 100644 --- a/crates/defguard_core/src/grpc/utils.rs +++ b/crates/defguard_core/src/grpc/utils.rs @@ -7,7 +7,6 @@ use defguard_common::{ models::{ Device, DeviceType, Settings, User, WireguardNetwork, device::WireguardNetworkDevice, - polling_token::PollingToken, wireguard::{LocationMfaMode, ServiceLocationMode}, }, }, @@ -24,48 +23,9 @@ use crate::{ enterprise::db::models::{ enterprise_settings::EnterpriseSettings, openid_provider::OpenIdProvider, }, - grpc::{client_version::ClientFeature, gateway::should_prevent_service_location_usage}, + grpc::{client_version::ClientFeature, should_prevent_service_location_usage}, }; -// Create a new token for configuration polling. -pub async fn new_polling_token(pool: &PgPool, device: &Device) -> Result { - debug!( - "Making a new polling token for device {}", - device.wireguard_pubkey - ); - let mut transaction = pool.begin().await.map_err(|err| { - error!("Failed to start transaction while making a new polling token: {err}"); - Status::internal(format!("unexpected error: {err}")) - })?; - - // 1. Delete existing polling token for the device, if it exists - // 2. Create a new polling token for the device - PollingToken::delete_for_device_id(&mut *transaction, device.id) - .await - .map_err(|err| { - error!("Failed to delete polling token: {err}"); - Status::internal(format!("unexpected error: {err}")) - })?; - let new_token = PollingToken::new(device.id) - .save(&mut *transaction) - .await - .map_err(|err| { - error!("Failed to save new polling token: {err}"); - Status::internal(format!("unexpected error: {err}")) - })?; - - transaction.commit().await.map_err(|err| { - error!("Failed to commit transaction while making a new polling token: {err}"); - Status::internal(format!("unexpected error: {err}")) - })?; - info!( - "New polling token created for device {}", - device.wireguard_pubkey - ); - - Ok(new_token.token) -} - pub async fn build_device_config_response( pool: &PgPool, device: Device, diff --git a/crates/defguard_core/src/handlers/component_setup.rs b/crates/defguard_core/src/handlers/component_setup.rs index 44dbdaf05e..fd8fa9d03f 100644 --- a/crates/defguard_core/src/handlers/component_setup.rs +++ b/crates/defguard_core/src/handlers/component_setup.rs @@ -925,6 +925,7 @@ pub async fn setup_gateway_tls_stream( let defguard_certs::CertificateInfo { not_after: expiry, + serial, .. } = match parse_certificate_info(cert.der()) { Ok(dt) => { @@ -944,7 +945,7 @@ pub async fn setup_gateway_tls_stream( request.common_name, ); - gateway.has_certificate = true; + gateway.certificate = Some(serial); gateway.certificate_expiry = Some(expiry); if let Err(err) = gateway.save(&pool).await { diff --git a/crates/defguard_core/src/handlers/mail.rs b/crates/defguard_core/src/handlers/mail.rs index 26cf6a3c54..8e431fbd59 100644 --- a/crates/defguard_core/src/handlers/mail.rs +++ b/crates/defguard_core/src/handlers/mail.rs @@ -127,7 +127,7 @@ pub async fn send_support_data( "network_id": g.network_id, "version": g.version.as_deref().unwrap_or("unknown"), "url": g.url, - "has_certificate": g.has_certificate, + "certificate": g.certificate, "hostname": g.hostname, "connected_at": g.connected_at, })).collect::>(), diff --git a/crates/defguard_core/src/handlers/network_devices.rs b/crates/defguard_core/src/handlers/network_devices.rs index 2a29be7f09..1120c730fc 100644 --- a/crates/defguard_core/src/handlers/network_devices.rs +++ b/crates/defguard_core/src/handlers/network_devices.rs @@ -31,7 +31,7 @@ use crate::{ enrollment_management::start_desktop_configuration, enterprise::{firewall::try_get_location_firewall_config, limits::update_counts}, events::{ApiEvent, ApiEventType, ApiRequestContext}, - grpc::gateway::events::GatewayEvent, + grpc::GatewayEvent, server_config, }; diff --git a/crates/defguard_core/src/handlers/wireguard.rs b/crates/defguard_core/src/handlers/wireguard.rs index df54c04d4c..ca21dfe9b2 100644 --- a/crates/defguard_core/src/handlers/wireguard.rs +++ b/crates/defguard_core/src/handlers/wireguard.rs @@ -44,7 +44,7 @@ use crate::{ limits::{get_counts, update_counts}, }, events::{ApiEvent, ApiEventType, ApiRequestContext}, - grpc::gateway::events::GatewayEvent, + grpc::GatewayEvent, location_management::{ allowed_peers::get_location_allowed_peers, handle_imported_devices, handle_mapped_devices, sync_location_allowed_devices, diff --git a/crates/defguard_core/src/lib.rs b/crates/defguard_core/src/lib.rs index ee100c6825..bb237404d6 100644 --- a/crates/defguard_core/src/lib.rs +++ b/crates/defguard_core/src/lib.rs @@ -1,6 +1,4 @@ #![allow(clippy::too_many_arguments)] -// FIXME: actually refactor errors instead -#![allow(clippy::result_large_err)] use std::{ net::{IpAddr, Ipv4Addr, SocketAddr}, sync::{Arc, LazyLock, Mutex, RwLock}, @@ -108,7 +106,7 @@ use crate::{ create_snat_binding, delete_snat_binding, list_snat_bindings, modify_snat_binding, }, }, - grpc::{WorkerState, gateway::events::GatewayEvent}, + grpc::{GatewayEvent, WorkerState}, handlers::{ app_info::get_app_info, auth::{ diff --git a/crates/defguard_core/src/location_management/allowed_peers.rs b/crates/defguard_core/src/location_management/allowed_peers.rs index 8017ca1c83..44bd0b14d9 100644 --- a/crates/defguard_core/src/location_management/allowed_peers.rs +++ b/crates/defguard_core/src/location_management/allowed_peers.rs @@ -2,7 +2,7 @@ use defguard_common::db::{Id, models::WireguardNetwork}; use defguard_proto::gateway::Peer; use sqlx::{Error as SqlxError, PgExecutor, query}; -use crate::grpc::gateway::should_prevent_service_location_usage; +use crate::grpc::should_prevent_service_location_usage; /// Get a list of all allowed peers for a given location /// @@ -62,7 +62,7 @@ where } else { None }, - keepalive_interval: Some(location.keepalive_interval as u32), + keepalive_interval: Some(location.keepalive_interval.cast_unsigned()), }) .collect(); diff --git a/crates/defguard_core/src/location_management/mod.rs b/crates/defguard_core/src/location_management/mod.rs index 8ff7a45642..400be93f6c 100644 --- a/crates/defguard_core/src/location_management/mod.rs +++ b/crates/defguard_core/src/location_management/mod.rs @@ -19,7 +19,7 @@ use tokio::sync::broadcast::Sender; use crate::{ enterprise::firewall::{FirewallError, try_get_location_firewall_config}, - grpc::gateway::{events::GatewayEvent, send_multiple_wireguard_events}, + grpc::{GatewayEvent, send_multiple_wireguard_events}, wg_config::ImportedDevice, }; @@ -410,7 +410,6 @@ mod test { use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use super::*; - use crate::grpc::gateway::events::GatewayEvent; #[sqlx::test] async fn test_sync_allowed_devices_for_user(_: PgPoolOptions, options: PgConnectOptions) { diff --git a/crates/defguard_core/src/user_management.rs b/crates/defguard_core/src/user_management.rs index 449e5d5b25..1fc5d7cea9 100644 --- a/crates/defguard_core/src/user_management.rs +++ b/crates/defguard_core/src/user_management.rs @@ -10,7 +10,7 @@ use tokio::sync::broadcast::Sender; use crate::{ enterprise::{firewall::try_get_location_firewall_config, limits::update_counts}, error::WebError, - grpc::gateway::{events::GatewayEvent, send_multiple_wireguard_events, send_wireguard_event}, + grpc::{GatewayEvent, send_multiple_wireguard_events, send_wireguard_event}, location_management::sync_allowed_devices_for_user, }; diff --git a/crates/defguard_core/src/utility_thread.rs b/crates/defguard_core/src/utility_thread.rs index c4d1f416d2..cf071fe5da 100644 --- a/crates/defguard_core/src/utility_thread.rs +++ b/crates/defguard_core/src/utility_thread.rs @@ -20,7 +20,7 @@ use crate::{ ldap::{do_ldap_sync, sync::get_ldap_sync_interval}, limits::do_count_update, }, - grpc::gateway::events::GatewayEvent, + grpc::GatewayEvent, location_management::allowed_peers::get_location_allowed_peers, updates::do_new_version_check, }; diff --git a/crates/defguard_core/tests/integration/api/common/mod.rs b/crates/defguard_core/tests/integration/api/common/mod.rs index 2abc0db1cf..6d9ad02340 100644 --- a/crates/defguard_core/tests/integration/api/common/mod.rs +++ b/crates/defguard_core/tests/integration/api/common/mod.rs @@ -20,7 +20,7 @@ use defguard_core::{ db::AppEvent, enterprise::license::{License, LicenseTier, set_cached_license}, events::ApiEvent, - grpc::{WorkerState, gateway::events::GatewayEvent}, + grpc::{GatewayEvent, WorkerState}, handlers::{Auth, user::UserDetails}, }; use reqwest::{StatusCode, header::HeaderName}; diff --git a/crates/defguard_core/tests/integration/api/wireguard.rs b/crates/defguard_core/tests/integration/api/wireguard.rs index a80082623e..396a9ea755 100644 --- a/crates/defguard_core/tests/integration/api/wireguard.rs +++ b/crates/defguard_core/tests/integration/api/wireguard.rs @@ -20,7 +20,7 @@ use defguard_core::{ handlers::openid_providers::AddProviderData, license::{get_cached_license, set_cached_license}, }, - grpc::gateway::events::GatewayEvent, + grpc::GatewayEvent, handlers::{Auth, GroupInfo, wireguard::WireguardNetworkData}, }; use ipnetwork::IpNetwork; diff --git a/crates/defguard_core/tests/integration/api/wireguard_network_allowed_groups.rs b/crates/defguard_core/tests/integration/api/wireguard_network_allowed_groups.rs index 86f5edaa47..9fb3b58fe7 100644 --- a/crates/defguard_core/tests/integration/api/wireguard_network_allowed_groups.rs +++ b/crates/defguard_core/tests/integration/api/wireguard_network_allowed_groups.rs @@ -9,7 +9,7 @@ use defguard_common::{ }, }; use defguard_core::{ - grpc::gateway::events::GatewayEvent, + grpc::GatewayEvent, handlers::{Auth, wireguard::ImportedNetworkData}, location_management::allowed_peers::get_location_allowed_peers, }; diff --git a/crates/defguard_core/tests/integration/api/wireguard_network_devices.rs b/crates/defguard_core/tests/integration/api/wireguard_network_devices.rs index 211e7cd33d..d91ae93c88 100644 --- a/crates/defguard_core/tests/integration/api/wireguard_network_devices.rs +++ b/crates/defguard_core/tests/integration/api/wireguard_network_devices.rs @@ -5,7 +5,7 @@ use defguard_common::db::{ models::{Device, WireguardNetwork}, }; use defguard_core::{ - grpc::gateway::events::GatewayEvent, + grpc::GatewayEvent, handlers::{Auth, network_devices::AddNetworkDevice}, }; use ipnetwork::IpNetwork; diff --git a/crates/defguard_core/tests/integration/api/wireguard_network_import.rs b/crates/defguard_core/tests/integration/api/wireguard_network_import.rs index 958f62e5a2..b609239668 100644 --- a/crates/defguard_core/tests/integration/api/wireguard_network_import.rs +++ b/crates/defguard_core/tests/integration/api/wireguard_network_import.rs @@ -9,7 +9,7 @@ use defguard_common::db::models::{ }, }; use defguard_core::{ - grpc::gateway::events::GatewayEvent, + grpc::GatewayEvent, handlers::{Auth, wireguard::ImportedNetworkData}, }; use matches::assert_matches; diff --git a/crates/defguard_event_router/src/lib.rs b/crates/defguard_event_router/src/lib.rs index 0d8a90b972..b46e6d2579 100644 --- a/crates/defguard_event_router/src/lib.rs +++ b/crates/defguard_event_router/src/lib.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use defguard_core::{ events::{ApiEvent, BidiStreamEvent}, - grpc::gateway::events::GatewayEvent, + grpc::GatewayEvent, }; use defguard_event_logger::message::{EventContext, EventLoggerMessage, LoggerEvent}; use defguard_session_manager::events::SessionManagerEvent; diff --git a/crates/defguard_gateway_manager/Cargo.toml b/crates/defguard_gateway_manager/Cargo.toml new file mode 100644 index 0000000000..9afb53828e --- /dev/null +++ b/crates/defguard_gateway_manager/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "defguard_gateway_manager" +version = "0.0.0" +edition.workspace = true +license-file.workspace = true +homepage.workspace = true +repository.workspace = true +rust-version.workspace = true + +[dependencies] +defguard_certs.workspace = true +defguard_common.workspace = true +defguard_core.workspace = true +defguard_grpc_tls.workspace = true +defguard_proto.workspace = true +defguard_version.workspace = true + +anyhow.workspace = true +chrono.workspace = true +hyper-rustls.workspace = true +reqwest.workspace = true +semver.workspace = true +serde_json.workspace = true +sqlx.workspace = true +thiserror.workspace = true +tokio.workspace = true +tokio-stream.workspace = true +tonic.workspace = true +tower.workspace = true +tracing.workspace = true + +[dev-dependencies] +hyper-util = "0.1" diff --git a/crates/defguard_gateway_manager/src/certs.rs b/crates/defguard_gateway_manager/src/certs.rs new file mode 100644 index 0000000000..a1daf0b53a --- /dev/null +++ b/crates/defguard_gateway_manager/src/certs.rs @@ -0,0 +1,33 @@ +//! Cached certificate serials for gateways. + +use std::{collections::HashMap, sync::Arc}; + +use defguard_common::db::{Id, models::gateway::Gateway}; +use sqlx::PgPool; +use tokio::sync::watch; + +fn collect_certs(items: I) -> HashMap +where + I: IntoIterator)>, +{ + items + .into_iter() + .filter_map(|(id, cert)| cert.map(|cert| (id, cert))) + .collect() +} + +pub(super) async fn refresh_certs(pool: &PgPool, tx: &watch::Sender>>) { + match Gateway::all(pool).await { + Ok(gateways) => { + let certs = collect_certs( + gateways + .into_iter() + .map(|gateway| (gateway.id, gateway.certificate)), + ); + let _ = tx.send(Arc::new(certs)); + } + Err(err) => { + warn!("Failed to refresh gateway certificate list: {err}"); + } + } +} diff --git a/crates/defguard_gateway_manager/src/error.rs b/crates/defguard_gateway_manager/src/error.rs new file mode 100644 index 0000000000..7fde13348e --- /dev/null +++ b/crates/defguard_gateway_manager/src/error.rs @@ -0,0 +1,32 @@ +use defguard_core::{enterprise::firewall::FirewallError, events::GrpcEvent}; +use thiserror::Error; +use tokio::sync::mpsc::error::SendError; +use tonic::{Code, Status}; + +#[allow(clippy::large_enum_variant)] +#[derive(Debug, Error)] +pub(crate) enum GatewayError { + #[error("gRPC event channel error: {0}")] + GrpcEventChannelError(#[from] SendError), + #[error("Endpoint error: {0}")] + EndpointError(String), + #[error("gRPC communication error: {0}")] + GrpcCommunicationError(#[from] tonic::Status), + #[error(transparent)] + CertificateError(#[from] defguard_certs::CertificateError), + #[error(transparent)] + SqlxError(#[from] sqlx::Error), + #[error("Not found: {0}")] + NotFound(String), + // mpsc channel send/receive error + #[error("Message channel error: {0}")] + MessageChannelError(String), + #[error(transparent)] + FirewallError(#[from] FirewallError), +} + +impl From for Status { + fn from(value: GatewayError) -> Self { + Self::new(Code::Internal, value.to_string()) + } +} diff --git a/crates/defguard_core/src/grpc/gateway/mod.rs b/crates/defguard_gateway_manager/src/handler.rs similarity index 53% rename from crates/defguard_core/src/grpc/gateway/mod.rs rename to crates/defguard_gateway_manager/src/handler.rs index d50436a7cc..3dc908d0ee 100644 --- a/crates/defguard_core/src/grpc/gateway/mod.rs +++ b/crates/defguard_gateway_manager/src/handler.rs @@ -1,311 +1,385 @@ -use std::{collections::HashMap, net::IpAddr, time::Duration}; +use std::{ + collections::HashMap, + net::IpAddr, + str::FromStr, + sync::{ + Arc, Mutex, + atomic::{AtomicU64, Ordering}, + }, +}; use chrono::DateTime; +#[cfg(not(test))] +use defguard_common::db::models::Settings; use defguard_common::{ + VERSION, db::{ - ChangeNotification, Id, TriggerOperation, - models::{ - WireguardNetwork, - gateway::Gateway, - wireguard::{DEFAULT_WIREGUARD_MTU, ServiceLocationMode}, - }, + Id, + models::{WireguardNetwork, gateway::Gateway, wireguard::DEFAULT_WIREGUARD_MTU}, }, messages::peer_stats_update::PeerStatsUpdate, }; +#[cfg(not(test))] +use defguard_grpc_tls::{certs as tls_certs, connector::HttpsSchemeConnector}; use defguard_proto::{ enterprise::firewall::FirewallConfig, - gateway::{Configuration, CoreResponse, Peer, PeerStats, Update, core_response, update}, + gateway::{ + Configuration, CoreResponse, Peer, PeerStats, Update, core_request, core_response, + gateway_client, update, + }, }; -use sqlx::{PgExecutor, PgPool, postgres::PgListener, query}; -use thiserror::Error; +use defguard_version::client::ClientVersionInterceptor; +#[cfg(not(test))] +use hyper_rustls::HttpsConnectorBuilder; +use reqwest::Url; +use semver::Version; +use sqlx::PgPool; use tokio::{ sync::{ - broadcast::{Receiver as BroadcastReceiver, Sender}, - mpsc::{UnboundedSender, error::SendError}, + broadcast::{self, Sender}, + mpsc::{self, UnboundedSender}, + watch, }, - task::{AbortHandle, JoinSet}, + time::sleep, }; -use tonic::{Code, Status}; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tonic::{Code, Status, transport::Endpoint}; -use crate::{ - enterprise::{firewall::FirewallError, is_enterprise_license_active}, - events::GrpcEvent, - grpc::gateway::{events::GatewayEvent, handler::GatewayHandler}, +use defguard_core::{ + enterprise::firewall::try_get_location_firewall_config, grpc::GatewayEvent, + handlers::mail::send_gateway_disconnected_email, + location_management::allowed_peers::get_location_allowed_peers, }; -pub mod events; -pub(crate) mod handler; -// #[cfg(test)] -// mod tests; - -#[cfg(test)] -pub(super) static TONIC_SOCKET: &str = "tonic.sock"; - -/// Sends given `GatewayEvent` to be handled by gateway GRPC server -/// -/// If you want to use it inside the API context, use [`crate::AppState::send_wireguard_event`] instead -pub fn send_wireguard_event(event: GatewayEvent, wg_tx: &Sender) { - debug!("Sending the following WireGuard event to Defguard Gateway: {event:?}"); - if let Err(err) = wg_tx.send(event) { - error!("Error sending WireGuard event {err}"); - } +use crate::{Client, TEN_SECS, error::GatewayError}; + +/// One instance per connected Gateway. +pub(super) struct GatewayHandler { + // Gateway server endpoint URL. + url: Url, + gateway: Gateway, + message_id: AtomicU64, + pool: PgPool, + events_tx: Sender, + peer_stats_tx: UnboundedSender, + certs_rx: watch::Receiver>>, } -/// Sends multiple events to be handled by gateway gRPC server. -/// -/// If you want to use it inside the API context, use [`crate::AppState::send_multiple_wireguard_events`] instead -pub fn send_multiple_wireguard_events(events: Vec, wg_tx: &Sender) { - debug!("Sending {} WireGuard events", events.len()); - for event in events { - send_wireguard_event(event, wg_tx); +impl GatewayHandler { + pub fn new( + gateway: Gateway, + pool: PgPool, + events_tx: Sender, + peer_stats_tx: UnboundedSender, + certs_rx: watch::Receiver>>, + ) -> Result { + let url = Url::from_str(&gateway.url).map_err(|err| { + GatewayError::EndpointError(format!( + "Failed to parse Gateway URL {}: {err}", + &gateway.url + )) + })?; + + Ok(Self { + url, + gateway, + message_id: AtomicU64::new(0), + pool, + events_tx, + peer_stats_tx, + certs_rx, + }) } -} -/// Helper used to convert peer stats coming from gRPC client -/// into an internal representation -fn try_protos_into_stats_message( - proto_stats: PeerStats, - location_id: Id, - gateway_id: Id, -) -> Option { - // try to parse endpoint - let endpoint = proto_stats.endpoint.parse().ok()?; + fn endpoint(&self) -> Result { + let mut url = self.url.clone(); - let latest_handshake = DateTime::from_timestamp(proto_stats.latest_handshake as i64, 0) - .unwrap_or_default() - .naive_utc(); + if let Err(()) = url.set_scheme("http") { + return Err(GatewayError::EndpointError(format!( + "Failed to set http scheme for Gateway URL {:?}", + self.url + ))); + } - Some(PeerStatsUpdate::new( - location_id, - gateway_id, - proto_stats.public_key, - endpoint, - proto_stats.upload, - proto_stats.download, - latest_handshake, - )) -} + let endpoint = Endpoint::from_shared(url.to_string()) + .map_err(|err| { + GatewayError::EndpointError(format!( + "Failed to create endpoint for Gateway URL {url:?}: {err}", + )) + })? + .http2_keep_alive_interval(TEN_SECS) + .tcp_keepalive(Some(TEN_SECS)) + .keep_alive_while_idle(true); + + Ok(endpoint) + } -#[allow(clippy::large_enum_variant)] -#[derive(Debug, Error)] -pub enum GatewayError { - #[error("Failed to acquire lock on VPN client state map")] - ClientStateMutexError, - #[error("gRPC event channel error: {0}")] - GrpcEventChannelError(#[from] SendError), - #[error("Endpoint error: {0}")] - EndpointError(String), - #[error("gRPC communication error: {0}")] - GrpcCommunicationError(#[from] tonic::Status), - #[error(transparent)] - CertificateError(#[from] defguard_certs::CertificateError), - #[error("Configuration error: {0}")] - ConfigurationError(String), - #[error("Conversion error: {0}")] - ConversionError(String), - #[error(transparent)] - SqlxError(#[from] sqlx::Error), - #[error("Not found: {0}")] - NotFound(String), - // mpsc channel send/receive error - #[error("Message channel error: {0}")] - MessageChannelError(String), - #[error(transparent)] - FirewallError(#[from] FirewallError), -} + /// Send network and VPN configuration to Gateway. + async fn send_configuration( + &self, + tx: &UnboundedSender, + ) -> Result, GatewayError> { + debug!("Sending configuration to Gateway"); + let network_id = self.gateway.network_id; -impl From for Status { - fn from(value: GatewayError) -> Self { - Self::new(Code::Internal, value.to_string()) - } -} + let mut conn = self.pool.acquire().await?; -/// If this location is marked as a service location, checks if all requirements are met for it to -/// function: -/// - Enterprise is enabled -#[must_use] -pub fn should_prevent_service_location_usage(location: &WireguardNetwork) -> bool { - location.service_location_mode != ServiceLocationMode::Disabled - && !is_enterprise_license_active() -} + let mut network = WireguardNetwork::find_by_id(&mut *conn, network_id) + .await? + .ok_or_else(|| { + GatewayError::NotFound(format!("Network with id {network_id} not found")) + })?; -/// Get a list of all allowed peers -/// -/// Each device is marked as allowed or not allowed in a given network, -/// which enables enforcing peer disconnect in MFA-protected networks. -/// -/// If the location is a service location, only returns peers if enterprise features are enabled. -/// -/// XXX: should be implemented in defguard_core::db::models::wireguard::WireguardNetwork. -pub async fn get_peers<'e, E>( - location: &WireguardNetwork, - executor: E, -) -> Result, sqlx::Error> -where - E: PgExecutor<'e>, -{ - debug!("Fetching all peers for network {}", location.id); - - if should_prevent_service_location_usage(location) { - warn!( - "Tried to use service location {} with disabled enterprise features. No clients \ - will be allowed to connect.", - location.name + debug!( + "Sending configuration to {}, network {network}", + self.gateway ); - return Ok(Vec::new()); - } - - // TODO: possible to not use ARRAY-unnest here? - let rows = query!( - "SELECT d.wireguard_pubkey pubkey, preshared_key, \ - ARRAY( - SELECT host(ip) - FROM unnest(wnd.wireguard_ips) AS ip - ) \"allowed_ips!: Vec\" \ - FROM wireguard_network_device wnd \ - JOIN device d ON wnd.device_id = d.id \ - JOIN \"user\" u ON d.user_id = u.id \ - WHERE wireguard_network_id = $1 AND (is_authorized = true OR NOT $2) \ - AND d.configured = true \ - AND u.is_active = true \ - ORDER BY d.id ASC", - location.id, - location.mfa_enabled() - ) - .fetch_all(executor) - .await?; - - // keepalive has to be added manually because Postgres - // doesn't support unsigned integers - let result = rows - .into_iter() - .map(|row| Peer { - pubkey: row.pubkey, - allowed_ips: row.allowed_ips, - // Don't send preshared key if MFA is not enabled, it can't be used and may - // cause issues with clients connecting if they expect no preshared key - // e.g. when you disable MFA on a location - preshared_key: if location.mfa_enabled() { - row.preshared_key - } else { - None - }, - keepalive_interval: Some(location.keepalive_interval as u32), - }) - .collect(); + if let Err(err) = network.touch_connected(&mut *conn).await { + error!( + "Failed to update connection time for network {network_id} in the database, \ + status: {err}" + ); + } - Ok(result) -} + let peers = get_location_allowed_peers(&network, &self.pool).await?; + + let maybe_firewall_config = try_get_location_firewall_config(&network, &mut conn).await?; + let payload = Some(core_response::Payload::Config(gen_config( + &network, + peers, + maybe_firewall_config, + ))); + let id = self.message_id.fetch_add(1, Ordering::Relaxed); + let req = CoreResponse { id, payload }; + match tx.send(req) { + Ok(()) => { + info!("Configuration sent to {}, network {network}", self.gateway); + Ok(network) + } + Err(err) => { + error!("Failed to send configuration sent to {}", self.gateway); + Err(GatewayError::MessageChannelError(format!( + "Configuration not sent to {}, error {err}", + self.gateway + ))) + } + } + } -fn gen_config( - network: &WireguardNetwork, - peers: Vec, - maybe_firewall_config: Option, -) -> Configuration { - Configuration { - name: network.name.clone(), - port: network.port as u32, - prvkey: network.prvkey.clone(), - addresses: network.address.iter().map(ToString::to_string).collect(), - peers, - firewall_config: maybe_firewall_config, - mtu: network.mtu as u32, - fwmark: network.fwmark as u32, + /// Send gateway disconnected notification. + /// Sends notification only if last notification time is bigger than specified in config. + async fn send_disconnect_notification(&self) { + debug!("Sending gateway disconnect email notification"); + let hostname = self.gateway.hostname.clone(); + let pool = self.pool.clone(); + let url = self.gateway.url.clone(); + + let Ok(Some(network)) = + WireguardNetwork::find_by_id(&self.pool, self.gateway.network_id).await + else { + error!( + "Failed to fetch network ID {} from database", + self.gateway.network_id + ); + return; + }; + + // Send email only if disconnection time is before the connection time. + let send_email = if let (Some(connected_at), Some(disconnected_at)) = + (self.gateway.connected_at, self.gateway.disconnected_at) + { + disconnected_at <= connected_at + } else { + true + }; + if send_email { + // FIXME: Try to get rid of spawn and use something like block_on + // To return result instead of logging + tokio::spawn(async move { + if let Err(err) = + send_gateway_disconnected_email(hostname, network.name, &url, &pool).await + { + error!("Failed to send gateway disconnect notification: {err}"); + } else { + info!("Email notification sent about gateway being disconnected"); + } + }); + } else { + info!( + "{} disconnected. Email notification not sent.", + self.gateway + ); + } } -} -const GATEWAY_TABLE_TRIGGER: &str = "gateway_change"; -const GATEWAY_RECONNECT_DELAY: Duration = Duration::from_secs(5); + /// Connect to Gateway and handle its messages through gRPC. + pub(super) async fn handle_connection( + &mut self, + clients: Arc>>, + ) -> Result<(), GatewayError> { + #[cfg(test)] + let _ = &self.certs_rx; + let endpoint = self.endpoint()?; + let uri = endpoint.uri().to_string(); + loop { + #[cfg(not(test))] + let channel = { + let settings = Settings::get_current_settings(); + let Some(ca_cert_der) = settings.ca_cert_der else { + return Err(GatewayError::EndpointError( + "Core CA is not setup, can't create a Gateway endpoint.".to_string(), + )); + }; + let tls_config = + tls_certs::client_config(&ca_cert_der, self.certs_rx.clone(), self.gateway.id) + .map_err(|err| GatewayError::EndpointError(err.to_string()))?; + let connector = HttpsConnectorBuilder::new() + .with_tls_config(tls_config) + .https_only() + .enable_http2() + .build(); + let connector = HttpsSchemeConnector::new(connector); + endpoint.connect_with_connector_lazy(connector) + }; + #[cfg(test)] + let channel = endpoint.connect_with_connector_lazy(tower::service_fn( + |_: tonic::transport::Uri| async { + Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new( + tokio::net::UnixStream::connect(super::TONIC_SOCKET).await?, + )) + }, + )); + + debug!("Connecting to Gateway {uri}"); + let interceptor = ClientVersionInterceptor::new( + Version::parse(VERSION).expect("failed to parse self version"), + ); + let mut client = gateway_client::GatewayClient::with_interceptor(channel, interceptor); + clients + .lock() + .expect("GatewayHandler failed to lock clients") + .insert(self.gateway.id, client.clone()); + let (tx, rx) = mpsc::unbounded_channel(); + let response = match client.bidi(UnboundedReceiverStream::new(rx)).await { + Ok(response) => response, + Err(err) => { + error!("Failed to connect to Gateway {uri}, retrying: {err}"); + sleep(TEN_SECS).await; + continue; + } + }; + info!("Connected to Defguard Gateway {uri}"); -/// Bi-directional gRPC stream for communication with Defguard Gateway. -pub async fn run_grpc_gateway_stream( - pool: PgPool, - events_tx: Sender, - peer_stats_tx: UnboundedSender, -) -> Result<(), anyhow::Error> { - let mut abort_handles = HashMap::new(); + let maybe_info = defguard_version::ComponentInfo::from_metadata(response.metadata()); + let (version, _info) = defguard_version::get_tracing_variables(&maybe_info); - let mut tasks = JoinSet::new(); - // Helper closure to launch `GatewayHandler`. - let mut launch_gateway_handler = |gateway: Gateway| -> Result { - let mut gateway_handler = GatewayHandler::new( - gateway, - pool.clone(), - events_tx.clone(), - peer_stats_tx.clone(), - )?; - let abort_handle = tasks.spawn(async move { - loop { - if let Err(err) = gateway_handler.handle_connection().await { - error!("Gateway connection error: {err}, retrying in 5 seconds..."); - tokio::time::sleep(GATEWAY_RECONNECT_DELAY).await; - } + if let Some(mut gateway) = Gateway::find_by_id(&self.pool, self.gateway.id).await? { + gateway.version = Some(version.to_string()); + gateway.save(&self.pool).await?; } - }); - Ok(abort_handle) - }; - - for gateway in Gateway::all(&pool).await? { - let id = gateway.id; - let abort_handle = launch_gateway_handler(gateway)?; - abort_handles.insert(id, abort_handle); - } - // Observe gateway URL changes. - let mut listener = PgListener::connect_with(&pool).await?; - listener.listen(GATEWAY_TABLE_TRIGGER).await?; - while let Ok(notification) = listener.recv().await { - let payload = notification.payload(); - match serde_json::from_str::>>(payload) { - Ok(gateway_notification) => match gateway_notification.operation { - TriggerOperation::Insert => { - if let Some(new) = gateway_notification.new { - let id = new.id; - let abort_handle = launch_gateway_handler(new)?; - abort_handles.insert(id, abort_handle); + let mut resp_stream = response.into_inner(); + let mut config_sent = false; + + 'message: loop { + match resp_stream.message().await { + Ok(None) => { + info!("Stream was closed by the sender."); + break 'message; } - } - TriggerOperation::Update => { - if let (Some(old), Some(new)) = - (gateway_notification.old, gateway_notification.new) - { - if old.url == new.url { - debug!( - "Gateway URL didn't change. Keeping the current gateway handler" - ); - } else if let Some(abort_handle) = abort_handles.remove(&old.id) { - info!("Aborting connection to {old}, it has changed in the database"); - abort_handle.abort(); - let id = new.id; - let abort_handle = launch_gateway_handler(new)?; - abort_handles.insert(id, abort_handle); - } else { - warn!("Cannot find {old} on the list of connected gateways"); + Ok(Some(received)) => { + info!("Received message from Gateway."); + debug!("Message from Gateway {uri}"); + + match received.payload { + Some(core_request::Payload::ConfigRequest(config_request)) => { + if config_sent { + warn!( + "Ignoring repeated configuration request from {}", + self.gateway + ); + continue; + } + + // Send network configuration to Gateway. + match self.send_configuration(&tx).await { + Ok(network) => { + info!("Sent configuration to {}", self.gateway); + config_sent = true; + let _ = self + .gateway + .touch_connected(&self.pool, config_request.hostname) + .await; + let mut updates_handler = GatewayUpdatesHandler::new( + self.gateway.network_id, + network, + self.gateway + .hostname + .clone() + .unwrap_or_default() + .clone(), + self.events_tx.subscribe(), + tx.clone(), + ); + tokio::spawn(async move { + updates_handler.run().await; + }); + } + Err(err) => { + error!( + "Failed to send configuration to {}: {err}", + self.gateway + ); + } + } + } + Some(core_request::Payload::PeerStats(peer_stats)) => { + if !config_sent { + warn!( + "Ignoring peer statistics from {} because it hasn't \ + authorized itself", + self.gateway + ); + continue; + } + + // convert stats to DB storage format + match try_protos_into_stats_message( + peer_stats.clone(), + self.gateway.network_id, + self.gateway.id, + ) { + None => { + warn!( + "Failed to parse peer stats update. Skipping sending \ + message to session manager." + ); + } + Some(message) => { + if let Err(err) = self.peer_stats_tx.send(message) { + error!( + "Failed to send peers stats update to session manager: {err}" + ); + } + } + } + } + None => (), } } - } - TriggerOperation::Delete => { - if let Some(old) = gateway_notification.old { - if let Some(abort_handle) = abort_handles.remove(&old.id) { - info!( - "Aborting connection to {old}, it has disappeard from the database" - ); - abort_handle.abort(); - } else { - warn!("Cannot find {old} on the list of connected gateways"); - } + Err(err) => { + error!("Disconnected from Gateway at {uri}, error: {err}"); + // Important: call this funtion before setting disconnection time. + self.send_disconnect_notification().await; + let _ = self.gateway.touch_disconnected(&self.pool).await; + debug!("Waiting 10s to re-establish the connection"); + sleep(TEN_SECS).await; + break 'message; } } - }, - Err(err) => error!("Failed to de-serialize database notification object: {err}"), + } } } - - while let Some(Ok(_result)) = tasks.join_next().await { - debug!("Gateway gRPC task has ended"); - } - - Ok(()) } /// Helper struct for handling gateway events. @@ -313,16 +387,17 @@ struct GatewayUpdatesHandler { network_id: Id, network: WireguardNetwork, gateway_hostname: String, - events_rx: BroadcastReceiver, + events_rx: broadcast::Receiver, tx: UnboundedSender, } impl GatewayUpdatesHandler { - pub fn new( + #[must_use] + fn new( network_id: Id, network: WireguardNetwork, gateway_hostname: String, - events_rx: BroadcastReceiver, + events_rx: broadcast::Receiver, tx: UnboundedSender, ) -> Self { Self { @@ -338,7 +413,7 @@ impl GatewayUpdatesHandler { /// /// Main gRPC server uses a shared channel for broadcasting all gateway events /// so the handler must determine if an event is relevant for the network being serviced - pub async fn run(&mut self) { + async fn run(&mut self) { info!( "Starting update stream to gateway: {}, network {}", self.gateway_hostname, self.network @@ -404,7 +479,7 @@ impl GatewayUpdatesHandler { .collect(), preshared_key: network_info.preshared_key.clone(), keepalive_interval: Some( - self.network.keepalive_interval as u32, + self.network.keepalive_interval.cast_unsigned(), ), }, 0, @@ -439,7 +514,7 @@ impl GatewayUpdatesHandler { .collect(), preshared_key: network_info.preshared_key.clone(), keepalive_interval: Some( - self.network.keepalive_interval as u32, + self.network.keepalive_interval.cast_unsigned(), ), }, 1, @@ -509,7 +584,9 @@ impl GatewayUpdatesHandler { .map(IpAddr::to_string) .collect(), preshared_key: network_device.preshared_key.clone(), - keepalive_interval: Some(self.network.keepalive_interval as u32), + keepalive_interval: Some( + self.network.keepalive_interval.cast_unsigned(), + ), }, 0, ) @@ -545,10 +622,10 @@ impl GatewayUpdatesHandler { name: network.name.clone(), prvkey: network.prvkey.clone(), addresses: network.address.iter().map(ToString::to_string).collect(), - port: network.port as u32, + port: network.port.cast_unsigned(), peers, firewall_config, - mtu: network.mtu as u32, + mtu: network.mtu.cast_unsigned(), fwmark: network.fwmark as u32, })), })), @@ -582,7 +659,7 @@ impl GatewayUpdatesHandler { port: 0, peers: Vec::new(), firewall_config: None, - mtu: DEFAULT_WIREGUARD_MTU as u32, + mtu: DEFAULT_WIREGUARD_MTU.cast_unsigned(), fwmark: 0, })), })), @@ -696,3 +773,45 @@ impl GatewayUpdatesHandler { Ok(()) } } + +/// Helper used to convert peer stats coming from gRPC client +/// into an internal representation +fn try_protos_into_stats_message( + proto_stats: PeerStats, + location_id: Id, + gateway_id: Id, +) -> Option { + // try to parse endpoint + let endpoint = proto_stats.endpoint.parse().ok()?; + + let latest_handshake = DateTime::from_timestamp(proto_stats.latest_handshake as i64, 0) + .unwrap_or_default() + .naive_utc(); + + Some(PeerStatsUpdate::new( + location_id, + gateway_id, + proto_stats.public_key, + endpoint, + proto_stats.upload, + proto_stats.download, + latest_handshake, + )) +} + +fn gen_config( + network: &WireguardNetwork, + peers: Vec, + maybe_firewall_config: Option, +) -> Configuration { + Configuration { + name: network.name.clone(), + port: network.port.cast_unsigned(), + prvkey: network.prvkey.clone(), + addresses: network.address.iter().map(ToString::to_string).collect(), + peers, + firewall_config: maybe_firewall_config, + mtu: network.mtu.cast_unsigned(), + fwmark: network.fwmark as u32, + } +} diff --git a/crates/defguard_gateway_manager/src/lib.rs b/crates/defguard_gateway_manager/src/lib.rs new file mode 100644 index 0000000000..b100560eb1 --- /dev/null +++ b/crates/defguard_gateway_manager/src/lib.rs @@ -0,0 +1,182 @@ +// FIXME: actually refactor errors instead +#![allow(clippy::result_large_err)] +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, + time::Duration, +}; + +use defguard_common::{ + db::{ChangeNotification, Id, TriggerOperation, models::gateway::Gateway}, + messages::peer_stats_update::PeerStatsUpdate, +}; +use defguard_core::grpc::GatewayEvent; +use defguard_proto::gateway::gateway_client::GatewayClient; +use defguard_version::client::ClientVersionInterceptor; +use sqlx::{PgPool, postgres::PgListener}; +use tokio::{ + sync::{broadcast::Sender, mpsc::UnboundedSender}, + task::{AbortHandle, JoinSet}, +}; +use tonic::{Request, service::interceptor::InterceptedService, transport::Channel}; + +use crate::handler::GatewayHandler; + +#[macro_use] +extern crate tracing; + +mod certs; +mod error; +mod handler; +// #[cfg(test)] +// mod tests; + +#[cfg(test)] +static TONIC_SOCKET: &str = "tonic.sock"; +const GATEWAY_TABLE_TRIGGER: &str = "gateway_change"; +const GATEWAY_RECONNECT_DELAY: Duration = Duration::from_secs(5); +const TEN_SECS: Duration = Duration::from_secs(10); + +type Client = GatewayClient>; + +#[derive(Default)] +pub struct GatewayManager { + clients: Arc>>, +} + +impl GatewayManager { + /// Bi-directional gRPC stream for communication with Defguard Gateway. + pub async fn run( + &mut self, + pool: PgPool, + events_tx: Sender, + peer_stats_tx: UnboundedSender, + ) -> Result<(), anyhow::Error> { + let (certs_tx, certs_rx) = tokio::sync::watch::channel(Arc::new(HashMap::new())); + certs::refresh_certs(&pool, &certs_tx).await; + let refresh_pool = pool.clone(); + tokio::spawn(async move { + loop { + certs::refresh_certs(&refresh_pool, &certs_tx).await; + tokio::time::sleep(TEN_SECS).await; + } + }); + let mut abort_handles = HashMap::new(); + + let mut tasks = JoinSet::new(); + // Helper closure to launch `GatewayHandler`. + // TODO: Store arguments in GatewayManager and rewrite this to method + let mut launch_gateway_handler = |gateway: Gateway, + clients: Arc>>| + -> Result { + let mut gateway_handler = GatewayHandler::new( + gateway, + pool.clone(), + events_tx.clone(), + peer_stats_tx.clone(), + certs_rx.clone(), + )?; + let abort_handle = tasks.spawn(async move { + loop { + if let Err(err) = gateway_handler + .handle_connection(Arc::clone(&clients)) + .await + { + error!("Gateway connection error: {err}, retrying in 5 seconds..."); + tokio::time::sleep(GATEWAY_RECONNECT_DELAY).await; + } + } + }); + Ok(abort_handle) + }; + for gateway in Gateway::all(&pool).await? { + let id = gateway.id; + let abort_handle = launch_gateway_handler(gateway, Arc::clone(&self.clients))?; + abort_handles.insert(id, abort_handle); + } + + // Observe gateway URL changes. + let mut listener = PgListener::connect_with(&pool).await?; + listener.listen(GATEWAY_TABLE_TRIGGER).await?; + while let Ok(notification) = listener.recv().await { + let payload = notification.payload(); + match serde_json::from_str::>>(payload) { + Ok(gateway_notification) => match gateway_notification.operation { + TriggerOperation::Insert => { + if let Some(new) = gateway_notification.new { + let id = new.id; + let abort_handle = + launch_gateway_handler(new, Arc::clone(&self.clients))?; + abort_handles.insert(id, abort_handle); + } + } + TriggerOperation::Update => { + if let (Some(old), Some(new)) = + (gateway_notification.old, gateway_notification.new) + { + if old.url == new.url { + debug!( + "Gateway URL didn't change. Keeping the current gateway handler" + ); + } else if let Some(abort_handle) = abort_handles.remove(&old.id) { + info!( + "Aborting connection to {old}, it has changed in the database" + ); + abort_handle.abort(); + let id = new.id; + let abort_handle = + launch_gateway_handler(new, Arc::clone(&self.clients))?; + abort_handles.insert(id, abort_handle); + } else { + warn!("Cannot find {old} on the list of connected gateways"); + } + } + } + TriggerOperation::Delete => { + let Some(old) = gateway_notification.old else { + continue; + }; + + // Send purge request to the gateway. + let maybe_client = { + self.clients + .lock() + .expect("Failed to lock GatewayManager::clients") + .remove(&old.id) + }; + + if let Some(mut client) = maybe_client { + debug!("Sending purge request to gateway {old}"); + if let Err(err) = client.purge(Request::new(())).await { + error!("Error sending purge request to gateway {old}: {err}"); + } else { + info!("Sent purge request to gateway {old}"); + } + } else { + warn!( + "Cannot find gRPC client for gateway {old}, won't send purge request" + ); + } + + // Kill the `GatewayHandler` and the connection. + if let Some(abort_handle) = abort_handles.remove(&old.id) { + info!( + "Aborting connection to gateway {old}, it has disappeard from the database" + ); + abort_handle.abort(); + } else { + warn!("Cannot find abort handle for gateway {old}"); + } + } + }, + Err(err) => error!("Failed to de-serialize database notification object: {err}"), + } + } + + while let Some(Ok(_result)) = tasks.join_next().await { + debug!("Gateway gRPC task has ended"); + } + + Ok(()) + } +} diff --git a/crates/defguard_core/src/grpc/gateway/tests.rs b/crates/defguard_gateway_manager/src/tests.rs similarity index 100% rename from crates/defguard_core/src/grpc/gateway/tests.rs rename to crates/defguard_gateway_manager/src/tests.rs diff --git a/crates/defguard_grpc_tls/Cargo.toml b/crates/defguard_grpc_tls/Cargo.toml new file mode 100644 index 0000000000..3d79cb5622 --- /dev/null +++ b/crates/defguard_grpc_tls/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "defguard_grpc_tls" +version = "0.0.0" +edition.workspace = true +license-file.workspace = true +homepage.workspace = true +repository.workspace = true +rust-version.workspace = true + +[dependencies] +defguard_common.workspace = true +http = "1.1" +rustls = { version = "0.23", features = ["ring"] } +thiserror.workspace = true +tokio.workspace = true +tower-service = "0.3" +x509-parser = "0.18" +tracing.workspace = true diff --git a/crates/defguard_grpc_tls/src/certs.rs b/crates/defguard_grpc_tls/src/certs.rs new file mode 100644 index 0000000000..e9f9f44aa5 --- /dev/null +++ b/crates/defguard_grpc_tls/src/certs.rs @@ -0,0 +1,169 @@ +//! Custom TLS verification for proxy and gateway connections. +//! +//! Motivation: +//! - tonic/rustls does not fetch or enforce CRL distribution points, so revocation +//! has to be enforced by the application. +//! - We pin each component to its expected certificate serial and reject mismatches +//! at the TLS layer, before any gRPC requests are processed. +//! - A lightweight in-memory cache (refreshed periodically) avoids database access +//! during the handshake and keeps verification synchronous. + +use std::{collections::HashMap, sync::Arc}; + +use defguard_common::db::Id; +use rustls::{ + CertificateError, DistinguishedName, Error as RustlsError, RootCertStore, SignatureScheme, + client::{ + WebPkiServerVerifier, + danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}, + }, + crypto, + pki_types::{CertificateDer, ServerName, UnixTime}, +}; +use thiserror::Error; +use tokio::sync::watch; +use tracing::error; +use x509_parser::parse_x509_certificate; + +/// Errors that can occur while building a TLS config with a pinned verifier. +#[derive(Debug, Error)] +pub enum CertConfigError { + #[error("TLS config error: {0}")] + TlsConfig(String), +} + +/// Wraps WebPKI verification to enforce component-specific certificate serials. +#[derive(Debug)] +struct CertVerifier { + inner: Arc, + certs_rx: watch::Receiver>>, + component_id: Id, +} + +impl CertVerifier { + fn new( + inner: Arc, + certs_rx: watch::Receiver>>, + component_id: Id, + ) -> Self { + Self { + inner, + certs_rx, + component_id, + } + } + + /// Validate the peer certificate serial against the expected component serial. + fn verify(&self, end_entity: &CertificateDer<'_>) -> Result<(), RustlsError> { + let (_, cert) = parse_x509_certificate(end_entity.as_ref()) + .map_err(|_| RustlsError::InvalidCertificate(CertificateError::BadEncoding))?; + let serial = cert.tbs_certificate.raw_serial_as_string(); + let certs = self.certs_rx.borrow(); + let Some(expected) = certs.get(&self.component_id) else { + error!( + "Missing expected certificate for component id={}, serial={}", + self.component_id, serial + ); + return Err(RustlsError::InvalidCertificate( + CertificateError::ApplicationVerificationFailure, + )); + }; + if !expected.eq_ignore_ascii_case(&serial) { + error!( + "Invalid certificate for component id={}: expected={} got={}.", + self.component_id, expected, serial + ); + return Err(RustlsError::InvalidCertificate( + CertificateError::ApplicationVerificationFailure, + )); + } + Ok(()) + } +} + +impl ServerCertVerifier for CertVerifier { + /// Delegate chain validation to WebPKI, then enforce the component-specific pin. + fn verify_server_cert( + &self, + end_entity: &CertificateDer<'_>, + intermediates: &[CertificateDer<'_>], + server_name: &ServerName<'_>, + ocsp_response: &[u8], + now: UnixTime, + ) -> Result { + self.inner.verify_server_cert( + end_entity, + intermediates, + server_name, + ocsp_response, + now, + )?; + self.verify(end_entity)?; + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + self.inner.verify_tls12_signature(message, cert, dss) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + self.inner.verify_tls13_signature(message, cert, dss) + } + + fn supported_verify_schemes(&self) -> Vec { + self.inner.supported_verify_schemes() + } + + fn root_hint_subjects(&self) -> Option<&[DistinguishedName]> { + self.inner.root_hint_subjects() + } +} + +/// Build a root store from the configured CA for WebPKI validation. +fn root_store_from_ca(ca_cert_der: &[u8]) -> Result { + let mut roots = RootCertStore::empty(); + roots + .add(CertificateDer::from(ca_cert_der.to_vec())) + .map_err(|err| CertConfigError::TlsConfig(err.to_string()))?; + Ok(roots) +} + +/// Create a rustls client config that enforces the pinned component certificate serial. +pub fn client_config( + ca_cert_der: &[u8], + certs_rx: watch::Receiver>>, + component_id: Id, +) -> Result { + let provider = Arc::new(crypto::ring::default_provider()); + let roots = root_store_from_ca(ca_cert_der)?; + let verifier_roots = root_store_from_ca(ca_cert_der)?; + let verifier = WebPkiServerVerifier::builder_with_provider( + Arc::new(verifier_roots), + Arc::clone(&provider), + ) + .build() + .map_err(|err| CertConfigError::TlsConfig(err.to_string()))?; + let builder = rustls::ClientConfig::builder_with_provider(provider) + .with_safe_default_protocol_versions() + .map_err(|err| CertConfigError::TlsConfig(err.to_string()))?; + let mut config = builder.with_root_certificates(roots).with_no_client_auth(); + let verifier: Arc = verifier; + config + .dangerous() + .set_certificate_verifier(Arc::new(CertVerifier::new( + verifier, + certs_rx, + component_id, + ))); + Ok(config) +} diff --git a/crates/defguard_grpc_tls/src/connector.rs b/crates/defguard_grpc_tls/src/connector.rs new file mode 100644 index 0000000000..16f438ffe6 --- /dev/null +++ b/crates/defguard_grpc_tls/src/connector.rs @@ -0,0 +1,50 @@ +use http::Uri; + +/// Rewrites the request URI scheme to https for the TLS connector. +/// +/// Tonic expects an http URI for its endpoint, but a custom connector performs +/// the TLS handshake and requires https to select the TLS path. +#[derive(Clone, Debug)] +pub struct HttpsSchemeConnector { + inner: C, +} + +impl HttpsSchemeConnector { + pub const fn new(inner: C) -> Self { + Self { inner } + } +} + +type BoxError = Box; + +impl tower_service::Service for HttpsSchemeConnector +where + C: tower_service::Service + Clone + Send + 'static, + C::Future: Send, +{ + type Response = C::Response; + type Error = BoxError; + type Future = std::pin::Pin< + Box> + Send>, + >; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, uri: Uri) -> Self::Future { + let mut parts = uri.into_parts(); + parts.scheme = Some(http::uri::Scheme::HTTPS); + let https_uri = match Uri::from_parts(parts) { + Ok(uri) => uri, + Err(err) => { + return Box::pin(async move { Err(err.into()) }); + } + }; + let mut inner = self.inner.clone(); + Box::pin(async move { inner.call(https_uri).await }) + } +} diff --git a/crates/defguard_grpc_tls/src/lib.rs b/crates/defguard_grpc_tls/src/lib.rs new file mode 100644 index 0000000000..b7a37f7f97 --- /dev/null +++ b/crates/defguard_grpc_tls/src/lib.rs @@ -0,0 +1,2 @@ +pub mod certs; +pub mod connector; diff --git a/crates/defguard_proxy_manager/Cargo.toml b/crates/defguard_proxy_manager/Cargo.toml index da3e28562f..2d27616711 100644 --- a/crates/defguard_proxy_manager/Cargo.toml +++ b/crates/defguard_proxy_manager/Cargo.toml @@ -15,12 +15,13 @@ defguard_mail.workspace = true defguard_proto.workspace = true defguard_version.workspace = true defguard_certs.workspace = true +defguard_grpc_tls.workspace = true axum.workspace = true axum-extra.workspace = true semver.workspace = true secrecy.workspace = true -http.workspace = true +hyper-rustls.workspace = true openidconnect.workspace = true reqwest.workspace = true sqlx.workspace = true @@ -29,8 +30,3 @@ tokio.workspace = true tokio-stream.workspace = true tonic.workspace = true tracing.workspace = true -x509-parser.workspace = true - -hyper-rustls = { version = "0.27", features = ["http2"] } -rustls = { version = "0.23", features = ["ring"] } -tower-service = "0.3" diff --git a/crates/defguard_proxy_manager/src/certs.rs b/crates/defguard_proxy_manager/src/certs.rs index a1010dc4cf..ac16d1e0ba 100644 --- a/crates/defguard_proxy_manager/src/certs.rs +++ b/crates/defguard_proxy_manager/src/certs.rs @@ -1,120 +1,10 @@ -//! Custom TLS verification for proxy connections. -//! -//! Motivation: -//! - tonic/rustls does not fetch or enforce CRL distribution points, so revocation -//! has to be enforced by the application. -//! - We pin each proxy to its expected certificate serial and reject mismatches at -//! the TLS layer, before any gRPC requests are processed. -//! - A lightweight in-memory cache (refreshed periodically) avoids database access -//! during the handshake and keeps verification synchronous. +//! Cached certificate serials for proxies. use std::{collections::HashMap, sync::Arc}; use defguard_common::db::{Id, models::proxy::Proxy}; -use rustls::{ - CertificateError, DistinguishedName, Error as RustlsError, RootCertStore, SignatureScheme, - client::{ - WebPkiServerVerifier, - danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}, - }, - crypto, - pki_types::{CertificateDer, ServerName, UnixTime}, -}; use sqlx::PgPool; use tokio::sync::watch; -use x509_parser::parse_x509_certificate; - -use crate::error::ProxyError; - -/// Wraps WebPKI verification to enforce proxy-specific certificate serials. -#[derive(Debug)] -struct CertVerifier { - inner: Arc, - certs_rx: watch::Receiver>>, - proxy_id: Id, -} - -impl CertVerifier { - fn new( - inner: Arc, - certs_rx: watch::Receiver>>, - proxy_id: Id, - ) -> Self { - Self { - inner, - certs_rx, - proxy_id, - } - } - - /// Validate the peer certificate serial against the expected proxy serial. - fn verify(&self, end_entity: &CertificateDer<'_>) -> Result<(), RustlsError> { - let (_, cert) = parse_x509_certificate(end_entity.as_ref()) - .map_err(|_| RustlsError::InvalidCertificate(CertificateError::BadEncoding))?; - let serial = cert.tbs_certificate.raw_serial_as_string(); - let certs = self.certs_rx.borrow(); - let Some(expected) = certs.get(&self.proxy_id) else { - error!("Missing expected certificate for proxy: {}", self.proxy_id); - return Err(RustlsError::InvalidCertificate(CertificateError::Revoked)); - }; - if !expected.eq_ignore_ascii_case(&serial) { - error!( - "Invalid certificate for proxy {}: expected={expected} got={serial}", - self.proxy_id - ); - return Err(RustlsError::InvalidCertificate(CertificateError::Revoked)); - } - Ok(()) - } -} - -impl ServerCertVerifier for CertVerifier { - /// Delegate chain validation to WebPKI, then enforce the proxy-specific pin. - fn verify_server_cert( - &self, - end_entity: &CertificateDer<'_>, - intermediates: &[CertificateDer<'_>], - server_name: &ServerName<'_>, - ocsp_response: &[u8], - now: UnixTime, - ) -> Result { - self.inner.verify_server_cert( - end_entity, - intermediates, - server_name, - ocsp_response, - now, - )?; - self.verify(end_entity)?; - Ok(ServerCertVerified::assertion()) - } - - fn verify_tls12_signature( - &self, - message: &[u8], - cert: &CertificateDer<'_>, - dss: &rustls::DigitallySignedStruct, - ) -> Result { - self.inner.verify_tls12_signature(message, cert, dss) - } - - fn verify_tls13_signature( - &self, - message: &[u8], - cert: &CertificateDer<'_>, - dss: &rustls::DigitallySignedStruct, - ) -> Result { - self.inner.verify_tls13_signature(message, cert, dss) - } - - fn supported_verify_schemes(&self) -> Vec { - self.inner.supported_verify_schemes() - } - - fn root_hint_subjects(&self) -> Option<&[DistinguishedName]> { - self.inner.root_hint_subjects() - } -} /// Build a compact id->serial map, skipping proxies without a stored cert. fn collect_certs(items: I) -> HashMap @@ -143,159 +33,3 @@ pub(crate) async fn refresh_certs(pool: &PgPool, tx: &watch::Sender Result { - let mut roots = RootCertStore::empty(); - roots - .add(CertificateDer::from(ca_cert_der.to_vec())) - .map_err(|err| ProxyError::TlsConfigError(err.to_string()))?; - Ok(roots) -} - -/// Create a rustls client config that enforces the pinned proxy certificate serial. -pub(crate) fn client_config( - ca_cert_der: &[u8], - certs_rx: watch::Receiver>>, - proxy_id: Id, -) -> Result { - let provider = Arc::new(crypto::ring::default_provider()); - let roots = root_store_from_ca(ca_cert_der)?; - let verifier_roots = root_store_from_ca(ca_cert_der)?; - let verifier = WebPkiServerVerifier::builder_with_provider( - Arc::new(verifier_roots), - Arc::clone(&provider), - ) - .build() - .map_err(|err| ProxyError::TlsConfigError(err.to_string()))?; - let builder = rustls::ClientConfig::builder_with_provider(provider) - .with_safe_default_protocol_versions() - .map_err(|err| ProxyError::TlsConfigError(err.to_string()))?; - let mut config = builder.with_root_certificates(roots).with_no_client_auth(); - config - .dangerous() - .set_certificate_verifier(Arc::new(CertVerifier::new(verifier, certs_rx, proxy_id))); - Ok(config) -} - -#[cfg(test)] -mod tests { - use super::*; - - use defguard_certs::{CertificateAuthority, Csr, DnType, generate_key_pair}; - use rustls::client::danger::HandshakeSignatureValid; - - #[derive(Debug)] - struct NoopVerifier; - - impl ServerCertVerifier for NoopVerifier { - fn verify_server_cert( - &self, - _end_entity: &CertificateDer<'_>, - _intermediates: &[CertificateDer<'_>], - _server_name: &ServerName<'_>, - _ocsp_response: &[u8], - _now: UnixTime, - ) -> Result { - Ok(ServerCertVerified::assertion()) - } - - fn verify_tls12_signature( - &self, - _message: &[u8], - _cert: &CertificateDer<'_>, - _dss: &rustls::DigitallySignedStruct, - ) -> Result { - Ok(HandshakeSignatureValid::assertion()) - } - - fn verify_tls13_signature( - &self, - _message: &[u8], - _cert: &CertificateDer<'_>, - _dss: &rustls::DigitallySignedStruct, - ) -> Result { - Ok(HandshakeSignatureValid::assertion()) - } - - fn supported_verify_schemes(&self) -> Vec { - Vec::new() - } - - fn root_hint_subjects(&self) -> Option<&[DistinguishedName]> { - None - } - } - - fn make_cert_and_serial() -> (CertificateDer<'static>, String) { - let ca = CertificateAuthority::new("Defguard CA", "test@example.com", 30).unwrap(); - let key_pair = generate_key_pair().unwrap(); - let csr = Csr::new( - &key_pair, - &["proxy.local".to_string()], - vec![(DnType::CommonName, "proxy.local")], - ) - .unwrap(); - let cert = ca.sign_csr(&csr).unwrap(); - let cert_der = CertificateDer::from(cert.der().to_vec()); - let (_, parsed) = parse_x509_certificate(cert_der.as_ref()).unwrap(); - let serial = parsed.tbs_certificate.raw_serial_as_string(); - (cert_der, serial) - } - - #[test] - fn collect_certs_skips_missing() { - let certs = collect_certs(vec![(1, None), (2, Some("abc".to_string()))]); - assert_eq!(certs.len(), 1); - assert_eq!(certs.get(&2), Some(&"abc".to_string())); - } - - #[test] - fn verify_accepts_expected_serial() { - let (cert_der, serial) = make_cert_and_serial(); - let (_tx, rx) = watch::channel(Arc::new(HashMap::from([(1, serial.clone())]))); - let verifier = CertVerifier::new(Arc::new(NoopVerifier), rx, 1); - let result = verifier.verify(&cert_der); - assert!(result.is_ok()); - } - - #[test] - fn verify_rejects_missing_expected_cert() { - let (cert_der, serial) = make_cert_and_serial(); - let (_tx, rx) = watch::channel(Arc::new(HashMap::from([(2, serial)]))); - let verifier = CertVerifier::new(Arc::new(NoopVerifier), rx, 1); - let result = verifier.verify(&cert_der); - assert!(matches!( - result, - Err(RustlsError::InvalidCertificate(CertificateError::Revoked)) - )); - } - - #[test] - fn verify_rejects_mismatched_serial() { - let (cert_der, _serial) = make_cert_and_serial(); - let (_tx, rx) = watch::channel(Arc::new(HashMap::from([(1, "deadbeef".to_string())]))); - let verifier = CertVerifier::new(Arc::new(NoopVerifier), rx, 1); - let result = verifier.verify(&cert_der); - assert!(matches!( - result, - Err(RustlsError::InvalidCertificate(CertificateError::Revoked)) - )); - } - - #[test] - fn verify_accepts_case_insensitive_serial() { - let (cert_der, serial) = make_cert_and_serial(); - let expected_lower = serial.to_ascii_lowercase(); - let (_tx, rx) = watch::channel(Arc::new(HashMap::from([(1, expected_lower)]))); - let verifier = CertVerifier::new(Arc::new(NoopVerifier), rx, 1); - let result = verifier.verify(&cert_der); - assert!(result.is_ok()); - - let expected_upper = serial.to_ascii_uppercase(); - let (_tx, rx) = watch::channel(Arc::new(HashMap::from([(1, expected_upper)]))); - let verifier = CertVerifier::new(Arc::new(NoopVerifier), rx, 1); - let result = verifier.verify(&cert_der); - assert!(result.is_ok()); - } -} diff --git a/crates/defguard_proxy_manager/src/proxy_handler.rs b/crates/defguard_proxy_manager/src/handler.rs similarity index 91% rename from crates/defguard_proxy_manager/src/proxy_handler.rs rename to crates/defguard_proxy_manager/src/handler.rs index 169f067d9e..5ee8fea4ad 100644 --- a/crates/defguard_proxy_manager/src/proxy_handler.rs +++ b/crates/defguard_proxy_manager/src/handler.rs @@ -27,7 +27,7 @@ use defguard_core::{ ldap::utils::ldap_update_user_state, }, grpc::{ - gateway::events::GatewayEvent, + GatewayEvent, proxy::client_mfa::{ClientLoginSession, ClientMfaServer}, }, version::{IncompatibleComponents, IncompatibleProxyData, is_proxy_version_supported}, @@ -39,7 +39,6 @@ use defguard_proto::proxy::{ use defguard_version::{ ComponentInfo, DefguardComponent, client::ClientVersionInterceptor, get_tracing_variables, }; -use http::Uri; use hyper_rustls::HttpsConnectorBuilder; use openidconnect::{AuthorizationCode, Nonce, Scope, core::CoreAuthenticationFlow}; use reqwest::Url; @@ -65,9 +64,9 @@ use tonic::{ use crate::{ ProxyError, ProxyTxSet, TEN_SECS, - certs::client_config, servers::{EnrollmentServer, PasswordResetServer}, }; +use defguard_grpc_tls::{certs as tls_certs, connector::HttpsSchemeConnector}; static VERSION_ZERO: Version = Version::new(0, 0, 0); @@ -86,7 +85,7 @@ pub(super) struct ProxyHandler { services: ProxyServices, /// Proxy server gRPC URL pub(super) url: Url, - shutdown_signal: Arc>>, + shutdown_signal: Arc>, proxy_id: Id, client: Option>>, } @@ -98,7 +97,7 @@ impl ProxyHandler { tx: &ProxyTxSet, remote_mfa_responses: Arc>>>, sessions: Arc>>, - shutdown_signal: Arc>>, + shutdown_signal: Arc>, proxy_id: Id, ) -> Self { // Instantiate gRPC servers. @@ -120,7 +119,7 @@ impl ProxyHandler { tx: &ProxyTxSet, remote_mfa_responses: Arc>>>, sessions: Arc>>, - shutdown_signal: Arc>>, + shutdown_signal: Arc>, ) -> Result { let url = Url::from_str(&format!("http://{}:{}", proxy.address, proxy.port))?; let proxy_id = proxy.id; @@ -202,7 +201,9 @@ impl ProxyHandler { "Core CA is not setup, can't create a Proxy endpoint.".to_string(), )); }; - let tls_config = client_config(&ca_cert_der, certs_rx.clone(), self.proxy_id)?; + let tls_config = + tls_certs::client_config(&ca_cert_der, certs_rx.clone(), self.proxy_id) + .map_err(|err| ProxyError::TlsConfigError(err.to_string()))?; let connector = HttpsConnectorBuilder::new() .with_tls_config(tls_config) .https_only() @@ -282,42 +283,39 @@ impl ProxyHandler { payload: Some(core_response::Payload::InitialInfo(initial_info)), }); - let shutdown_signal = self.shutdown_signal.lock().await.take(); - if let Some(shutdown_signal) = shutdown_signal { - select! { - res = self.message_loop(tx, tx_set.wireguard.clone(), &mut resp_stream) => { - if let Err(err) = res { - error!("Proxy message loop ended with error: {err}, reconnecting in {TEN_SECS:?}",); - } else { - info!("Proxy message loop ended, reconnecting in {TEN_SECS:?}"); - } - self.mark_disconnected().await?; - sleep(TEN_SECS).await; + let shutdown_signal = Arc::clone(&self.shutdown_signal); + select! { + res = self.message_loop(tx, tx_set.wireguard.clone(), &mut resp_stream) => { + if let Err(err) = res { + error!("Proxy message loop ended with error: {err}, reconnecting in {TEN_SECS:?}",); + } else { + info!("Proxy message loop ended, reconnecting in {TEN_SECS:?}"); } - res = shutdown_signal => { - match res { - Err(err) => { - error!("An error occurred when trying to wait for a shutdown signal for Proxy: {err}. Reconnecting to: {}", endpoint.uri()); - } - Ok(purge) => { - info!("Shutdown signal received, purge: {purge}, stopping proxy connection to {}", endpoint.uri()); - if purge { + self.mark_disconnected().await?; + sleep(TEN_SECS).await; + } + res = &mut *shutdown_signal.lock().await => { + match res { + Err(err) => { + error!("An error occurred when trying to wait for a shutdown signal for Proxy: {err}. Reconnecting to: {}", endpoint.uri()); + } + Ok(purge) => { + info!("Shutdown signal received, purge: {purge}, stopping proxy connection to {}", endpoint.uri()); + if purge { + if let Some(client) = self.client.as_mut() { debug!("Sending purge request to proxy {}", endpoint.uri()); - if let Some(client) = self.client.as_mut() { - if let Err(err) = client.purge(Request::new(())).await { - error!("Error sending purge request to proxy {}: {err}", endpoint.uri()); - } + if let Err(err) = client.purge(Request::new(())).await { + error!("Error sending purge request to proxy {}: {err}", endpoint.uri()); + } else { + info!("Sent purge request to proxy {}", endpoint.uri()); } } } } - self.mark_disconnected().await?; - break; } + self.mark_disconnected().await?; + break; } - } else { - self.message_loop(tx, tx_set.wireguard.clone(), &mut resp_stream) - .await?; } } @@ -837,52 +835,3 @@ impl ProxyServices { } } } - -/// Rewrites the request URI scheme to https for the TLS connector. -/// -/// Tonic expects an http URI for its endpoint, but our custom connector performs -/// the TLS handshake and requires https to select the TLS path. -#[derive(Clone, Debug)] -struct HttpsSchemeConnector { - inner: C, -} - -impl HttpsSchemeConnector { - const fn new(inner: C) -> Self { - Self { inner } - } -} - -type BoxError = Box; - -impl tower_service::Service for HttpsSchemeConnector -where - C: tower_service::Service + Clone + Send + 'static, - C::Future: Send, -{ - type Response = C::Response; - type Error = BoxError; - type Future = std::pin::Pin< - Box> + Send>, - >; - - fn poll_ready( - &mut self, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.inner.poll_ready(cx).map_err(Into::into) - } - - fn call(&mut self, uri: Uri) -> Self::Future { - let mut parts = uri.into_parts(); - parts.scheme = Some(http::uri::Scheme::HTTPS); - let https_uri = match Uri::from_parts(parts) { - Ok(uri) => uri, - Err(err) => { - return Box::pin(async move { Err(err.into()) }); - } - }; - let mut inner = self.inner.clone(); - Box::pin(async move { inner.call(https_uri).await }) - } -} diff --git a/crates/defguard_proxy_manager/src/lib.rs b/crates/defguard_proxy_manager/src/lib.rs index ffb5e7ca55..351f279001 100644 --- a/crates/defguard_proxy_manager/src/lib.rs +++ b/crates/defguard_proxy_manager/src/lib.rs @@ -5,9 +5,7 @@ use std::{ }; use defguard_common::{db::models::proxy::Proxy, types::proxy::ProxyControlMessage}; -use defguard_core::{ - events::BidiStreamEvent, grpc::gateway::events::GatewayEvent, version::IncompatibleComponents, -}; +use defguard_core::{events::BidiStreamEvent, grpc::GatewayEvent, version::IncompatibleComponents}; use sqlx::PgPool; use tokio::{ @@ -21,11 +19,11 @@ use tokio::{ task::JoinSet, }; -use crate::{certs::refresh_certs, error::ProxyError, proxy_handler::ProxyHandler}; +use crate::{certs::refresh_certs, error::ProxyError, handler::ProxyHandler}; mod certs; mod error; -mod proxy_handler; +mod handler; mod servers; #[macro_use] @@ -37,7 +35,6 @@ const TEN_SECS: Duration = Duration::from_secs(10); /// /// Responsibilities include: /// - instantiating and supervising proxy connections, -/// - routing responses to the appropriate proxy based on correlation state, /// - providing shared infrastructure (database access, outbound channels), pub struct ProxyManager { pool: PgPool, @@ -64,12 +61,13 @@ impl ProxyManager { /// Spawns and supervises asynchronous tasks for all configured proxies. /// /// Each proxy runs in its own task and shares Core-side infrastructure - /// such as routing state and compatibility tracking. pub async fn run(mut self) -> Result<(), ProxyError> { debug!("ProxyManager starting"); let remote_mfa_responses = Arc::default(); let sessions = Arc::default(); let (certs_tx, certs_rx) = watch::channel(Arc::new(HashMap::new())); + // Prime the cache to avoid race with connection loop. + refresh_certs(&self.pool, &certs_tx).await; let refresh_pool = self.pool.clone(); tokio::spawn(async move { loop { @@ -91,7 +89,7 @@ impl ProxyManager { &self.tx, Arc::clone(&remote_mfa_responses), Arc::clone(&sessions), - Arc::new(Mutex::new(Some(shutdown_rx))), + Arc::new(Mutex::new(shutdown_rx)), ) }) .collect::, _>>()?; @@ -133,7 +131,7 @@ impl ProxyManager { &self.tx, Arc::clone(&remote_mfa_responses), Arc::clone(&sessions), - Arc::new(Mutex::new(Some(shutdown_rx))), + Arc::new(Mutex::new(shutdown_rx)), ) { Ok(proxy) => { debug!("Spawning proxy task for proxy {}", proxy.url); diff --git a/crates/defguard_proxy_manager/src/servers/enrollment.rs b/crates/defguard_proxy_manager/src/servers/enrollment.rs index 4fea5f0401..bd1bbef18e 100644 --- a/crates/defguard_proxy_manager/src/servers/enrollment.rs +++ b/crates/defguard_proxy_manager/src/servers/enrollment.rs @@ -22,10 +22,9 @@ use defguard_core::{ }, events::{BidiRequestContext, BidiStreamEvent, BidiStreamEventType, EnrollmentEvent}, grpc::{ - InstanceInfo, + GatewayEvent, InstanceInfo, client_version::ClientFeature, - gateway::events::GatewayEvent, - utils::{build_device_config_response, new_polling_token, parse_client_ip_agent}, + utils::{build_device_config_response, parse_client_ip_agent}, }, handlers::{ mail::{send_email_mfa_activation_email, send_mfa_configured_email}, @@ -1055,6 +1054,45 @@ async fn initial_info_from_user( }) } +// Create a new token for configuration polling. +pub async fn new_polling_token(pool: &PgPool, device: &Device) -> Result { + debug!( + "Making a new polling token for device {}", + device.wireguard_pubkey + ); + let mut transaction = pool.begin().await.map_err(|err| { + error!("Failed to start transaction while making a new polling token: {err}"); + Status::internal(format!("unexpected error: {err}")) + })?; + + // 1. Delete existing polling token for the device, if it exists + // 2. Create a new polling token for the device + PollingToken::delete_for_device_id(&mut *transaction, device.id) + .await + .map_err(|err| { + error!("Failed to delete polling token: {err}"); + Status::internal(format!("unexpected error: {err}")) + })?; + let new_token = PollingToken::new(device.id) + .save(&mut *transaction) + .await + .map_err(|err| { + error!("Failed to save new polling token: {err}"); + Status::internal(format!("unexpected error: {err}")) + })?; + + transaction.commit().await.map_err(|err| { + error!("Failed to commit transaction while making a new polling token: {err}"); + Status::internal(format!("unexpected error: {err}")) + })?; + info!( + "New polling token created for device {}", + device.wireguard_pubkey + ); + + Ok(new_token.token) +} + #[cfg(test)] mod test { use defguard_common::{ diff --git a/crates/defguard_session_manager/src/error.rs b/crates/defguard_session_manager/src/error.rs index 8e4c6bca7f..5242bf2a08 100644 --- a/crates/defguard_session_manager/src/error.rs +++ b/crates/defguard_session_manager/src/error.rs @@ -1,5 +1,5 @@ use defguard_common::db::Id; -use defguard_core::grpc::gateway::events::GatewayEvent; +use defguard_core::grpc::GatewayEvent; use thiserror::Error; use tokio::sync::{broadcast::error::SendError as BroadcastSendError, mpsc::error::SendError}; diff --git a/crates/defguard_session_manager/src/lib.rs b/crates/defguard_session_manager/src/lib.rs index 58135192a2..cfe8505ee9 100644 --- a/crates/defguard_session_manager/src/lib.rs +++ b/crates/defguard_session_manager/src/lib.rs @@ -12,7 +12,7 @@ use defguard_common::{ }, messages::peer_stats_update::PeerStatsUpdate, }; -use defguard_core::grpc::gateway::events::GatewayEvent; +use defguard_core::grpc::GatewayEvent; use sqlx::{PgConnection, PgPool}; use tokio::{ sync::{ diff --git a/crates/defguard_version/Cargo.toml b/crates/defguard_version/Cargo.toml index f05ace0e92..b5e5c00067 100644 --- a/crates/defguard_version/Cargo.toml +++ b/crates/defguard_version/Cargo.toml @@ -10,11 +10,11 @@ rust-version.workspace = true [dependencies] axum.workspace = true http.workspace = true -os_info = "3.12" +os_info.workspace = true semver.workspace = true serde.workspace = true thiserror.workspace = true tonic.workspace = true -tower = "0.5" +tower.workspace = true tracing.workspace = true tracing-subscriber.workspace = true diff --git a/deny.toml b/deny.toml index ae7303ded6..558e412364 100644 --- a/deny.toml +++ b/deny.toml @@ -143,6 +143,14 @@ exceptions = [ "AGPL-3.0-only", "AGPL-3.0-or-later", ], crate = "defguard_event_logger" }, + { allow = [ + "AGPL-3.0-only", + "AGPL-3.0-or-later", + ], crate = "defguard_gateway_manager" }, + { allow = [ + "AGPL-3.0-only", + "AGPL-3.0-or-later", + ], crate = "defguard_grpc_tls" }, { allow = [ "AGPL-3.0-only", "AGPL-3.0-or-later", diff --git a/migrations/20260213090000_[2.0.0]_gateway_certificate_serial.down.sql b/migrations/20260213090000_[2.0.0]_gateway_certificate_serial.down.sql new file mode 100644 index 0000000000..66425c97dc --- /dev/null +++ b/migrations/20260213090000_[2.0.0]_gateway_certificate_serial.down.sql @@ -0,0 +1,3 @@ +ALTER TABLE gateway + DROP COLUMN certificate, + ADD COLUMN has_certificate boolean NOT NULL DEFAULT false; diff --git a/migrations/20260213090000_[2.0.0]_gateway_certificate_serial.up.sql b/migrations/20260213090000_[2.0.0]_gateway_certificate_serial.up.sql new file mode 100644 index 0000000000..b728d7a16c --- /dev/null +++ b/migrations/20260213090000_[2.0.0]_gateway_certificate_serial.up.sql @@ -0,0 +1,3 @@ +ALTER TABLE gateway + DROP COLUMN has_certificate, + ADD COLUMN certificate TEXT; diff --git a/proto b/proto index 8326216b71..faebcc5449 160000 --- a/proto +++ b/proto @@ -1 +1 @@ -Subproject commit 8326216b71edc64acf8fe091bb27d690c8d6885f +Subproject commit faebcc5449ae803e15cf5faf838c0c508401caf1 diff --git a/tools/defguard_generator/Cargo.toml b/tools/defguard_generator/Cargo.toml index 6c37f70ba9..06255adc2b 100644 --- a/tools/defguard_generator/Cargo.toml +++ b/tools/defguard_generator/Cargo.toml @@ -5,7 +5,7 @@ edition = "2024" license-file = "../../LICENSE.md" homepage = "https://defguard.net/" repository = "https://github.com/DefGuard/defguard" -rust-version = "1.85.1" +rust-version = "1.87.0" [dependencies] defguard_common = { workspace = true }