From 9d89d85b065681c85b31cf1262608c86caa36021 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Thu, 27 Nov 2025 13:12:18 +0100 Subject: [PATCH 01/17] Reverse gRPC communication --- Cargo.lock | 106 +- Cargo.toml | 3 +- crates/defguard_core/src/db/models/gateway.rs | 105 ++ crates/defguard_core/src/db/models/mod.rs | 1 + .../src/db/models/polling_token.rs | 11 +- .../defguard_core/src/db/models/wireguard.rs | 111 ++- .../src/enterprise/firewall/mod.rs | 2 +- .../defguard_core/src/grpc/gateway/handler.rs | 368 +++++++ crates/defguard_core/src/grpc/gateway/mod.rs | 902 ++++++++++-------- .../defguard_core/src/grpc/gateway/tests.rs | 87 ++ crates/defguard_core/src/grpc/mod.rs | 51 +- .../20251125072923_network_gateways.down.sql | 3 + .../20251125072923_network_gateways.up.sql | 20 + proto | 2 +- 14 files changed, 1264 insertions(+), 508 deletions(-) create mode 100644 crates/defguard_core/src/db/models/gateway.rs create mode 100644 crates/defguard_core/src/grpc/gateway/handler.rs create mode 100644 crates/defguard_core/src/grpc/gateway/tests.rs create mode 100644 migrations/20251125072923_network_gateways.down.sql create mode 100644 migrations/20251125072923_network_gateways.up.sql diff --git a/Cargo.lock b/Cargo.lock index a0722d94d6..78a69433db 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -574,9 +574,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.46" +version = "1.2.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b97463e1064cb1b1c1384ad0a0b9c8abd0988e2a91f52606c80ef14aadb63e36" +checksum = "cd405d82c84ff7f35739f175f67d8b9fb7687a0e84ccdc78bd3568839827cf07" dependencies = [ "find-msvc-tools", "jobserver", @@ -669,9 +669,9 @@ checksum = "bba18ee93d577a8428902687bcc2b6b45a56b1981a1f6d779731c86cc4c5db18" [[package]] name = "clap" -version = "4.5.52" +version = "4.5.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa8120877db0e5c011242f96806ce3c94e0737ab8108532a76a3300a01db2ab8" +checksum = "c9e340e012a1bf4935f5282ed1436d1489548e8f72308207ea5df0e23d2d03f8" dependencies = [ "clap_builder", "clap_derive", @@ -679,9 +679,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.52" +version = "4.5.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02576b399397b659c26064fbc92a75fede9d18ffd5f80ca1cd74ddab167016e1" +checksum = "d76b5d13eaa18c901fd2f7fca939fefe3a0727a953561fefdf3b2922b8569d00" dependencies = [ "anstream", "anstyle", @@ -835,9 +835,9 @@ dependencies = [ [[package]] name = "crc" -version = "3.3.0" +version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9710d3b3739c2e349eb44fe848ad0b7c8cb1e42bd87ee49371df2f7acaf3e675" +checksum = "5eb8a2a1cd12ab0d987a5d5e825195d372001a4094a0376319d5a0ad71c1ba0d" dependencies = [ "crc-catalog", ] @@ -2001,7 +2001,7 @@ dependencies = [ "futures-core", "futures-sink", "http", - "indexmap 2.12.0", + "indexmap 2.12.1", "slab", "tokio", "tokio-util", @@ -2054,9 +2054,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.16.0" +version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5419bdc4f6a9207fbeba6d11b604d481addf78ecd10c11ad51e76c2f6482748d" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" [[package]] name = "hashlink" @@ -2154,12 +2154,11 @@ dependencies = [ [[package]] name = "http" -version = "1.3.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" dependencies = [ "bytes", - "fnv", "itoa", ] @@ -2484,12 +2483,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.12.0" +version = "2.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6717a8d2a5a929a1a2eb43a12812498ed141a0bcfb7e8f7844fbdbe4303bba9f" +checksum = "0ad4bb2b565bca0645f4d68c5c9af97fba094e9791da685bf83cb5f3ce74acf2" dependencies = [ "equivalent", - "hashbrown 0.16.0", + "hashbrown 0.16.1", "serde", "serde_core", ] @@ -3594,9 +3593,9 @@ checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" [[package]] name = "pest" -version = "2.8.3" +version = "2.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "989e7521a040efde50c3ab6bbadafbe15ab6dc042686926be59ac35d74607df4" +checksum = "cbcfd20a6d4eeba40179f05735784ad32bdaef05ce8e8af05f180d45bb3e7e22" dependencies = [ "memchr", "ucd-trie", @@ -3604,9 +3603,9 @@ dependencies = [ [[package]] name = "pest_derive" -version = "2.8.3" +version = "2.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "187da9a3030dbafabbbfb20cb323b976dc7b7ce91fcd84f2f74d6e31d378e2de" +checksum = "51f72981ade67b1ca6adc26ec221be9f463f2b5839c7508998daa17c23d94d7f" dependencies = [ "pest", "pest_generator", @@ -3614,9 +3613,9 @@ dependencies = [ [[package]] name = "pest_generator" -version = "2.8.3" +version = "2.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49b401d98f5757ebe97a26085998d6c0eecec4995cad6ab7fc30ffdf4b052843" +checksum = "dee9efd8cdb50d719a80088b76f81aec7c41ed6d522ee750178f83883d271625" dependencies = [ "pest", "pest_meta", @@ -3627,9 +3626,9 @@ dependencies = [ [[package]] name = "pest_meta" -version = "2.8.3" +version = "2.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72f27a2cfee9f9039c4d86faa5af122a0ac3851441a34865b8a043b46be0065a" +checksum = "bf1d70880e76bdc13ba52eafa6239ce793d85c8e43896507e43dd8984ff05b82" dependencies = [ "pest", "sha2", @@ -3642,7 +3641,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" dependencies = [ "fixedbitset", - "indexmap 2.12.0", + "indexmap 2.12.1", ] [[package]] @@ -4351,13 +4350,12 @@ dependencies = [ [[package]] name = "rust-ini" -version = "0.21.1" +version = "0.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e310ef0e1b6eeb79169a1171daf9abcb87a2e17c03bee2c4bb100b55c75409f" +checksum = "796e8d2b6696392a43bea58116b667fb4c29727dc5abd27d6acf338bb4f688c7" dependencies = [ "cfg-if", "ordered-multimap", - "trim-in-place", ] [[package]] @@ -4653,7 +4651,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b2f2d7ff8a2140333718bb329f5c40fc5f0865b84c426183ce14c97d2ab8154f" dependencies = [ "form_urlencoded", - "indexmap 2.12.0", + "indexmap 2.12.1", "itoa", "ryu", "serde_core", @@ -4725,7 +4723,7 @@ dependencies = [ "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.12.0", + "indexmap 2.12.1", "schemars 0.9.0", "schemars 1.1.0", "serde_core", @@ -4752,7 +4750,7 @@ version = "0.9.34+deprecated" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" dependencies = [ - "indexmap 2.12.0", + "indexmap 2.12.1", "itoa", "ryu", "serde", @@ -5002,7 +5000,7 @@ dependencies = [ "futures-util", "hashbrown 0.15.5", "hashlink", - "indexmap 2.12.0", + "indexmap 2.12.1", "ipnetwork", "log", "memchr", @@ -5320,9 +5318,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.110" +version = "2.0.111" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a99801b5bd34ede4cf3fc688c5919368fea4e4814a4664359503e6015b280aea" +checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87" dependencies = [ "proc-macro2", "quote", @@ -5625,7 +5623,7 @@ version = "0.23.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6485ef6d0d9b5d0ec17244ff7eb05310113c3f316f2d14200d4de56b3cb98f8d" dependencies = [ - "indexmap 2.12.0", + "indexmap 2.12.1", "toml_datetime", "toml_parser", "winnow", @@ -5744,7 +5742,7 @@ checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" dependencies = [ "futures-core", "futures-util", - "indexmap 2.12.0", + "indexmap 2.12.1", "pin-project-lite", "slab", "sync_wrapper", @@ -5757,9 +5755,9 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.6.6" +version = "0.6.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adc82fd73de2a9722ac5da747f12383d2bfdb93591ee6c58486e0097890f05f2" +checksum = "9cf146f99d442e8e68e585f5d798ccd3cad9a7835b917e09728880a862706456" dependencies = [ "bitflags 2.10.0", "bytes", @@ -5809,9 +5807,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.30" +version = "0.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" dependencies = [ "proc-macro2", "quote", @@ -5820,9 +5818,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.34" +version = "0.1.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" +checksum = "7a04e24fab5c89c6a36eb8558c9656f30d81de51dfa4d3b45f26b21d61fa0a6c" dependencies = [ "once_cell", "valuable", @@ -5868,12 +5866,6 @@ dependencies = [ "syn", ] -[[package]] -name = "trim-in-place" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "343e926fc669bc8cde4fa3129ab681c63671bae288b1f1081ceee6d9d37904fc" - [[package]] name = "try-lock" version = "0.2.5" @@ -6024,7 +6016,7 @@ version = "5.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2fcc29c80c21c31608227e0912b2d7fddba57ad76b606890627ba8ee7964e993" dependencies = [ - "indexmap 2.12.0", + "indexmap 2.12.1", "serde", "serde_json", "utoipa-gen", @@ -6718,9 +6710,9 @@ checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" [[package]] name = "winnow" -version = "0.7.13" +version = "0.7.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21a0236b59786fed61e2a80582dd500fe61f18b5dca67a4a067d0bc9039339cf" +checksum = "5a5364e9d77fcdeeaa6062ced926ee3381faa2ee02d3eb83a5c27a8825540829" dependencies = [ "memchr", ] @@ -6809,18 +6801,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.27" +version = "0.8.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0894878a5fa3edfd6da3f88c4805f4c8558e2b996227a3d864f47fe11e38282c" +checksum = "4ea879c944afe8a2b25fef16bb4ba234f47c694565e97383b36f3a878219065c" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.27" +version = "0.8.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831" +checksum = "cf955aa904d6040f70dc8e9384444cb1030aed272ba3cb09bbc4ab9e7c1f34f5" dependencies = [ "proc-macro2", "quote", @@ -6910,7 +6902,7 @@ dependencies = [ "arbitrary", "crc32fast", "flate2", - "indexmap 2.12.0", + "indexmap 2.12.1", "memchr", "zopfli", ] diff --git a/Cargo.toml b/Cargo.toml index 673c673ccc..2cb6ad1bc7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,8 +61,7 @@ pulldown-cmark = "0.13" rand = "0.8" reqwest = { version = "0.12", features = ["json"] } rsa = "0.9" -# 0.21.2 causes config parsing errors -rust-ini = "=0.21.1" +rust-ini = "0.21" semver = { version = "1.0", features = ["serde"] } secrecy = { version = "0.10", features = ["serde"] } serde = { version = "1.0", features = ["derive"] } diff --git a/crates/defguard_core/src/db/models/gateway.rs b/crates/defguard_core/src/db/models/gateway.rs new file mode 100644 index 0000000000..5d6221f6b4 --- /dev/null +++ b/crates/defguard_core/src/db/models/gateway.rs @@ -0,0 +1,105 @@ +use std::fmt; + +use chrono::{NaiveDateTime, Utc}; +use model_derive::Model; +use sqlx::{PgExecutor, query, query_as}; + +use defguard_common::db::{Id, NoId}; + +#[derive(Clone, Debug, Deserialize, Model, PartialEq, Serialize)] +pub(crate) struct Gateway { + pub id: I, + pub network_id: Id, + pub url: String, + pub hostname: Option, + pub connected_at: Option, + pub disconnected_at: Option, +} + +impl Gateway { + #[must_use] + pub(crate) fn new>(network_id: Id, url: S) -> Self { + Self { + id: NoId, + network_id, + url: url.into(), + hostname: None, + connected_at: None, + disconnected_at: None, + } + } +} + +impl Gateway { + pub(crate) async fn find_by_network_id<'e, E>( + executor: E, + network_id: Id, + ) -> Result, sqlx::Error> + where + E: PgExecutor<'e>, + { + query_as!( + Self, + "SELECT * FROM gateway WHERE network_id = $1 ORDER BY id", + network_id + ) + .fetch_all(executor) + .await + } + + /// Update `hostname` and set `connected_at` to the current time and save it to the database. + pub(crate) async fn touch_connected<'e, E>( + &mut self, + executor: E, + hostname: String, + ) -> Result<(), sqlx::Error> + where + E: PgExecutor<'e>, + { + self.hostname = Some(hostname); + self.connected_at = Some(Utc::now().naive_utc()); + query!( + "UPDATE gateway SET hostname = $2, connected_at = $3 WHERE id = $1", + self.id, + self.hostname, + self.connected_at + ) + .execute(executor) + .await?; + + Ok(()) + } + + /// Set `disconnected_at` to the current time and save it to the database. + pub(crate) async fn touch_disconnected<'e, E>(&mut self, executor: E) -> Result<(), sqlx::Error> + where + E: PgExecutor<'e>, + { + self.disconnected_at = Some(Utc::now().naive_utc()); + query!( + "UPDATE gateway SET disconnected_at = $2 WHERE id = $1", + self.id, + self.disconnected_at + ) + .execute(executor) + .await?; + + Ok(()) + } + + pub(crate) fn is_connected(&self) -> bool { + if let (Some(connected_at), Some(disconnected_at)) = + (self.connected_at, self.disconnected_at) + { + disconnected_at <= connected_at + } else { + self.connected_at.is_some() + } + } +} + +impl fmt::Display for Gateway { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Gateway(ID {}; URL {})", self.id, self.url) + } +} diff --git a/crates/defguard_core/src/db/models/mod.rs b/crates/defguard_core/src/db/models/mod.rs index df2faac41b..0f9061f458 100644 --- a/crates/defguard_core/src/db/models/mod.rs +++ b/crates/defguard_core/src/db/models/mod.rs @@ -1,6 +1,7 @@ pub mod activity_log; pub mod device; pub mod enrollment; +pub mod gateway; pub mod group; pub mod oauth2authorizedapp; pub mod oauth2client; diff --git a/crates/defguard_core/src/db/models/polling_token.rs b/crates/defguard_core/src/db/models/polling_token.rs index b4d911936d..6c23535873 100644 --- a/crates/defguard_core/src/db/models/polling_token.rs +++ b/crates/defguard_core/src/db/models/polling_token.rs @@ -4,7 +4,7 @@ use defguard_common::{ random::gen_alphanumeric, }; use model_derive::Model; -use sqlx::{Error as SqlxError, PgExecutor, PgPool, query_as}; +use sqlx::{PgExecutor, query_as}; // Token used for polling requests. #[derive(Clone, Debug, Model)] @@ -28,18 +28,21 @@ impl PollingToken { } impl PollingToken { - pub async fn find(pool: &PgPool, token: &str) -> Result, SqlxError> { + pub async fn find<'e, E>(executor: E, token: &str) -> Result, sqlx::Error> + where + E: PgExecutor<'e>, + { query_as!( Self, "SELECT id, token, device_id, created_at \ FROM pollingtoken WHERE token = $1", token ) - .fetch_optional(pool) + .fetch_optional(executor) .await } - pub async fn delete_for_device_id<'e, E>(executor: E, device_id: Id) -> Result<(), SqlxError> + pub async fn delete_for_device_id<'e, E>(executor: E, device_id: Id) -> Result<(), sqlx::Error> where E: PgExecutor<'e>, { diff --git a/crates/defguard_core/src/db/models/wireguard.rs b/crates/defguard_core/src/db/models/wireguard.rs index 33c26e4989..2e559dcbf6 100644 --- a/crates/defguard_core/src/db/models/wireguard.rs +++ b/crates/defguard_core/src/db/models/wireguard.rs @@ -23,8 +23,8 @@ use ipnetwork::{IpNetwork, IpNetworkError, NetworkSize}; use model_derive::Model; use rand::rngs::OsRng; use sqlx::{ - Error as SqlxError, FromRow, PgConnection, PgExecutor, PgPool, Type, - postgres::types::PgInterval, query_as, query_scalar, + FromRow, PgConnection, PgExecutor, PgPool, Type, + postgres::types::PgInterval, query, query_as, query_scalar, }; use thiserror::Error; use tokio::sync::broadcast::Sender; @@ -934,13 +934,15 @@ impl WireguardNetwork { &self, conn: &PgPool, device_id: Id, - ) -> Result, SqlxError> { + ) -> Result, sqlx::Error> { // Find a first handshake gap longer than WIREGUARD_MAX_HANDSHAKE. // We assume that this gap indicates a time when the device was not connected. // So, the handshake after this gap is the moment the last connection was established. - // If no such gap is found, the device may be connected from the beginning, return the first handshake in this case. + // If no such gap is found, the device may be connected from the beginning, return the first + // handshake in this case. let connected_at = query_scalar!( - "WITH stats AS (SELECT * FROM wireguard_peer_stats_view WHERE device_id = $1 AND network = $2) \ + "WITH stats AS \ + (SELECT * FROM wireguard_peer_stats_view WHERE device_id = $1 AND network = $2) \ SELECT \ COALESCE( \ ( \ @@ -964,6 +966,85 @@ impl WireguardNetwork { Ok(connected_at) } + /// Get a list of all allowed peers + /// + /// Each device is marked as allowed or not allowed in a given network, + /// which enables enforcing peer disconnect in MFA-protected networks. + /// + /// If the location is a service location, only returns peers if enterprise features are enabled. + pub async fn get_peers<'e, E>(&self, executor: E) -> Result, sqlx::Error> + where + E: PgExecutor<'e>, + { + debug!("Fetching all peers for network {}", self.id); + + if self.should_prevent_service_location_usage() { + warn!( + "Tried to use service location {} with disabled enterprise features. No clients will be allowed to connect.", + self.name + ); + return Ok(Vec::new()); + } + + let rows = query!( + "SELECT d.wireguard_pubkey pubkey, preshared_key, \ + -- TODO possible to not use ARRAY-unnest here? + ARRAY( + SELECT host(ip) + FROM unnest(wnd.wireguard_ips) AS ip + ) \"allowed_ips!: Vec\" \ + FROM wireguard_network_device wnd \ + JOIN device d ON wnd.device_id = d.id \ + JOIN \"user\" u ON d.user_id = u.id \ + WHERE wireguard_network_id = $1 AND (is_authorized = true OR NOT $2) \ + AND d.configured = true \ + AND u.is_active = true \ + ORDER BY d.id ASC", + self.id, + self.mfa_enabled() + ) + .fetch_all(executor) + .await?; + + // keepalive has to be added manually because Postgres + // doesn't support unsigned integers + let result = rows + .into_iter() + .map(|row| Peer { + pubkey: row.pubkey, + allowed_ips: row.allowed_ips, + // Don't send preshared key if MFA is not enabled, it can't be used and may + // cause issues with clients connecting if they expect no preshared key + // e.g. when you disable MFA on a location + preshared_key: if self.mfa_enabled() { + row.preshared_key + } else { + None + }, + keepalive_interval: Some(self.keepalive_interval as u32), + }) + .collect(); + + Ok(result) + } + + /// Update `connected_at` to the current time and save it to the database. + pub(crate) async fn touch_connected<'e, E>(&mut self, executor: E) -> Result<(), sqlx::Error> + where + E: PgExecutor<'e>, + { + self.connected_at = Some(Utc::now().naive_utc()); + query!( + "UPDATE wireguard_network SET connected_at = $2 WHERE name = $1", + self.name, + self.connected_at + ) + .execute(executor) + .await?; + + Ok(()) + } + /// Retrieves stats for specified devices pub(crate) async fn device_stats( &self, @@ -971,7 +1052,7 @@ impl WireguardNetwork { devices: &[Device], from: &NaiveDateTime, aggregation: &DateTimeAggregation, - ) -> Result, SqlxError> { + ) -> Result, sqlx::Error> { if devices.is_empty() { return Ok(Vec::new()); } @@ -1036,7 +1117,7 @@ impl WireguardNetwork { from: &NaiveDateTime, aggregation: &DateTimeAggregation, device_type: DeviceType, - ) -> Result, SqlxError> { + ) -> Result, sqlx::Error> { let oldest_handshake = (Utc::now() - WIREGUARD_MAX_HANDSHAKE).naive_utc(); // Retrieve connected devices from database let devices = query_as!( @@ -1062,7 +1143,7 @@ impl WireguardNetwork { conn: &PgPool, from: &NaiveDateTime, aggregation: &DateTimeAggregation, - ) -> Result, SqlxError> { + ) -> Result, sqlx::Error> { let mut user_map: HashMap> = HashMap::new(); // Retrieve data series for all active devices and assign them to users let device_stats = self @@ -1076,7 +1157,7 @@ impl WireguardNetwork { for u in user_map { let user = User::find_by_id(conn, u.0) .await? - .ok_or(SqlxError::RowNotFound)?; + .ok_or(sqlx::Error::RowNotFound)?; stats.push(WireguardUserStatsRow { user: UserInfo::from_user(conn, &user).await?, devices: u.1.clone(), @@ -1091,7 +1172,7 @@ impl WireguardNetwork { &self, conn: &PgPool, from: &NaiveDateTime, - ) -> Result { + ) -> Result { let activity_stats = query_as!( WireguardNetworkActivityStats, "SELECT \ @@ -1115,7 +1196,7 @@ impl WireguardNetwork { async fn current_activity( &self, conn: &PgPool, - ) -> Result { + ) -> Result { let from = (Utc::now() - WIREGUARD_MAX_HANDSHAKE).naive_utc(); let activity_stats = query_as!( WireguardNetworkActivityStats, @@ -1143,7 +1224,7 @@ impl WireguardNetwork { conn: &PgPool, from: &NaiveDateTime, aggregation: &DateTimeAggregation, - ) -> Result, SqlxError> { + ) -> Result, sqlx::Error> { let stats = query_as!( WireguardStatsRow, "SELECT \ @@ -1171,7 +1252,7 @@ impl WireguardNetwork { conn: &PgPool, from: &NaiveDateTime, aggregation: &DateTimeAggregation, - ) -> Result { + ) -> Result { let total_activity = self.total_activity(conn, from).await?; let current_activity = self.current_activity(conn).await?; let transfer_series = self.transfer_series(conn, from, aggregation).await?; @@ -1192,7 +1273,7 @@ impl WireguardNetwork { &self, executor: E, device_type: DeviceType, - ) -> Result>, SqlxError> + ) -> Result>, sqlx::Error> where E: PgExecutor<'e>, { @@ -1432,7 +1513,7 @@ pub(crate) async fn networks_stats( conn: &PgPool, from: &NaiveDateTime, aggregation: &DateTimeAggregation, -) -> Result { +) -> Result { let total_activity = query_as!( WireguardNetworkActivityStats, "SELECT \ diff --git a/crates/defguard_core/src/enterprise/firewall/mod.rs b/crates/defguard_core/src/enterprise/firewall/mod.rs index 5e2b7e8d97..44ed70b0a8 100644 --- a/crates/defguard_core/src/enterprise/firewall/mod.rs +++ b/crates/defguard_core/src/enterprise/firewall/mod.rs @@ -896,7 +896,7 @@ impl WireguardNetwork { Ok(rules_info) } - /// Prepares firewall configuration for a gateway based on location config and ACLs + /// Prepares firewall configuration for Gateway based on location config and ACLs. /// Returns `None` if firewall management is disabled for a given location. pub async fn try_get_firewall_config( &self, diff --git a/crates/defguard_core/src/grpc/gateway/handler.rs b/crates/defguard_core/src/grpc/gateway/handler.rs new file mode 100644 index 0000000000..66b9d07aa0 --- /dev/null +++ b/crates/defguard_core/src/grpc/gateway/handler.rs @@ -0,0 +1,368 @@ +use std::{ + str::FromStr, + sync::atomic::{AtomicU64, Ordering}, +}; + +use defguard_common::{auth::claims::Claims, db::Id}; +use defguard_mail::Mail; +use defguard_proto::gateway::{CoreResponse, core_request, core_response, gateway_client}; +use sqlx::PgPool; +use tokio::{ + sync::mpsc::{self, Sender, UnboundedSender}, + time::sleep, +}; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tonic::{ + Code, Status, + transport::{ClientTlsConfig, Endpoint}, +}; + +use crate::{ + ClaimsType, + db::{ + Device, GatewayEvent, WireguardNetwork, + models::{gateway::Gateway, wireguard_peer_stats::WireguardPeerStats}, + }, + grpc::TEN_SECS, + handlers::mail::send_gateway_disconnected_email, +}; + +/// One instance per connected gateway. +pub(super) struct GatewayHandler { + endpoint: Endpoint, + gateway: Gateway, + message_id: AtomicU64, + pool: PgPool, + events_tx: Sender, + mail_tx: UnboundedSender, +} + +impl GatewayHandler { + pub(super) fn new( + gateway: Gateway, + tls_config: Option, + pool: PgPool, + events_tx: Sender, + mail_tx: UnboundedSender, + ) -> Result { + let endpoint = Endpoint::from_shared(gateway.url.to_string())? + .http2_keep_alive_interval(TEN_SECS) + .tcp_keepalive(Some(TEN_SECS)) + .keep_alive_while_idle(true); + let endpoint = if let Some(tls) = tls_config { + endpoint.tls_config(tls)? + } else { + endpoint + }; + + Ok(Self { + endpoint, + gateway, + message_id: AtomicU64::new(0), + pool, + events_tx, + mail_tx, + }) + } + + /// Send network and VPN configuration to Gateway. + async fn send_configuration(&self, tx: &UnboundedSender) -> Result<(), Status> { + debug!("Sending configuration to Gateway"); + let network_id = self.gateway.network_id; + // let hostname = Self::get_gateway_hostname(request.metadata())?; + + let mut conn = self.pool.acquire().await.map_err(|err| { + error!("Failed to acquire DB connection: {err}"); + Status::new( + Code::Internal, + "Failed to acquire database connection".to_string(), + ) + })?; + + let mut network = WireguardNetwork::find_by_id(&mut *conn, network_id) + .await + .map_err(|err| { + error!("Network {network_id} not found"); + Status::new(Code::Internal, format!("Failed to retrieve network: {err}")) + })? + .ok_or_else(|| { + Status::new( + Code::Internal, + format!("Network with id {network_id} not found"), + ) + })?; + + debug!( + "Sending configuration to {}, network {network}", + self.gateway + ); + if let Err(err) = network.touch_connected(&mut *conn).await { + error!( + "Failed to update connection time for network {network_id} in the database, \ + status: {err}" + ); + } + + let peers = network.get_peers(&self.pool).await.map_err(|error| { + error!("Failed to fetch peers from the database for network {network_id}: {error}",); + Status::new( + Code::Internal, + format!("Failed to retrieve peers from the database for network: {network_id}"), + ) + })?; + + let maybe_firewall_config = + network + .try_get_firewall_config(&mut *conn) + .await + .map_err(|err| { + error!("Failed to generate firewall config for network {network_id}: {err}"); + Status::new( + Code::Internal, + format!("Failed to generate firewall config for network: {network_id}"), + ) + })?; + let payload = Some(core_response::Payload::Config(super::gen_config( + &network, + peers, + maybe_firewall_config, + ))); + let id = self.message_id.fetch_add(1, Ordering::Relaxed); + let req = CoreResponse { id, payload }; + match tx.send(req) { + Ok(()) => { + info!("Configuration sent to {}, network {network}", self.gateway); + Ok(()) + } + Err(err) => { + error!("Failed to send configuration sent to {}", self.gateway); + Err(Status::new( + Code::Internal, + format!("Configuration not sent to {}, error {err}", self.gateway), + )) + } + } + } + + /// Send gateway disconnected notification. + /// Sends notification only if last notification time is bigger than specified in config. + async fn send_disconnect_notification(&self) { + debug!("Sending gateway disconnect email notification"); + let hostname = self.gateway.hostname.clone(); + let mail_tx = self.mail_tx.clone(); + let pool = self.pool.clone(); + let url = self.gateway.url.clone(); + + let Ok(Some(network)) = + WireguardNetwork::find_by_id(&self.pool, self.gateway.network_id).await + else { + error!( + "Failed to fetch network ID {} from database", + self.gateway.network_id + ); + return; + }; + + // Send email only if disconnection time is before the connection time. + let send_email = if let (Some(connected_at), Some(disconnected_at)) = + (self.gateway.connected_at, self.gateway.disconnected_at) + { + disconnected_at <= connected_at + } else { + true + }; + if send_email { + // FIXME: Try to get rid of spawn and use something like block_on + // To return result instead of logging + tokio::spawn(async move { + if let Err(err) = + send_gateway_disconnected_email(hostname, network.name, &url, &mail_tx, &pool) + .await + { + error!("Failed to send gateway disconnect notification: {err}"); + } else { + info!("Email notification sent about gateway being disconnected"); + } + }); + } else { + info!( + "{} disconnected. Email notification not sent.", + self.gateway + ); + }; + } + + /// Connect to Gateway and handle its messages through gRPC. + pub(super) async fn handle_connection(&mut self) -> ! { + let uri = self.endpoint.uri(); + loop { + #[cfg(not(test))] + let channel = self.endpoint.connect_lazy(); + #[cfg(test)] + let channel = self.endpoint.connect_with_connector_lazy(tower::service_fn( + |_: tonic::transport::Uri| async { + Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new( + tokio::net::UnixStream::connect(super::tests::TONIC_SOCKET).await?, + )) + }, + )); + + debug!("Connecting to Gateway {uri}"); + let mut client = gateway_client::GatewayClient::new(channel); + let (tx, rx) = mpsc::unbounded_channel(); + let response = match client.bidi(UnboundedReceiverStream::new(rx)).await { + Ok(response) => response, + Err(err) => { + error!("Failed to connect to gateway {uri}, retrying: {err}"); + sleep(TEN_SECS).await; + continue; + } + }; + + info!("Connected to Defguard Gateway {uri}"); + let mut resp_stream = response.into_inner(); + let mut config_sent = false; + + 'message: loop { + match resp_stream.message().await { + Ok(None) => { + info!("stream was closed by the sender"); + break 'message; + } + Ok(Some(received)) => { + info!("Received message from gateway."); + debug!("Message from Gateway {uri}"); + match received.payload { + Some(core_request::Payload::ConfigRequest(config_request)) => { + if config_sent { + warn!( + "Ignoring repeated configuration request from {}", + self.gateway + ); + continue; + } + // Validate authorization token. + if let Ok(claims) = Claims::from_jwt( + ClaimsType::Gateway, + &config_request.auth_token, + ) { + if let Ok(client_id) = Id::from_str(&claims.client_id) { + if client_id == self.gateway.network_id { + debug!( + "Authorization token is correct for {}", + self.gateway + ); + } else { + warn!( + "Authorization token received from {uri} has \ + `client_id` for a different network" + ); + continue; + } + } else { + warn!( + "Authorization token received from {uri} has incorrect \ + `client_id`" + ); + continue; + } + } else { + warn!("Invalid authorization token received from {uri}"); + continue; + } + + // Send network configuration to Gateway. + match self.send_configuration(&tx).await { + Ok(()) => { + info!("Sent configuration to {}", self.gateway); + config_sent = true; + let _ = self + .gateway + .touch_connected(&self.pool, config_request.hostname) + .await; + } + Err(err) => { + error!( + "Failed to send configuration to {}: {err}", + self.gateway + ); + } + } + + // Start observing configuration changes. + let Ok(Some(network)) = WireguardNetwork::find_by_id( + &self.pool, + self.gateway.network_id, + ) + .await + else { + error!( + "Failed to fetch network ID {} from the database", + self.gateway.network_id + ); + continue; + }; + // tokio::spawn(super::handle_events( + // network, + // tx.clone(), + // self.events_tx.subscribe(), + // )); + } + Some(core_request::Payload::PeerStats(peer_stats)) => { + if !config_sent { + warn!( + "Ignoring peer statistics from {} because it didn't \ + authorize itself", + self.gateway + ); + continue; + } + + // let public_key = peer_stats.public_key.clone(); + // let mut stats = WireguardPeerStats::from_peer_stats( + // peer_stats, + // self.gateway.network_id, + + // ); + // // Get device by public key and fill in stats.device_id + // match Device::find_by_pubkey(&self.pool, &public_key).await { + // Ok(Some(device)) => { + // stats.device_id = device.id; + // match stats.save(&self.pool).await { + // Ok(_) => { + // info!("Saved WireGuard peer stats to database.") + // } + // Err(err) => error!( + // "Failed to save WireGuard peer stats to database: \ + // {err}" + // ), + // } + // } + // Ok(None) => { + // error!("Device with public key {public_key} not found"); + // } + // Err(err) => { + // error!( + // "Failed to retrieve device with public key \ + // {public_key}: {err}", + // ); + // } + // }; + } + None => (), + }; + } + Err(err) => { + error!("Disconnected from gateway at {uri}, error: {err}"); + // Important: call this funtion before setting disconnection time. + self.send_disconnect_notification().await; + let _ = self.gateway.touch_disconnected(&self.pool).await; + debug!("Waiting 10s to re-establish the connection"); + sleep(TEN_SECS).await; + break 'message; + } + } + } + } + } +} diff --git a/crates/defguard_core/src/grpc/gateway/mod.rs b/crates/defguard_core/src/grpc/gateway/mod.rs index ff119fc0fc..b94db32f13 100644 --- a/crates/defguard_core/src/grpc/gateway/mod.rs +++ b/crates/defguard_core/src/grpc/gateway/mod.rs @@ -1,7 +1,7 @@ use std::{ net::{IpAddr, SocketAddr}, pin::Pin, - sync::{Arc, Mutex}, + sync::{Arc, Mutex, MutexGuard}, task::{Context, Poll}, }; @@ -12,13 +12,13 @@ use defguard_mail::Mail; use defguard_proto::{ enterprise::firewall::FirewallConfig, gateway::{ - Configuration, ConfigurationRequest, Peer, PeerStats, StatsUpdate, Update, - gateway_service_server, stats_update, update, + Configuration, ConfigurationRequest, CoreResponse, Peer, PeerStats, Update, UpdateType, + core_response, update, }, }; use defguard_version::version_info_from_metadata; use semver::Version; -use sqlx::{Error as SqlxError, PgExecutor, PgPool, query}; +use sqlx::PgPool; use thiserror::Error; use tokio::{ sync::{ @@ -41,8 +41,11 @@ use crate::{ }; pub mod client_state; +pub(crate) mod handler; pub mod map; pub(crate) mod state; +#[cfg(test)] +mod tests; const PEER_DISCONNECT_INTERVAL: u64 = 60; @@ -90,70 +93,6 @@ pub struct GatewayServer { grpc_event_tx: UnboundedSender, } -impl WireguardNetwork { - /// Get a list of all allowed peers - /// - /// Each device is marked as allowed or not allowed in a given network, - /// which enables enforcing peer disconnect in MFA-protected networks. - /// - /// If the location is a service location, only returns peers if enterprise features are enabled. - pub async fn get_peers<'e, E>(&self, executor: E) -> Result, SqlxError> - where - E: PgExecutor<'e>, - { - debug!("Fetching all peers for network {}", self.id); - - if self.should_prevent_service_location_usage() { - warn!( - "Tried to use service location {} with disabled enterprise features. No clients will be allowed to connect.", - self.name - ); - return Ok(Vec::new()); - } - - let rows = query!( - "SELECT d.wireguard_pubkey pubkey, preshared_key, \ - -- TODO possible to not use ARRAY-unnest here? - ARRAY( - SELECT host(ip) - FROM unnest(wnd.wireguard_ips) AS ip - ) \"allowed_ips!: Vec\" \ - FROM wireguard_network_device wnd \ - JOIN device d ON wnd.device_id = d.id \ - JOIN \"user\" u ON d.user_id = u.id \ - WHERE wireguard_network_id = $1 AND (is_authorized = true OR NOT $2) \ - AND d.configured = true \ - AND u.is_active = true \ - ORDER BY d.id ASC", - self.id, - self.mfa_enabled() - ) - .fetch_all(executor) - .await?; - - // keepalive has to be added manually because Postgres - // doesn't support unsigned integers - let result = rows - .into_iter() - .map(|row| Peer { - pubkey: row.pubkey, - allowed_ips: row.allowed_ips, - // Don't send preshared key if MFA is not enabled, it can't be used and may - // cause issues with clients connecting if they expect no preshared key - // e.g. when you disable MFA on a location - preshared_key: if self.mfa_enabled() { - row.preshared_key - } else { - None - }, - keepalive_interval: Some(self.keepalive_interval as u32), - }) - .collect(); - - Ok(result) - } -} - /// Utility struct encapsulating commonly extracted metadata fields during gRPC communication. struct GatewayMetadata { network_id: Id, @@ -224,9 +163,7 @@ impl GatewayServer { } } - pub fn get_client_state_guard( - &self, - ) -> Result, GatewayServerError> { + pub fn get_client_state_guard(&self) -> Result, GatewayServerError> { let client_state = self .client_state .lock() @@ -354,6 +291,167 @@ impl WireguardPeerStats { } } +/* + +/// Process received Gateway events +/// +/// Main gRPC server uses a shared channel for broadcasting all Gateway events, +/// so the handler must determine if an event is relevant for the network being serviced. +async fn handle_events( + mut current_network: WireguardNetwork, + tx: UnboundedSender, + mut events_rx: Receiver, +) { + info!("Starting update stream network {current_network}"); + while let Some(event) = events_rx.recv().await { + debug!("Received networking state update event: {event:?}"); + let (update_type, update) = match event { + GatewayEvent::NetworkCreated(network, _fixme) => { + if network.id != current_network.id { + continue; + } + ( + UpdateType::Create, + update::Update::Network(Configuration { + name: network.name.clone(), + prvkey: network.prvkey.clone(), + addresses: network.address.to_string(), + port: network.port as u32, + peers: Vec::new(), + }), + ) + } + GatewayEvent::NetworkModified(network, peers, _fixme) => { + if network.id != current_network.id { + continue; + } + // update stored network data + current_network = network.clone(); + ( + UpdateType::Modify, + update::Update::Network(Configuration { + name: network.name, + prvkey: network.prvkey, + addresses: network.address.to_string(), + port: network.port as u32, + peers, + }), + ) + } + GatewayEvent::NetworkDeleted(network_id, network_name) => { + if network_id != current_network.id { + continue; + } + ( + UpdateType::Delete, + update::Update::Network(Configuration { + name: network_name.to_string(), + prvkey: String::new(), + addresses: Vec::new(), + port: 0, + peers: Vec::new(), + firewall_config: None, + }), + ) + } + GatewayEvent::DeviceCreated(device) => { + // check if a peer has to be added in the current network + match device + .network_info + .iter() + .find(|info| info.network_id == current_network.id) + { + Some(network_info) => { + if current_network.mfa_enabled && !network_info.is_authorized { + debug!( + "Created WireGuard device {} is not authorized to connect to MFA enabled location {}", + device.device.name, current_network.name + ); + continue; + }; + let peer = Peer { + pubkey: device.device.wireguard_pubkey, + allowed_ips: vec![network_info.device_wireguard_ip.to_string()], + preshared_key: network_info.preshared_key.clone(), + keepalive_interval: Some(current_network.keepalive_interval as u32), + }; + (UpdateType::Create, update::Update::Peer(peer)) + } + None => continue, + } + } + GatewayEvent::DeviceModified(device) => { + // check if a peer has to be updated in the current network + match device + .network_info + .iter() + .find(|info| info.network_id == current_network.id) + { + Some(network_info) => { + if current_network.mfa_enabled && !network_info.is_authorized { + debug!( + "Modified WireGuard device {} is not authorized to connect to MFA enabled location {}", + device.device.name, current_network.name + ); + continue; + }; + let peer = Peer { + pubkey: device.device.wireguard_pubkey, + allowed_ips: vec![network_info.device_wireguard_ip.to_string()], + preshared_key: network_info.preshared_key.clone(), + keepalive_interval: Some(current_network.keepalive_interval as u32), + }; + (UpdateType::Modify, update::Update::Peer(peer)) + } + None => continue, + } + } + GatewayEvent::DeviceDeleted(device) => { + // check if a peer has to be updated in the current network + match device + .network_info + .iter() + .find(|info| info.network_id == current_network.id) + { + Some(_) => ( + UpdateType::Delete, + update::Update::Peer(Peer { + pubkey: device.device.wireguard_pubkey, + allowed_ips: Vec::new(), + preshared_key: None, + keepalive_interval: None, + }), + ), + None => continue, + } + } + GatewayEvent::FirewallConfigChanged(_fixme, _) => (), + GatewayEvent::FirewallDisabled(_id) => (), + }; + + let req = CoreResponse { + id: 0, + payload: Some(core_response::Payload::Update(Update { + update_type: update_type as i32, + update: Some(update), + })), + }; + if let Err(err) = tx.send(req) { + error!( + "Failed to send network update, network {current_network}, update type: {}, error: \ + {err}", + update_type.as_str_name() + ); + break; + } + debug!( + "Network update sent for network {current_network}, update type: {}", + update_type.as_str_name() + ); + } +} +*/ + /// Helper struct for handling gateway events struct GatewayUpdatesHandler { network_id: Id, @@ -751,334 +849,334 @@ impl Drop for GatewayUpdatesStream { } } -#[tonic::async_trait] -impl gateway_service_server::GatewayService for GatewayServer { - type UpdatesStream = GatewayUpdatesStream; - - /// Retrieve stats from gateway and save it to database - async fn stats( - &self, - request: Request>, - ) -> Result, Status> { - let GatewayMetadata { - network_id, - hostname, - .. - } = Self::extract_metadata(request.metadata())?; - let mut stream = request.into_inner(); - let mut disconnect_timer = interval(Duration::from_secs(PEER_DISCONNECT_INTERVAL)); - // FIXME: tracing causes looping messages, like `INFO gateway_config:gateway_stats:...`. - // let span = tracing::info_span!("gateway_stats", component = %DefguardComponent::Gateway, - // version = version.to_string(), info); - // let _guard = span.enter(); - loop { - // Wait for a message or update client map at least once a mninute, if no messages are - // received. - let stats_update = tokio::select! { - message = stream.message() => { - match message? { - Some(update) => update, - None => break, // Stream ended - } - } - _ = disconnect_timer.tick() => { - debug!("No stats updates received in last {PEER_DISCONNECT_INTERVAL} seconds. \ - Updating disconnected VPN clients"); - // fetch location to get current peer disconnect threshold - let location = self.fetch_location_from_db(network_id).await?; - - // perform client state operations in a dedicated block to drop mutex guard - let disconnected_clients = { - // acquire lock on client state map - let mut client_map = self.get_client_state_guard()?; - - // disconnect inactive clients - client_map.disconnect_inactive_vpn_clients_for_location(&location - )? - }; - - // emit client disconnect events - for (device, context) in disconnected_clients { - self.emit_event(GrpcEvent::ClientDisconnected { - context, - location: location.clone(), - device, - })?; - }; - continue; - } - }; - - debug!("Received stats message: {stats_update:?}"); - let Some(stats_update::Payload::PeerStats(peer_stats)) = stats_update.payload else { - debug!("Received stats message is empty, skipping."); - continue; - }; - let public_key = peer_stats.public_key.clone(); - - // fetch device from DB - // TODO: fetch only when device has changed and use client state otherwise - let device = match self.fetch_device_from_db(&public_key).await? { - Some(device) => device, - None => { - warn!( - "Received stats update for a device which does not exist: {public_key}, skipping." - ); - continue; - } - }; - - // copy device ID for easier reference later - let device_id = device.id; - - // fetch user and location from DB for activity log - // TODO: cache usernames since they don't change - let user = self.fetch_user_from_db(device.user_id, &public_key).await?; - let location = self.fetch_location_from_db(network_id).await?; - - // convert stats to DB storage format - let stats = WireguardPeerStats::from_peer_stats(peer_stats, network_id, device_id); - - // only perform client state update if stats include an endpoint IP - // otherwise a peer was added to the gateway interface - // but has not connected yet - if let Some(endpoint) = &stats.endpoint { - // parse client endpoint IP - let socket_addr: SocketAddr = endpoint.clone().parse().map_err(|err| { - error!("Failed to parse VPN client endpoint: {err}"); - Status::new( - Code::Internal, - format!("Failed to parse VPN client endpoint: {err}"), - ) - })?; - - // perform client state operations in a dedicated block to drop mutex guard - let disconnected_clients = { - // acquire lock on client state map - let mut client_map = self.get_client_state_guard()?; - - // update connected clients map - match client_map.get_vpn_client(network_id, &public_key) { - Some(client_state) => { - // update connected client state - client_state.update_client_state( - device, - socket_addr, - stats.latest_handshake, - stats.upload, - stats.download, - ); - } - None => { - // don't mark inactive peers as connected - if (Utc::now().naive_utc() - stats.latest_handshake) - < TimeDelta::seconds(location.peer_disconnect_threshold.into()) - { - // mark new VPN client as connected - client_map.connect_vpn_client( - network_id, - &hostname, - &public_key, - &device, - &user, - socket_addr, - &stats, - )?; - - // emit connection event - let context = GrpcRequestContext::new( - user.id, - user.username.clone(), - socket_addr.ip(), - device.id, - device.name.clone(), - location.clone(), - ); - self.emit_event(GrpcEvent::ClientConnected { - context, - location: location.clone(), - device: device.clone(), - })?; - } - } - } - - // disconnect inactive clients - client_map.disconnect_inactive_vpn_clients_for_location(&location)? - }; - - // emit client disconnect events - for (device, context) in disconnected_clients { - self.emit_event(GrpcEvent::ClientDisconnected { - context, - location: location.clone(), - device, - })?; - } - } - - // Save stats to db - let stats = match stats.save(&self.pool).await { - Ok(stats) => stats, - Err(err) => { - error!("Saving WireGuard peer stats to db failed: {err}"); - return Err(Status::new( - Code::Internal, - format!("Saving WireGuard peer stats to db failed: {err}"), - )); - } - }; - info!("Saved WireGuard peer stats to db."); - debug!("WireGuard peer stats: {stats:?}"); - } - - Ok(Response::new(())) - } - - async fn config( - &self, - request: Request, - ) -> Result, Status> { - debug!("Sending configuration to gateway client."); - let GatewayMetadata { - network_id, - hostname, - version, - .. - // info, - } = Self::extract_metadata(request.metadata())?; - // FIXME: tracing causes looping messages, like `INFO gateway_config:gateway_stats:...`. - // let span = tracing::info_span!("gateway_config", component = %DefguardComponent::Gateway, - // version = version.to_string(), info); - // let _guard = span.enter(); - - let mut conn = self.pool.acquire().await.map_err(|e| { - error!("Failed to acquire DB connection: {e}"); - Status::new( - Code::Internal, - "Failed to acquire DB connection".to_string(), - ) - })?; - - let mut network = WireguardNetwork::find_by_id(&mut *conn, network_id) - .await - .map_err(|e| { - error!("Network {network_id} not found"); - Status::new(Code::Internal, format!("Failed to retrieve network: {e}")) - })? - .ok_or_else(|| { - Status::new( - Code::Internal, - format!("Network with id {network_id} not found"), - ) - })?; - - debug!("Sending configuration to gateway client, network {network}."); - - // store connected gateway in memory - { - let mut state = self.gateway_state.lock().unwrap(); - state.add_gateway( - network_id, - &network.name, - hostname, - request.into_inner().name, - self.mail_tx.clone(), - version, - ); - } - - network.connected_at = Some(Utc::now().naive_utc()); - if let Err(err) = network.save(&mut *conn).await { - error!("Failed to save updated network {network_id} in the database, status: {err}"); - } - - let peers = network.get_peers(&mut *conn).await.map_err(|error| { - error!("Failed to fetch peers from the database for network {network_id}: {error}",); - Status::new( - Code::Internal, - format!("Failed to retrieve peers from the database for network: {network_id}"), - ) - })?; - let maybe_firewall_config = - network - .try_get_firewall_config(&mut conn) - .await - .map_err(|err| { - error!("Failed to generate firewall config for network {network_id}: {err}"); - Status::new( - Code::Internal, - format!("Failed to generate firewall config for network: {network_id}"), - ) - })?; - - info!("Configuration sent to gateway client, network {network}."); - - Ok(Response::new(gen_config( - &network, - peers, - maybe_firewall_config, - ))) - } - - async fn updates(&self, request: Request<()>) -> Result, Status> { - let GatewayMetadata { - network_id, - hostname, - .. - // info, - } = Self::extract_metadata(request.metadata())?; - // FIXME: tracing causes looping messages, like `INFO gateway_config:gateway_stats:...`. - // let span = tracing::info_span!("gateway_updates", component = %DefguardComponent::Gateway, - // version = version.to_string(), info); - // let _guard = span.enter(); - - let Some(network) = WireguardNetwork::find_by_id(&self.pool, network_id) - .await - .map_err(|_| { - error!("Failed to fetch network {network_id} from the database"); - Status::new( - Code::Internal, - format!("Failed to retrieve network {network_id} from the database"), - ) - })? - else { - return Err(Status::new( - Code::Internal, - format!("Network with id {network_id} not found"), - )); - }; - - info!("New client connected to updates stream: {hostname}, network {network}",); - - let (tx, rx) = mpsc::channel(4); - let events_rx = self.wireguard_tx.subscribe(); - let mut state = self.gateway_state.lock().unwrap(); - state - .connect_gateway(network_id, &hostname, &self.pool) - .map_err(|err| { - error!("Failed to connect gateway on network {network_id}: {err}"); - Status::new( - Code::Internal, - format!("Failed to connect gateway on network {network_id}"), - ) - })?; - - // clone here before moving into a closure - let gateway_hostname = hostname.clone(); - let handle = tokio::spawn(async move { - let mut update_handler = - GatewayUpdatesHandler::new(network_id, network, gateway_hostname, events_rx, tx); - update_handler.run().await; - }); - - Ok(Response::new(GatewayUpdatesStream::new( - handle, - rx, - network_id, - hostname, - Arc::clone(&self.gateway_state), - self.pool.clone(), - ))) - } -} +// #[tonic::async_trait] +// impl gateway_service_server::GatewayService for GatewayServer { +// type UpdatesStream = GatewayUpdatesStream; + +// /// Retrieve stats from gateway and save it to database +// async fn stats( +// &self, +// request: Request>, +// ) -> Result, Status> { +// let GatewayMetadata { +// network_id, +// hostname, +// .. +// } = Self::extract_metadata(request.metadata())?; +// let mut stream = request.into_inner(); +// let mut disconnect_timer = interval(Duration::from_secs(PEER_DISCONNECT_INTERVAL)); +// // FIXME: tracing causes looping messages, like `INFO gateway_config:gateway_stats:...`. +// // let span = tracing::info_span!("gateway_stats", component = %DefguardComponent::Gateway, +// // version = version.to_string(), info); +// // let _guard = span.enter(); +// loop { +// // Wait for a message or update client map at least once a mninute, if no messages are +// // received. +// let stats_update = tokio::select! { +// message = stream.message() => { +// match message? { +// Some(update) => update, +// None => break, // Stream ended +// } +// } +// _ = disconnect_timer.tick() => { +// debug!("No stats updates received in last {PEER_DISCONNECT_INTERVAL} seconds. \ +// Updating disconnected VPN clients"); +// // fetch location to get current peer disconnect threshold +// let location = self.fetch_location_from_db(network_id).await?; + +// // perform client state operations in a dedicated block to drop mutex guard +// let disconnected_clients = { +// // acquire lock on client state map +// let mut client_map = self.get_client_state_guard()?; + +// // disconnect inactive clients +// client_map.disconnect_inactive_vpn_clients_for_location(&location +// )? +// }; + +// // emit client disconnect events +// for (device, context) in disconnected_clients { +// self.emit_event(GrpcEvent::ClientDisconnected { +// context, +// location: location.clone(), +// device, +// })?; +// }; +// continue; +// } +// }; + +// debug!("Received stats message: {stats_update:?}"); +// let Some(stats_update::Payload::PeerStats(peer_stats)) = stats_update.payload else { +// debug!("Received stats message is empty, skipping."); +// continue; +// }; +// let public_key = peer_stats.public_key.clone(); + +// // fetch device from DB +// // TODO: fetch only when device has changed and use client state otherwise +// let device = match self.fetch_device_from_db(&public_key).await? { +// Some(device) => device, +// None => { +// warn!( +// "Received stats update for a device which does not exist: {public_key}, skipping." +// ); +// continue; +// } +// }; + +// // copy device ID for easier reference later +// let device_id = device.id; + +// // fetch user and location from DB for activity log +// // TODO: cache usernames since they don't change +// let user = self.fetch_user_from_db(device.user_id, &public_key).await?; +// let location = self.fetch_location_from_db(network_id).await?; + +// // convert stats to DB storage format +// let stats = WireguardPeerStats::from_peer_stats(peer_stats, network_id, device_id); + +// // only perform client state update if stats include an endpoint IP +// // otherwise a peer was added to the gateway interface +// // but has not connected yet +// if let Some(endpoint) = &stats.endpoint { +// // parse client endpoint IP +// let socket_addr: SocketAddr = endpoint.clone().parse().map_err(|err| { +// error!("Failed to parse VPN client endpoint: {err}"); +// Status::new( +// Code::Internal, +// format!("Failed to parse VPN client endpoint: {err}"), +// ) +// })?; + +// // perform client state operations in a dedicated block to drop mutex guard +// let disconnected_clients = { +// // acquire lock on client state map +// let mut client_map = self.get_client_state_guard()?; + +// // update connected clients map +// match client_map.get_vpn_client(network_id, &public_key) { +// Some(client_state) => { +// // update connected client state +// client_state.update_client_state( +// device, +// socket_addr, +// stats.latest_handshake, +// stats.upload, +// stats.download, +// ); +// } +// None => { +// // don't mark inactive peers as connected +// if (Utc::now().naive_utc() - stats.latest_handshake) +// < TimeDelta::seconds(location.peer_disconnect_threshold.into()) +// { +// // mark new VPN client as connected +// client_map.connect_vpn_client( +// network_id, +// &hostname, +// &public_key, +// &device, +// &user, +// socket_addr, +// &stats, +// )?; + +// // emit connection event +// let context = GrpcRequestContext::new( +// user.id, +// user.username.clone(), +// socket_addr.ip(), +// device.id, +// device.name.clone(), +// location.clone(), +// ); +// self.emit_event(GrpcEvent::ClientConnected { +// context, +// location: location.clone(), +// device: device.clone(), +// })?; +// } +// } +// } + +// // disconnect inactive clients +// client_map.disconnect_inactive_vpn_clients_for_location(&location)? +// }; + +// // emit client disconnect events +// for (device, context) in disconnected_clients { +// self.emit_event(GrpcEvent::ClientDisconnected { +// context, +// location: location.clone(), +// device, +// })?; +// } +// } + +// // Save stats to db +// let stats = match stats.save(&self.pool).await { +// Ok(stats) => stats, +// Err(err) => { +// error!("Saving WireGuard peer stats to db failed: {err}"); +// return Err(Status::new( +// Code::Internal, +// format!("Saving WireGuard peer stats to db failed: {err}"), +// )); +// } +// }; +// info!("Saved WireGuard peer stats to db."); +// debug!("WireGuard peer stats: {stats:?}"); +// } + +// Ok(Response::new(())) +// } + +// async fn config( +// &self, +// request: Request, +// ) -> Result, Status> { +// debug!("Sending configuration to gateway client."); +// let GatewayMetadata { +// network_id, +// hostname, +// version, +// .. +// // info, +// } = Self::extract_metadata(request.metadata())?; +// // FIXME: tracing causes looping messages, like `INFO gateway_config:gateway_stats:...`. +// // let span = tracing::info_span!("gateway_config", component = %DefguardComponent::Gateway, +// // version = version.to_string(), info); +// // let _guard = span.enter(); + +// let mut conn = self.pool.acquire().await.map_err(|e| { +// error!("Failed to acquire DB connection: {e}"); +// Status::new( +// Code::Internal, +// "Failed to acquire DB connection".to_string(), +// ) +// })?; + +// let mut network = WireguardNetwork::find_by_id(&mut *conn, network_id) +// .await +// .map_err(|e| { +// error!("Network {network_id} not found"); +// Status::new(Code::Internal, format!("Failed to retrieve network: {e}")) +// })? +// .ok_or_else(|| { +// Status::new( +// Code::Internal, +// format!("Network with id {network_id} not found"), +// ) +// })?; + +// debug!("Sending configuration to gateway client, network {network}."); + +// // store connected gateway in memory +// { +// let mut state = self.gateway_state.lock().unwrap(); +// state.add_gateway( +// network_id, +// &network.name, +// hostname, +// request.into_inner().name, +// self.mail_tx.clone(), +// version, +// ); +// } + +// network.connected_at = Some(Utc::now().naive_utc()); +// if let Err(err) = network.save(&mut *conn).await { +// error!("Failed to save updated network {network_id} in the database, status: {err}"); +// } + +// let peers = network.get_peers(&mut *conn).await.map_err(|error| { +// error!("Failed to fetch peers from the database for network {network_id}: {error}",); +// Status::new( +// Code::Internal, +// format!("Failed to retrieve peers from the database for network: {network_id}"), +// ) +// })?; +// let maybe_firewall_config = +// network +// .try_get_firewall_config(&mut conn) +// .await +// .map_err(|err| { +// error!("Failed to generate firewall config for network {network_id}: {err}"); +// Status::new( +// Code::Internal, +// format!("Failed to generate firewall config for network: {network_id}"), +// ) +// })?; + +// info!("Configuration sent to gateway client, network {network}."); + +// Ok(Response::new(gen_config( +// &network, +// peers, +// maybe_firewall_config, +// ))) +// } + +// async fn updates(&self, request: Request<()>) -> Result, Status> { +// let GatewayMetadata { +// network_id, +// hostname, +// .. +// // info, +// } = Self::extract_metadata(request.metadata())?; +// // FIXME: tracing causes looping messages, like `INFO gateway_config:gateway_stats:...`. +// // let span = tracing::info_span!("gateway_updates", component = %DefguardComponent::Gateway, +// // version = version.to_string(), info); +// // let _guard = span.enter(); + +// let Some(network) = WireguardNetwork::find_by_id(&self.pool, network_id) +// .await +// .map_err(|_| { +// error!("Failed to fetch network {network_id} from the database"); +// Status::new( +// Code::Internal, +// format!("Failed to retrieve network {network_id} from the database"), +// ) +// })? +// else { +// return Err(Status::new( +// Code::Internal, +// format!("Network with id {network_id} not found"), +// )); +// }; + +// info!("New client connected to updates stream: {hostname}, network {network}",); + +// let (tx, rx) = mpsc::channel(4); +// let events_rx = self.wireguard_tx.subscribe(); +// let mut state = self.gateway_state.lock().unwrap(); +// state +// .connect_gateway(network_id, &hostname, &self.pool) +// .map_err(|err| { +// error!("Failed to connect gateway on network {network_id}: {err}"); +// Status::new( +// Code::Internal, +// format!("Failed to connect gateway on network {network_id}"), +// ) +// })?; + +// // clone here before moving into a closure +// let gateway_hostname = hostname.clone(); +// let handle = tokio::spawn(async move { +// let mut update_handler = +// GatewayUpdatesHandler::new(network_id, network, gateway_hostname, events_rx, tx); +// update_handler.run().await; +// }); + +// Ok(Response::new(GatewayUpdatesStream::new( +// handle, +// rx, +// network_id, +// hostname, +// Arc::clone(&self.gateway_state), +// self.pool.clone(), +// ))) +// } +// } diff --git a/crates/defguard_core/src/grpc/gateway/tests.rs b/crates/defguard_core/src/grpc/gateway/tests.rs new file mode 100644 index 0000000000..18174116e6 --- /dev/null +++ b/crates/defguard_core/src/grpc/gateway/tests.rs @@ -0,0 +1,87 @@ +use std::{ + io, + net::{IpAddr, Ipv4Addr}, +}; + +use ipnetwork::IpNetwork; +use tokio::{ + net::UnixListener, + sync::{broadcast, mpsc::unbounded_channel}, +}; +use tokio_stream::wrappers::UnixListenerStream; +use tonic::{transport::Server, Request, Response, Status, Streaming}; + +use super::*; + +pub(super) static TONIC_SOCKET: &str = "tonic.sock"; + +struct FakeGateway; + +#[tonic::async_trait] +impl gateway_server::Gateway for FakeGateway { + type BidiStream = UnboundedReceiverStream>; + + async fn bidi( + &self, + request: Request>, + ) -> Result, Status> { + let (_tx, rx) = mpsc::unbounded_channel(); + let mut stream = request.into_inner(); + tokio::spawn(async move { + loop { + match stream.message().await { + Ok(Some(_response)) => (), + Ok(None) => (), + Err(_err) => (), + } + } + }); + + Ok(Response::new(UnboundedReceiverStream::new(rx))) + } +} + +async fn fake_gateway() -> Result<(), io::Error> { + let gateway = FakeGateway {}; + + let uds = UnixListener::bind(TONIC_SOCKET)?; + let uds_stream = UnixListenerStream::new(uds); + + Server::builder() + .add_service(gateway_server::GatewayServer::new(gateway)) + .serve_with_incoming(uds_stream) + .await + .unwrap(); + + Ok(()) +} + +#[sqlx::test] +async fn test_gateway(pool: PgPool) { + let network = WireguardNetwork::new( + "TestNet".to_string(), + IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 1, 1, 1)), 24).unwrap(), + 50051, + "0.0.0.0".to_string(), + None, + vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 1, 1, 0)), 24).unwrap()], + false, + 0, + 0, + ) + .save(&pool) + .await + .unwrap(); + let gateway = Gateway::new(network.id, "http://[::]:50051") + .save(&pool) + .await + .unwrap(); + let (events_tx, _events_rx) = broadcast::channel::(16); + let (mail_tx, _mail_rx) = unbounded_channel::(); + + let mut gateway_handler = GatewayHandler::new(gateway, None, pool, events_tx, mail_tx).unwrap(); + let handle = tokio::spawn(async move { + gateway_handler.handle_connection().await; + }); + handle.abort(); +} diff --git a/crates/defguard_core/src/grpc/mod.rs b/crates/defguard_core/src/grpc/mod.rs index a4c4ba3dcf..f2e98d6ecf 100644 --- a/crates/defguard_core/src/grpc/mod.rs +++ b/crates/defguard_core/src/grpc/mod.rs @@ -90,7 +90,6 @@ pub mod proto { use defguard_proto::{ auth::auth_service_server::AuthServiceServer, - gateway::gateway_service_server::GatewayServiceServer, proxy::{ AuthCallbackResponse, AuthInfoResponse, CoreError, CoreRequest, CoreResponse, core_request, core_response, proxy_client::ProxyClient, @@ -734,31 +733,31 @@ pub async fn build_grpc_service_router( .add_service(health_service) .add_service(auth_service); - let router = { - use crate::version::GatewayVersionInterceptor; - - let gateway_service = GatewayServiceServer::new(GatewayServer::new( - pool, - gateway_state, - client_state, - wireguard_tx, - mail_tx, - grpc_event_tx, - )); - - let own_version = Version::parse(VERSION)?; - router.add_service( - ServiceBuilder::new() - .layer(tonic::service::InterceptorLayer::new(JwtInterceptor::new( - ClaimsType::Gateway, - ))) - .layer(tonic::service::InterceptorLayer::new( - GatewayVersionInterceptor::new(MIN_GATEWAY_VERSION, incompatible_components), - )) - .layer(DefguardVersionLayer::new(own_version)) - .service(gateway_service), - ) - }; + // let router = { + // use crate::version::GatewayVersionInterceptor; + + // let gateway_service = GatewayServiceServer::new(GatewayServer::new( + // pool, + // gateway_state, + // client_state, + // wireguard_tx, + // mail_tx, + // grpc_event_tx, + // )); + + // let own_version = Version::parse(VERSION)?; + // router.add_service( + // ServiceBuilder::new() + // .layer(tonic::service::InterceptorLayer::new(JwtInterceptor::new( + // ClaimsType::Gateway, + // ))) + // .layer(tonic::service::InterceptorLayer::new( + // GatewayVersionInterceptor::new(MIN_GATEWAY_VERSION, incompatible_components), + // )) + // .layer(DefguardVersionLayer::new(own_version)) + // .service(gateway_service), + // ) + // }; let router = router.add_service(worker_service); diff --git a/migrations/20251125072923_network_gateways.down.sql b/migrations/20251125072923_network_gateways.down.sql new file mode 100644 index 0000000000..5e727c02c8 --- /dev/null +++ b/migrations/20251125072923_network_gateways.down.sql @@ -0,0 +1,3 @@ +DROP TRIGGER gateway ON gateway; +DROP FUNCTION row_change(); +DROP TABLE gateway; diff --git a/migrations/20251125072923_network_gateways.up.sql b/migrations/20251125072923_network_gateways.up.sql new file mode 100644 index 0000000000..3db149fd6e --- /dev/null +++ b/migrations/20251125072923_network_gateways.up.sql @@ -0,0 +1,20 @@ +CREATE TABLE gateway ( + id bigserial PRIMARY KEY, + network_id bigint NOT NULL, + url text NOT NULL, + hostname text NULL, + connected_at timestamp without time zone NULL, + disconnected_at timestamp without time zone NULL, + FOREIGN KEY(network_id) REFERENCES wireguard_network(id) +); +CREATE FUNCTION row_change() RETURNS trigger AS $$ +BEGIN + PERFORM pg_notify(TG_TABLE_NAME || '_change', + json_build_object('operation', TG_OP, 'old', row_to_json(OLD), 'new', row_to_json(NEW))::text + ); + RETURN NULL; +END; +$$ LANGUAGE plpgsql; +CREATE TRIGGER gateway + AFTER INSERT OR UPDATE OR DELETE ON gateway + FOR ROW EXECUTE FUNCTION row_change(); diff --git a/proto b/proto index 74d60d9171..d8a8d1b27f 160000 --- a/proto +++ b/proto @@ -1 +1 @@ -Subproject commit 74d60d9171048ba0ccaf8a21b05950fb7a673f09 +Subproject commit d8a8d1b27fe38f1bd71241971c90ed3852f06d5b From dcdc0f2606614dd231f809560e3a641bac0761b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Fri, 28 Nov 2025 11:25:17 +0100 Subject: [PATCH 02/17] Let it build --- Cargo.lock | 40 +++++++++---------- .../defguard_core/src/grpc/gateway/handler.rs | 4 +- crates/defguard_core/src/grpc/gateway/mod.rs | 5 ++- .../defguard_core/src/grpc/gateway/tests.rs | 4 +- crates/defguard_core/src/grpc/mod.rs | 24 +++++------ .../integration/grpc/common/mock_gateway.rs | 4 +- .../tests/integration/grpc/common/mod.rs | 12 +++--- .../tests/integration/grpc/gateway.rs | 2 +- .../defguard_core/tests/integration/main.rs | 2 +- 9 files changed, 48 insertions(+), 49 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 78a69433db..5703ef2bc4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2569,9 +2569,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.82" +version = "0.3.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b011eec8cc36da2aab2d5cff675ec18454fad408585853910a202391cf9f8e65" +checksum = "464a3709c7f55f1f721e5389aa6ea4e3bc6aba669353300af094b29ffbdde1d8" dependencies = [ "once_cell", "wasm-bindgen", @@ -4715,9 +4715,9 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.16.0" +version = "3.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10574371d41b0d9b2cff89418eda27da52bcaff2cc8741db26382a77c29131f1" +checksum = "4fa237f2807440d238e0364a218270b98f767a00d3dada77b1c53ae88940e2e7" dependencies = [ "base64 0.22.1", "chrono", @@ -4734,9 +4734,9 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.16.0" +version = "3.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08a72d8216842fdd57820dc78d840bef99248e35fb2554ff923319e60f2d686b" +checksum = "52a8e3ca0ca629121f70ab50f95249e5a6f925cc0f6ffe8256c45b728875706c" dependencies = [ "darling 0.21.3", "proc-macro2", @@ -5795,9 +5795,9 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" -version = "0.1.41" +version = "0.1.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +checksum = "2d15d90a0b5c19378952d479dc858407149d7bb45a14de0142f6c534b16fc647" dependencies = [ "log", "pin-project-lite", @@ -6171,9 +6171,9 @@ checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" [[package]] name = "wasm-bindgen" -version = "0.2.105" +version = "0.2.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da95793dfc411fbbd93f5be7715b0578ec61fe87cb1a42b12eb625caa5c5ea60" +checksum = "0d759f433fa64a2d763d1340820e46e111a7a5ab75f993d1852d70b03dbb80fd" dependencies = [ "cfg-if", "once_cell", @@ -6184,9 +6184,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.55" +version = "0.4.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "551f88106c6d5e7ccc7cd9a16f312dd3b5d36ea8b4954304657d5dfba115d4a0" +checksum = "836d9622d604feee9e5de25ac10e3ea5f2d65b41eac0d9ce72eb5deae707ce7c" dependencies = [ "cfg-if", "js-sys", @@ -6197,9 +6197,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.105" +version = "0.2.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04264334509e04a7bf8690f2384ef5265f05143a4bff3889ab7a3269adab59c2" +checksum = "48cb0d2638f8baedbc542ed444afc0644a29166f1595371af4fecf8ce1e7eeb3" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -6207,9 +6207,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.105" +version = "0.2.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "420bc339d9f322e562942d52e115d57e950d12d88983a14c79b86859ee6c7ebc" +checksum = "cefb59d5cd5f92d9dcf80e4683949f15ca4b511f4ac0a6e14d4e1ac60c6ecd40" dependencies = [ "bumpalo", "proc-macro2", @@ -6220,9 +6220,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.105" +version = "0.2.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76f218a38c84bcb33c25ec7059b07847d465ce0e0a76b995e134a45adcb6af76" +checksum = "cbc538057e648b67f72a982e708d485b2efa771e1ac05fec311f9f63e5800db4" dependencies = [ "unicode-ident", ] @@ -6242,9 +6242,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.82" +version = "0.3.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a1f95c0d03a47f4ae1f7a64643a6bb97465d9b740f0fa8f90ea33915c99a9a1" +checksum = "9b32828d774c412041098d182a8b38b16ea816958e07cf40eec2bc080ae137ac" dependencies = [ "js-sys", "wasm-bindgen", diff --git a/crates/defguard_core/src/grpc/gateway/handler.rs b/crates/defguard_core/src/grpc/gateway/handler.rs index 66b9d07aa0..497ac064a6 100644 --- a/crates/defguard_core/src/grpc/gateway/handler.rs +++ b/crates/defguard_core/src/grpc/gateway/handler.rs @@ -27,7 +27,7 @@ use crate::{ handlers::mail::send_gateway_disconnected_email, }; -/// One instance per connected gateway. +/// One instance per connected Gateway. pub(super) struct GatewayHandler { endpoint: Endpoint, gateway: Gateway, @@ -202,7 +202,7 @@ impl GatewayHandler { let channel = self.endpoint.connect_with_connector_lazy(tower::service_fn( |_: tonic::transport::Uri| async { Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new( - tokio::net::UnixStream::connect(super::tests::TONIC_SOCKET).await?, + tokio::net::UnixStream::connect(super::TONIC_SOCKET).await?, )) }, )); diff --git a/crates/defguard_core/src/grpc/gateway/mod.rs b/crates/defguard_core/src/grpc/gateway/mod.rs index b94db32f13..8f1f5ec669 100644 --- a/crates/defguard_core/src/grpc/gateway/mod.rs +++ b/crates/defguard_core/src/grpc/gateway/mod.rs @@ -44,10 +44,11 @@ pub mod client_state; pub(crate) mod handler; pub mod map; pub(crate) mod state; -#[cfg(test)] -mod tests; +// #[cfg(test)] +// mod tests; const PEER_DISCONNECT_INTERVAL: u64 = 60; +static TONIC_SOCKET: &str = "tonic.sock"; /// Sends given `GatewayEvent` to be handled by gateway GRPC server /// diff --git a/crates/defguard_core/src/grpc/gateway/tests.rs b/crates/defguard_core/src/grpc/gateway/tests.rs index 18174116e6..f79b77dba7 100644 --- a/crates/defguard_core/src/grpc/gateway/tests.rs +++ b/crates/defguard_core/src/grpc/gateway/tests.rs @@ -9,12 +9,10 @@ use tokio::{ sync::{broadcast, mpsc::unbounded_channel}, }; use tokio_stream::wrappers::UnixListenerStream; -use tonic::{transport::Server, Request, Response, Status, Streaming}; +use tonic::{Request, Response, Status, Streaming, transport::Server}; use super::*; -pub(super) static TONIC_SOCKET: &str = "tonic.sock"; - struct FakeGateway; #[tonic::async_trait] diff --git a/crates/defguard_core/src/grpc/mod.rs b/crates/defguard_core/src/grpc/mod.rs index f2e98d6ecf..eedad0a010 100644 --- a/crates/defguard_core/src/grpc/mod.rs +++ b/crates/defguard_core/src/grpc/mod.rs @@ -680,13 +680,13 @@ pub async fn run_grpc_server( server, pool, worker_state, - gateway_state, - client_state, - wireguard_tx, - mail_tx, + // gateway_state, + // client_state, + // wireguard_tx, + // mail_tx, failed_logins, - grpc_event_tx, - incompatible_components, + // grpc_event_tx, + // incompatible_components, ) .await?; @@ -707,13 +707,13 @@ pub async fn build_grpc_service_router( server: Server, pool: PgPool, worker_state: Arc>, - gateway_state: Arc>, - client_state: Arc>, - wireguard_tx: Sender, - mail_tx: UnboundedSender, + // gateway_state: Arc>, + // client_state: Arc>, + // wireguard_tx: Sender, + // mail_tx: UnboundedSender, failed_logins: Arc>, - grpc_event_tx: UnboundedSender, - incompatible_components: Arc>, + // grpc_event_tx: UnboundedSender, + // incompatible_components: Arc>, ) -> Result { let auth_service = AuthServiceServer::new(AuthServer::new(pool.clone(), failed_logins)); diff --git a/crates/defguard_core/tests/integration/grpc/common/mock_gateway.rs b/crates/defguard_core/tests/integration/grpc/common/mock_gateway.rs index 6440f5e8f9..11bcdafbfd 100644 --- a/crates/defguard_core/tests/integration/grpc/common/mock_gateway.rs +++ b/crates/defguard_core/tests/integration/grpc/common/mock_gateway.rs @@ -2,8 +2,8 @@ use std::time::Duration; use defguard_core::grpc::{AUTHORIZATION_HEADER, HOSTNAME_HEADER}; use defguard_proto::gateway::{ - Configuration, ConfigurationRequest, StatsUpdate, Update, - gateway_service_client::GatewayServiceClient, + Configuration, ConfigurationRequest, Update, + }; use defguard_version::{Version, client::ClientVersionInterceptor}; use tokio::{ diff --git a/crates/defguard_core/tests/integration/grpc/common/mod.rs b/crates/defguard_core/tests/integration/grpc/common/mod.rs index 96609dbfa7..b919afcd44 100644 --- a/crates/defguard_core/tests/integration/grpc/common/mod.rs +++ b/crates/defguard_core/tests/integration/grpc/common/mod.rs @@ -28,7 +28,7 @@ use tower::service_fn; use crate::common::{init_config, initialize_users}; -pub mod mock_gateway; +// pub mod mock_gateway; pub struct TestGrpcServer { grpc_server_task_handle: JoinHandle<()>, @@ -156,13 +156,13 @@ pub(crate) async fn make_grpc_test_server(pool: &PgPool) -> TestGrpcServer { server, pool.clone(), worker_state, - gateway_state.clone(), - client_state.clone(), - wg_tx.clone(), + // gateway_state.clone(), + // client_state.clone(), + // wg_tx.clone(), mail_tx, failed_logins, - grpc_event_tx, - Default::default(), + // grpc_event_tx, + // Default::default(), ) .await .unwrap(); diff --git a/crates/defguard_core/tests/integration/grpc/gateway.rs b/crates/defguard_core/tests/integration/grpc/gateway.rs index d27fca1e72..75c0c83398 100644 --- a/crates/defguard_core/tests/integration/grpc/gateway.rs +++ b/crates/defguard_core/tests/integration/grpc/gateway.rs @@ -21,7 +21,7 @@ use defguard_core::{ }; use defguard_proto::{ enterprise::firewall::FirewallPolicy, - gateway::{Configuration, PeerStats, StatsUpdate, Update, stats_update::Payload, update}, + gateway::{Configuration, PeerStats, Update, stats_update::Payload, update}, }; use semver::Version; use sqlx::{ diff --git a/crates/defguard_core/tests/integration/main.rs b/crates/defguard_core/tests/integration/main.rs index f85d8d0fa3..b3793ede28 100644 --- a/crates/defguard_core/tests/integration/main.rs +++ b/crates/defguard_core/tests/integration/main.rs @@ -1,3 +1,3 @@ mod api; mod common; -mod grpc; +// mod grpc; From 9af11aa77cb89b695d9fc9bb20bef2acdeab5611 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Fri, 28 Nov 2025 13:31:44 +0100 Subject: [PATCH 03/17] Database trigger --- crates/defguard/src/main.rs | 7 +- crates/defguard_common/src/config.rs | 26 +++- crates/defguard_common/src/db/mod.rs | 15 +++ .../defguard_core/src/db/models/wireguard.rs | 4 +- .../defguard_core/src/grpc/gateway/handler.rs | 11 +- crates/defguard_core/src/grpc/gateway/mod.rs | 2 +- .../defguard_core/src/grpc/gateway/state.rs | 16 ++- crates/defguard_core/src/grpc/mod.rs | 115 ++++++++++++++++-- 8 files changed, 177 insertions(+), 19 deletions(-) diff --git a/crates/defguard/src/main.rs b/crates/defguard/src/main.rs index 3c7576a2ee..5188cfd41e 100644 --- a/crates/defguard/src/main.rs +++ b/crates/defguard/src/main.rs @@ -24,7 +24,7 @@ use defguard_core::{ grpc::{ WorkerState, gateway::{client_state::ClientMap, map::GatewayMap}, - run_grpc_bidi_stream, run_grpc_server, + run_grpc_bidi_stream, run_grpc_gateway_stream, run_grpc_server, }, init_dev_env, init_vpn_location, run_web_server, utility_thread::run_utility_thread, @@ -153,6 +153,11 @@ async fn main() -> Result<(), anyhow::Error> { // run services tokio::select! { + res = run_grpc_gateway_stream( + pool.clone(), + wireguard_tx.clone(), + mail_tx.clone() + ) => error!("Gateway gRPC stream returned early: {res:?}"), res = run_grpc_bidi_stream( pool.clone(), wireguard_tx.clone(), diff --git a/crates/defguard_common/src/config.rs b/crates/defguard_common/src/config.rs index 2549ce610b..97f8d59b28 100644 --- a/crates/defguard_common/src/config.rs +++ b/crates/defguard_common/src/config.rs @@ -1,4 +1,4 @@ -use std::{net::IpAddr, sync::OnceLock}; +use std::{fs::read_to_string, io, net::IpAddr, sync::OnceLock}; use clap::{Args, Parser, Subcommand}; use humantime::Duration; @@ -13,6 +13,7 @@ use rsa::{ }; use secrecy::{ExposeSecret, SecretString}; use serde::Serialize; +use tonic::transport::{Certificate, ClientTlsConfig, Identity}; pub static SERVER_CONFIG: OnceLock = OnceLock::new(); @@ -65,9 +66,11 @@ pub struct DefGuardConfig { #[arg(long, env = "DEFGUARD_GRPC_PORT", default_value_t = 50055)] pub grpc_port: u16, + // Certificate authority (CA), certificate, and key for gRPC communication over HTTPS. + #[arg(long, env = "DEFGUARD_GRPC_CA")] + pub grpc_ca: Option, #[arg(long, env = "DEFGUARD_GRPC_CERT")] pub grpc_cert: Option, - #[arg(long, env = "DEFGUARD_GRPC_KEY")] pub grpc_key: Option, @@ -298,6 +301,25 @@ impl DefGuardConfig { } url } + + /// Provide [`ClientTlsConfig`] from paths to cerfiticate, key, and cerfiticate authority (CA). + pub fn grpc_client_tls_config(&self) -> Result, io::Error> { + if self.grpc_ca.is_none() && (self.grpc_cert.is_none() || self.grpc_key.is_none()) { + return Ok(None); + } + let mut tls = ClientTlsConfig::new(); + if let (Some(cert_path), Some(key_path)) = (&self.grpc_cert, &self.grpc_key) { + let cert = read_to_string(cert_path)?; + let key = read_to_string(key_path)?; + tls = tls.identity(Identity::from_pem(cert, key)); + } + if let Some(ca_path) = &self.grpc_ca { + let ca = read_to_string(ca_path)?; + tls = tls.ca_certificate(Certificate::from_pem(ca)); + } + + Ok(Some(tls)) + } } impl Default for DefGuardConfig { diff --git a/crates/defguard_common/src/db/mod.rs b/crates/defguard_common/src/db/mod.rs index d7ca63d055..cc49e8289c 100644 --- a/crates/defguard_common/src/db/mod.rs +++ b/crates/defguard_common/src/db/mod.rs @@ -45,3 +45,18 @@ pub async fn setup_pool(options: PgConnectOptions) -> PgPool { .expect("Cannot run database migrations."); pool } + +#[derive(Deserialize)] +#[serde(rename_all = "UPPERCASE")] +pub enum TriggerOperation { + Insert, + Update, + Delete, +} + +#[derive(Deserialize)] +pub struct ChangeNotification { + pub operation: TriggerOperation, + pub old: Option, + pub new: Option, +} diff --git a/crates/defguard_core/src/db/models/wireguard.rs b/crates/defguard_core/src/db/models/wireguard.rs index 2e559dcbf6..32c4a4e4fb 100644 --- a/crates/defguard_core/src/db/models/wireguard.rs +++ b/crates/defguard_core/src/db/models/wireguard.rs @@ -23,8 +23,8 @@ use ipnetwork::{IpNetwork, IpNetworkError, NetworkSize}; use model_derive::Model; use rand::rngs::OsRng; use sqlx::{ - FromRow, PgConnection, PgExecutor, PgPool, Type, - postgres::types::PgInterval, query, query_as, query_scalar, + FromRow, PgConnection, PgExecutor, PgPool, Type, postgres::types::PgInterval, query, query_as, + query_scalar, }; use thiserror::Error; use tokio::sync::broadcast::Sender; diff --git a/crates/defguard_core/src/grpc/gateway/handler.rs b/crates/defguard_core/src/grpc/gateway/handler.rs index 497ac064a6..53403fcc21 100644 --- a/crates/defguard_core/src/grpc/gateway/handler.rs +++ b/crates/defguard_core/src/grpc/gateway/handler.rs @@ -8,7 +8,10 @@ use defguard_mail::Mail; use defguard_proto::gateway::{CoreResponse, core_request, core_response, gateway_client}; use sqlx::PgPool; use tokio::{ - sync::mpsc::{self, Sender, UnboundedSender}, + sync::{ + broadcast::Sender, + mpsc::{self, UnboundedSender}, + }, time::sleep, }; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -28,7 +31,7 @@ use crate::{ }; /// One instance per connected Gateway. -pub(super) struct GatewayHandler { +pub(crate) struct GatewayHandler { endpoint: Endpoint, gateway: Gateway, message_id: AtomicU64, @@ -38,7 +41,7 @@ pub(super) struct GatewayHandler { } impl GatewayHandler { - pub(super) fn new( + pub(crate) fn new( gateway: Gateway, tls_config: Option, pool: PgPool, @@ -193,7 +196,7 @@ impl GatewayHandler { } /// Connect to Gateway and handle its messages through gRPC. - pub(super) async fn handle_connection(&mut self) -> ! { + pub(crate) async fn handle_connection(&mut self) -> ! { let uri = self.endpoint.uri(); loop { #[cfg(not(test))] diff --git a/crates/defguard_core/src/grpc/gateway/mod.rs b/crates/defguard_core/src/grpc/gateway/mod.rs index 8f1f5ec669..ab6d576140 100644 --- a/crates/defguard_core/src/grpc/gateway/mod.rs +++ b/crates/defguard_core/src/grpc/gateway/mod.rs @@ -48,7 +48,7 @@ pub(crate) mod state; // mod tests; const PEER_DISCONNECT_INTERVAL: u64 = 60; -static TONIC_SOCKET: &str = "tonic.sock"; +pub(super) static TONIC_SOCKET: &str = "tonic.sock"; /// Sends given `GatewayEvent` to be handled by gateway GRPC server /// diff --git a/crates/defguard_core/src/grpc/gateway/state.rs b/crates/defguard_core/src/grpc/gateway/state.rs index 7219c30d16..788801106f 100644 --- a/crates/defguard_core/src/grpc/gateway/state.rs +++ b/crates/defguard_core/src/grpc/gateway/state.rs @@ -13,6 +13,7 @@ use utoipa::ToSchema; use uuid::Uuid; use crate::{ + db::models::gateway::Gateway, grpc::MIN_GATEWAY_VERSION, handlers::mail::{send_gateway_disconnected_email, send_gateway_reconnected_email}, }; @@ -23,7 +24,7 @@ pub struct GatewayState { pub connected: bool, pub network_id: Id, pub network_name: String, - pub name: Option, + pub name: Option, // TODO: remove pub hostname: String, pub connected_at: Option, pub disconnected_at: Option, @@ -36,6 +37,19 @@ pub struct GatewayState { } impl GatewayState { + // pub(crate) fn from_gateway(gateway: &Gateway, network_name: &str) -> Self { + // Self { + // id: gateway.id, + // connected: gateway.is_connected(), + // network_id: gateway.network_id, + // network_name: network_name.to_owned(), + // name: None, // TODO: remove + // hostname: gateway.hostname.clone().unwrap_or_default(), + // connected_at: gateway.connected_at, + // disconnected_at: gateway.disconnected_at, + // } + // } + #[must_use] pub fn new>( network_id: Id, diff --git a/crates/defguard_core/src/grpc/mod.rs b/crates/defguard_core/src/grpc/mod.rs index eedad0a010..cee430a613 100644 --- a/crates/defguard_core/src/grpc/mod.rs +++ b/crates/defguard_core/src/grpc/mod.rs @@ -10,7 +10,7 @@ use axum::http::Uri; use defguard_common::{ VERSION, auth::claims::ClaimsType, - db::{Id, models::Settings}, + db::{ChangeNotification, Id, TriggerOperation, models::Settings}, }; use defguard_mail::Mail; use defguard_version::{ @@ -20,12 +20,13 @@ use defguard_version::{ use openidconnect::{AuthorizationCode, Nonce, Scope, core::CoreAuthenticationFlow}; use reqwest::Url; use serde::Serialize; -use sqlx::PgPool; +use sqlx::{PgPool, postgres::PgListener}; use tokio::{ sync::{ broadcast::Sender, mpsc::{self, UnboundedSender}, }, + task::{AbortHandle, JoinSet}, time::sleep, }; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -35,19 +36,20 @@ use tonic::{ Certificate, ClientTlsConfig, Endpoint, Identity, Server, ServerTlsConfig, server::Router, }, }; -use tower::ServiceBuilder; use self::{ auth::AuthServer, client_mfa::ClientMfaServer, enrollment::EnrollmentServer, - gateway::GatewayServer, interceptor::JwtInterceptor, password_reset::PasswordResetServer, - worker::WorkerServer, + gateway::handler::GatewayHandler, interceptor::JwtInterceptor, + password_reset::PasswordResetServer, worker::WorkerServer, }; -pub use crate::version::MIN_GATEWAY_VERSION; use crate::{ auth::failed_login::FailedLoginMap, db::{ AppEvent, GatewayEvent, - models::enrollment::{ENROLLMENT_TOKEN_TYPE, Token}, + models::{ + enrollment::{ENROLLMENT_TOKEN_TYPE, Token}, + gateway::Gateway, + }, }, enterprise::{ db::models::{ @@ -65,7 +67,10 @@ use crate::{ events::{BidiStreamEvent, GrpcEvent}, grpc::gateway::{client_state::ClientMap, map::GatewayMap}, server_config, - version::{IncompatibleComponents, IncompatibleProxyData, is_proxy_version_supported}, + version::{ + IncompatibleComponents, IncompatibleProxyData, MIN_GATEWAY_VERSION, + is_proxy_version_supported, + }, }; static VERSION_ZERO: Version = Version::new(0, 0, 0); @@ -546,6 +551,100 @@ async fn handle_proxy_message_loop( Ok(()) } +const GATEWAY_TABLE_TRIGGER: &str = "gateway_change"; + +/// Bi-directional gRPC stream for comminication with Defguard Gateway. +pub async fn run_grpc_gateway_stream( + pool: PgPool, + events_tx: Sender, + mail_tx: UnboundedSender, +) -> Result<(), anyhow::Error> { + let config = server_config(); + let tls_config = config.grpc_client_tls_config()?; + + let mut abort_handles = HashMap::new(); + + let mut tasks = JoinSet::new(); + // Helper closure to launch `GatewayHandler`. + let mut launch_gateway_handler = + |gateway: Gateway| -> Result { + let mut gateway_handler = GatewayHandler::new( + gateway, + tls_config.clone(), + pool.clone(), + events_tx.clone(), + mail_tx.clone(), + )?; + let abort_handle = tasks.spawn(async move { + gateway_handler.handle_connection().await; + }); + Ok(abort_handle) + }; + + let gateways = Gateway::all(&pool).await?; + for gateway in gateways { + let id = gateway.id; + let abort_handle = launch_gateway_handler(gateway)?; + abort_handles.insert(id, abort_handle); + } + + // Observe gateway URL changes. + let mut listener = PgListener::connect_with(&pool).await?; + listener.listen(GATEWAY_TABLE_TRIGGER).await?; + while let Ok(notification) = listener.recv().await { + let payload = notification.payload(); + match serde_json::from_str::>>(payload) { + Ok(gateway_notification) => match gateway_notification.operation { + TriggerOperation::Insert => { + if let Some(new) = gateway_notification.new { + let id = new.id; + let abort_handle = launch_gateway_handler(new)?; + abort_handles.insert(id, abort_handle); + } + } + TriggerOperation::Update => { + if let (Some(old), Some(new)) = + (gateway_notification.old, gateway_notification.new) + { + if old.url == new.url { + debug!( + "Gateway URL didn't change. Keeping the current gateway handler" + ); + } else if let Some(abort_handle) = abort_handles.remove(&old.id) { + info!("Aborting connection to {old}, it has changed in the database"); + abort_handle.abort(); + let id = new.id; + let abort_handle = launch_gateway_handler(new)?; + abort_handles.insert(id, abort_handle); + } else { + warn!("Cannot find {old} on the list of connected gateways"); + } + } + } + TriggerOperation::Delete => { + if let Some(old) = gateway_notification.old { + if let Some(abort_handle) = abort_handles.remove(&old.id) { + info!( + "Aborting connection to {old}, it has disappeard from the database" + ); + abort_handle.abort(); + } else { + warn!("Cannot find {old} on the list of connected gateways"); + } + } + } + }, + Err(err) => error!("Failed to de-serialize database notification object: {err}"), + } + } + + while let Some(Ok(_result)) = tasks.join_next().await { + debug!("Gateway gRPC task has ended"); + } + + Ok(()) +} + /// Bi-directional gRPC stream for communication with Defguard Proxy. #[instrument(skip_all)] pub async fn run_grpc_bidi_stream( From 20f1c6b05d730b0d78cc69ba752741cd99c4036d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Mon, 1 Dec 2025 11:17:37 +0100 Subject: [PATCH 04/17] Handle gateway stats --- Cargo.lock | 70 ++- crates/defguard/src/main.rs | 10 +- .../src/grpc/gateway/client_state.rs | 3 +- .../defguard_core/src/grpc/gateway/handler.rs | 277 ++++++++-- crates/defguard_core/src/grpc/gateway/mod.rs | 487 +++++++----------- .../defguard_core/src/grpc/gateway/state.rs | 2 +- crates/defguard_core/src/grpc/mod.rs | 36 +- 7 files changed, 462 insertions(+), 423 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5703ef2bc4..4e4af3ad94 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -574,9 +574,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.47" +version = "1.2.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd405d82c84ff7f35739f175f67d8b9fb7687a0e84ccdc78bd3568839827cf07" +checksum = "c481bdbf0ed3b892f6f806287d72acd515b352a4ec27a208489b8c1bc839633a" dependencies = [ "find-msvc-tools", "jobserver", @@ -616,7 +616,7 @@ dependencies = [ "num-traits", "serde", "wasm-bindgen", - "windows-link 0.2.1", + "windows-link", ] [[package]] @@ -2132,13 +2132,13 @@ dependencies = [ [[package]] name = "hostname" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a56f203cd1c76362b69e3863fd987520ac36cf70a8c92627449b2f64a8cf7d65" +checksum = "617aaa3557aef3810a6369d0a99fac8a080891b68bd9f9812a1eeda0c0730cbd" dependencies = [ "cfg-if", "libc", - "windows-link 0.1.3", + "windows-link", ] [[package]] @@ -3520,7 +3520,7 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-link 0.2.1", + "windows-link", ] [[package]] @@ -3636,11 +3636,12 @@ dependencies = [ [[package]] name = "petgraph" -version = "0.7.1" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" +checksum = "8701b58ea97060d5e5b155d383a69952a60943f0e6dfe30b04c287beb0b27455" dependencies = [ "fixedbitset", + "hashbrown 0.15.5", "indexmap 2.12.1", ] @@ -3904,9 +3905,9 @@ dependencies = [ [[package]] name = "prost" -version = "0.14.1" +version = "0.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7231bd9b3d3d33c86b58adbac74b5ec0ad9f496b19d22801d773636feaa95f3d" +checksum = "101fec8d036f8d9d4a1e8ebf90d566d1d798f3b1aa379d2576a54a0d9acea5bd" dependencies = [ "bytes", "prost-derive", @@ -3914,15 +3915,14 @@ dependencies = [ [[package]] name = "prost-build" -version = "0.14.1" +version = "0.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac6c3320f9abac597dcbc668774ef006702672474aad53c6d596b62e487b40b1" +checksum = "528a07106a21e01f4880c09818d0b7e73d0f0993536ddfff161754b5c20a086c" dependencies = [ "heck", "itertools 0.14.0", "log", "multimap", - "once_cell", "petgraph", "prettyplease", "prost", @@ -3936,9 +3936,9 @@ dependencies = [ [[package]] name = "prost-derive" -version = "0.14.1" +version = "0.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9120690fafc389a67ba3803df527d0ec9cbbc9cc45e4cc20b332996dfb672425" +checksum = "d2d93e596a829ebe00afa41c3a056e6308d6b8a4c7d869edf184e2c91b1ba564" dependencies = [ "anyhow", "itertools 0.14.0", @@ -3949,9 +3949,9 @@ dependencies = [ [[package]] name = "prost-types" -version = "0.14.1" +version = "0.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9b4db3d6da204ed77bb26ba83b6122a73aeb2e87e25fbf7ad2e84c4ccbf8f72" +checksum = "f5d7b7346e150de32340ae3390b8b3ffa37ad93ec31fb5dad86afe817619e4e7" dependencies = [ "prost", ] @@ -4424,9 +4424,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.13.0" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94182ad936a0c91c324cd46c6511b9510ed16af436d7b5bab34beab0afd55f7a" +checksum = "708c0f9d5f54ba0272468c1d306a52c495b31fa155e91bc25371e6df7996908c" dependencies = [ "web-time", "zeroize", @@ -5839,9 +5839,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.20" +version = "0.3.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" +checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" dependencies = [ "matchers", "nu-ansi-term", @@ -6409,7 +6409,7 @@ checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" dependencies = [ "windows-implement", "windows-interface", - "windows-link 0.2.1", + "windows-link", "windows-result", "windows-strings", ] @@ -6436,12 +6436,6 @@ dependencies = [ "syn", ] -[[package]] -name = "windows-link" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" - [[package]] name = "windows-link" version = "0.2.1" @@ -6454,7 +6448,7 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "02752bf7fbdcce7f2a27a742f798510f3e5ad88dbe84871e5168e2120c3d5720" dependencies = [ - "windows-link 0.2.1", + "windows-link", "windows-result", "windows-strings", ] @@ -6465,7 +6459,7 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" dependencies = [ - "windows-link 0.2.1", + "windows-link", ] [[package]] @@ -6474,7 +6468,7 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" dependencies = [ - "windows-link 0.2.1", + "windows-link", ] [[package]] @@ -6519,7 +6513,7 @@ version = "0.61.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" dependencies = [ - "windows-link 0.2.1", + "windows-link", ] [[package]] @@ -6559,7 +6553,7 @@ version = "0.53.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" dependencies = [ - "windows-link 0.2.1", + "windows-link", "windows_aarch64_gnullvm 0.53.1", "windows_aarch64_msvc 0.53.1", "windows_i686_gnu 0.53.1", @@ -6801,18 +6795,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.30" +version = "0.8.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ea879c944afe8a2b25fef16bb4ba234f47c694565e97383b36f3a878219065c" +checksum = "fd74ec98b9250adb3ca554bdde269adf631549f51d8a8f8f0a10b50f1cb298c3" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.30" +version = "0.8.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf955aa904d6040f70dc8e9384444cb1030aed272ba3cb09bbc4ab9e7c1f34f5" +checksum = "d8a8d209fdf45cf5138cbb5a506f6b52522a25afccc534d1475dad8e31105c6a" dependencies = [ "proc-macro2", "quote", diff --git a/crates/defguard/src/main.rs b/crates/defguard/src/main.rs index 5188cfd41e..11f25325a4 100644 --- a/crates/defguard/src/main.rs +++ b/crates/defguard/src/main.rs @@ -155,8 +155,10 @@ async fn main() -> Result<(), anyhow::Error> { tokio::select! { res = run_grpc_gateway_stream( pool.clone(), + client_state, wireguard_tx.clone(), - mail_tx.clone() + mail_tx.clone(), + grpc_event_tx, ) => error!("Gateway gRPC stream returned early: {res:?}"), res = run_grpc_bidi_stream( pool.clone(), @@ -168,15 +170,9 @@ async fn main() -> Result<(), anyhow::Error> { res = run_grpc_server( Arc::clone(&worker_state), pool.clone(), - Arc::clone(&gateway_state), - client_state, - wireguard_tx.clone(), - mail_tx.clone(), grpc_cert, grpc_key, failed_logins.clone(), - grpc_event_tx, - Arc::clone(&incompatible_components), ) => error!("gRPC server returned early: {res:?}"), res = run_web_server( worker_state, diff --git a/crates/defguard_core/src/grpc/gateway/client_state.rs b/crates/defguard_core/src/grpc/gateway/client_state.rs index 1bc49a404c..8f0f5ecd48 100644 --- a/crates/defguard_core/src/grpc/gateway/client_state.rs +++ b/crates/defguard_core/src/grpc/gateway/client_state.rs @@ -117,7 +117,8 @@ impl ClientMap { stats: &WireguardPeerStats, ) -> Result<(), ClientMapError> { info!( - "VPN client {} with public key {public_key} connected to location {location_id} through gateway {gateway_hostname}", + "VPN client {} with public key {public_key} connected to location {location_id} \ + through Gateway {gateway_hostname}", device.name ); diff --git a/crates/defguard_core/src/grpc/gateway/handler.rs b/crates/defguard_core/src/grpc/gateway/handler.rs index 53403fcc21..30c6ecc6f2 100644 --- a/crates/defguard_core/src/grpc/gateway/handler.rs +++ b/crates/defguard_core/src/grpc/gateway/handler.rs @@ -1,8 +1,13 @@ use std::{ + net::SocketAddr, str::FromStr, - sync::atomic::{AtomicU64, Ordering}, + sync::{ + Arc, Mutex, + atomic::{AtomicU64, Ordering}, + }, }; +use chrono::{TimeDelta, Utc}; use defguard_common::{auth::claims::Claims, db::Id}; use defguard_mail::Mail; use defguard_proto::gateway::{CoreResponse, core_request, core_response, gateway_client}; @@ -23,10 +28,10 @@ use tonic::{ use crate::{ ClaimsType, db::{ - Device, GatewayEvent, WireguardNetwork, + Device, GatewayEvent, User, WireguardNetwork, models::{gateway::Gateway, wireguard_peer_stats::WireguardPeerStats}, }, - grpc::TEN_SECS, + grpc::{ClientMap, GrpcEvent, TEN_SECS, gateway::GrpcRequestContext}, handlers::mail::send_gateway_disconnected_email, }; @@ -36,8 +41,10 @@ pub(crate) struct GatewayHandler { gateway: Gateway, message_id: AtomicU64, pool: PgPool, + client_state: Arc>, events_tx: Sender, mail_tx: UnboundedSender, + grpc_event_tx: UnboundedSender, } impl GatewayHandler { @@ -45,8 +52,10 @@ impl GatewayHandler { gateway: Gateway, tls_config: Option, pool: PgPool, + client_state: Arc>, events_tx: Sender, mail_tx: UnboundedSender, + grpc_event_tx: UnboundedSender, ) -> Result { let endpoint = Endpoint::from_shared(gateway.url.to_string())? .http2_keep_alive_interval(TEN_SECS) @@ -63,8 +72,10 @@ impl GatewayHandler { gateway, message_id: AtomicU64::new(0), pool, + client_state, events_tx, mail_tx, + grpc_event_tx, }) } @@ -195,6 +206,79 @@ impl GatewayHandler { }; } + /// Helper method to fetch `Device` info from DB by pubkey and return appropriate errors + async fn fetch_device_from_db(&self, public_key: &str) -> Result>, Status> { + let device = Device::find_by_pubkey(&self.pool, public_key) + .await + .map_err(|err| { + error!("Failed to retrieve device with public key {public_key}: {err}",); + Status::new( + Code::Internal, + format!("Failed to retrieve device with public key {public_key}: {err}",), + ) + })?; + + Ok(device) + } + + /// Helper method to fetch `WireguardNetwork` info from DB and return appropriate errors + async fn fetch_location_from_db( + &self, + location_id: Id, + ) -> Result, Status> { + let location = match WireguardNetwork::find_by_id(&self.pool, location_id).await { + Ok(Some(location)) => location, + Ok(None) => { + error!("Location {location_id} not found"); + return Err(Status::new( + Code::Internal, + format!("Location {location_id} not found"), + )); + } + Err(err) => { + error!("Failed to retrieve location {location_id}: {err}",); + return Err(Status::new( + Code::Internal, + format!("Failed to retrieve location {location_id}: {err}",), + )); + } + }; + Ok(location) + } + + /// Helper method to fetch `User` info from DB and return appropriate errors + async fn fetch_user_from_db(&self, user_id: Id, public_key: &str) -> Result, Status> { + let user = match User::find_by_id(&self.pool, user_id).await { + Ok(Some(user)) => user, + Ok(None) => { + error!("User {user_id} assigned to device with public key {public_key} not found"); + return Err(Status::new( + Code::Internal, + format!("User assigned to device with public key {public_key} not found"), + )); + } + Err(err) => { + error!( + "Failed to retrieve user {user_id} for device with public key {public_key}: {err}", + ); + return Err(Status::new( + Code::Internal, + format!( + "Failed to retrieve user for device with public key {public_key}: {err}", + ), + )); + } + }; + + Ok(user) + } + + fn emit_event(&self, event: GrpcEvent) { + if self.grpc_event_tx.send(event).is_err() { + warn!("Failed to send gRPC event"); + } + } + /// Connect to Gateway and handle its messages through gRPC. pub(crate) async fn handle_connection(&mut self) -> ! { let uri = self.endpoint.uri(); @@ -229,11 +313,11 @@ impl GatewayHandler { 'message: loop { match resp_stream.message().await { Ok(None) => { - info!("stream was closed by the sender"); + info!("Stream was closed by the sender."); break 'message; } Ok(Some(received)) => { - info!("Received message from gateway."); + info!("Received message from Gateway."); debug!("Message from Gateway {uri}"); match received.payload { Some(core_request::Payload::ConfigRequest(config_request)) => { @@ -307,6 +391,7 @@ impl GatewayHandler { }; // tokio::spawn(super::handle_events( // network, + // self.gateway.hostname.unwrap_or_default().clone(), // tx.clone(), // self.events_tx.subscribe(), // )); @@ -314,43 +399,163 @@ impl GatewayHandler { Some(core_request::Payload::PeerStats(peer_stats)) => { if !config_sent { warn!( - "Ignoring peer statistics from {} because it didn't \ + "Ignoring peer statistics from {} because it hasn't \ authorize itself", self.gateway ); continue; } - // let public_key = peer_stats.public_key.clone(); - // let mut stats = WireguardPeerStats::from_peer_stats( - // peer_stats, - // self.gateway.network_id, - - // ); - // // Get device by public key and fill in stats.device_id - // match Device::find_by_pubkey(&self.pool, &public_key).await { - // Ok(Some(device)) => { - // stats.device_id = device.id; - // match stats.save(&self.pool).await { - // Ok(_) => { - // info!("Saved WireGuard peer stats to database.") - // } - // Err(err) => error!( - // "Failed to save WireGuard peer stats to database: \ - // {err}" - // ), - // } - // } - // Ok(None) => { - // error!("Device with public key {public_key} not found"); - // } - // Err(err) => { - // error!( - // "Failed to retrieve device with public key \ - // {public_key}: {err}", - // ); - // } - // }; + let public_key = peer_stats.public_key.clone(); + + // fetch device from DB + // TODO: fetch only when device has changed and use client state + // otherwise + let Ok(Some(device)) = self.fetch_device_from_db(&public_key).await + else { + warn!( + "Received stats update for a device which does not \ + exist: {public_key}, skipping." + ); + continue; + }; + + // copy device ID for easier reference later + let device_id = device.id; + + // fetch user and location from DB for activity log + // TODO: cache usernames since they don't change + let Ok(user) = + self.fetch_user_from_db(device.user_id, &public_key).await + else { + continue; + }; + let Ok(location) = + self.fetch_location_from_db(self.gateway.network_id).await + else { + continue; + }; + + // Convert stats to database storage format. + let stats = WireguardPeerStats::from_peer_stats( + peer_stats, + self.gateway.network_id, + device_id, + ); + + // Only perform client state update if stats include an endpoint IP. + // Otherwise, a peer was added to the gateway interface, but hasn't + // connected yet. + if let Some(endpoint) = &stats.endpoint { + // parse client endpoint IP + let Ok(socket_addr) = endpoint.clone().parse::() + else { + error!("Failed to parse VPN client endpoint"); + continue; + }; + + // Perform client state operations in a dedicated block to drop + // mutex guard. + let disconnected_clients = { + // acquire lock on client state map + let mut client_map = self.client_state.lock().unwrap(); + + // update connected clients map + match client_map + .get_vpn_client(self.gateway.network_id, &public_key) + { + Some(client_state) => { + // update connected client state + client_state.update_client_state( + device, + socket_addr, + stats.latest_handshake, + stats.upload, + stats.download, + ); + } + None => { + // don't mark inactive peers as connected + if (Utc::now().naive_utc() - stats.latest_handshake) + < TimeDelta::seconds( + location.peer_disconnect_threshold.into(), + ) + { + // mark new VPN client as connected + if client_map + .connect_vpn_client( + self.gateway.network_id, + // Hostname is for logging only. + &self + .gateway + .hostname + .as_ref() + .cloned() + .unwrap_or_default(), + &public_key, + &device, + &user, + socket_addr, + &stats, + ) + .is_err() + { + // TODO: log message + continue; + } + + // emit connection event + let context = GrpcRequestContext::new( + user.id, + user.username.clone(), + socket_addr.ip(), + device.id, + device.name.clone(), + location.clone(), + ); + self.emit_event(GrpcEvent::ClientConnected { + context, + location: location.clone(), + device: device.clone(), + }); + } + } + } + + // disconnect inactive clients + let Ok(clients) = client_map + .disconnect_inactive_vpn_clients_for_location( + &location, + ) + else { + // TODO: log message + continue; + }; + clients + }; + + // emit client disconnect events + for (device, context) in disconnected_clients { + self.emit_event(GrpcEvent::ClientDisconnected { + context, + location: location.clone(), + device, + }); + } + } + + // Save stats to database. + let stats = match stats.save(&self.pool).await { + Ok(stats) => stats, + Err(err) => { + error!( + "Saving WireGuard peer stats to database failed: {err}" + ); + continue; + } + }; + info!("Saved WireGuard peer stats to database."); + debug!("WireGuard peer stats: {stats:?}"); } None => (), }; diff --git a/crates/defguard_core/src/grpc/gateway/mod.rs b/crates/defguard_core/src/grpc/gateway/mod.rs index ab6d576140..329a67f85f 100644 --- a/crates/defguard_core/src/grpc/gateway/mod.rs +++ b/crates/defguard_core/src/grpc/gateway/mod.rs @@ -1,40 +1,30 @@ use std::{ - net::{IpAddr, SocketAddr}, - pin::Pin, - sync::{Arc, Mutex, MutexGuard}, - task::{Context, Poll}, + net::IpAddr, + sync::{Arc, Mutex}, }; -use chrono::{DateTime, TimeDelta, Utc}; +use chrono::{DateTime, Utc}; use client_state::ClientMap; use defguard_common::db::{Id, NoId}; use defguard_mail::Mail; use defguard_proto::{ enterprise::firewall::FirewallConfig, - gateway::{ - Configuration, ConfigurationRequest, CoreResponse, Peer, PeerStats, Update, UpdateType, - core_response, update, - }, + gateway::{Configuration, Peer, PeerStats, Update, update}, }; use defguard_version::version_info_from_metadata; use semver::Version; use sqlx::PgPool; use thiserror::Error; -use tokio::{ - sync::{ - broadcast::{Receiver as BroadcastReceiver, Sender}, - mpsc::{self, Receiver, UnboundedSender, error::SendError}, - }, - task::JoinHandle, - time::{Duration, interval}, +use tokio::sync::{ + broadcast::{Receiver as BroadcastReceiver, Sender}, + mpsc::{self, UnboundedSender, error::SendError}, }; -use tokio_stream::Stream; -use tonic::{Code, Request, Response, Status, metadata::MetadataMap}; +use tonic::{Code, Status, metadata::MetadataMap}; use self::map::GatewayMap; use crate::{ db::{ - Device, GatewayEvent, User, + GatewayEvent, models::{wireguard::WireguardNetwork, wireguard_peer_stats::WireguardPeerStats}, }, events::{GrpcEvent, GrpcRequestContext}, @@ -164,86 +154,6 @@ impl GatewayServer { } } - pub fn get_client_state_guard(&self) -> Result, GatewayServerError> { - let client_state = self - .client_state - .lock() - .map_err(|_| GatewayServerError::ClientStateMutexError)?; - debug!("Current VPN client state map: {client_state:?}"); - Ok(client_state) - } - - fn emit_event(&self, event: GrpcEvent) -> Result<(), GatewayServerError> { - Ok(self.grpc_event_tx.send(event)?) - } - - /// Helper method to fetch `Device` info from DB by pubkey and return appropriate errors - async fn fetch_device_from_db(&self, public_key: &str) -> Result>, Status> { - let device = Device::find_by_pubkey(&self.pool, public_key) - .await - .map_err(|err| { - error!("Failed to retrieve device with public key {public_key}: {err}",); - Status::new( - Code::Internal, - format!("Failed to retrieve device with public key {public_key}: {err}",), - ) - })?; - - Ok(device) - } - - /// Helper method to fetch `WireguardNetwork` info from DB and return appropriate errors - async fn fetch_location_from_db( - &self, - location_id: Id, - ) -> Result, Status> { - let location = match WireguardNetwork::find_by_id(&self.pool, location_id).await { - Ok(Some(location)) => location, - Ok(None) => { - error!("Location {location_id} not found"); - return Err(Status::new( - Code::Internal, - format!("Location {location_id} not found"), - )); - } - Err(err) => { - error!("Failed to retrieve location {location_id}: {err}",); - return Err(Status::new( - Code::Internal, - format!("Failed to retrieve location {location_id}: {err}",), - )); - } - }; - Ok(location) - } - - /// Helper method to fetch `User` info from DB and return appropriate errors - async fn fetch_user_from_db(&self, user_id: Id, public_key: &str) -> Result, Status> { - let user = match User::find_by_id(&self.pool, user_id).await { - Ok(Some(user)) => user, - Ok(None) => { - error!("User {user_id} assigned to device with public key {public_key} not found"); - return Err(Status::new( - Code::Internal, - format!("User assigned to device with public key {public_key} not found"), - )); - } - Err(err) => { - error!( - "Failed to retrieve user {user_id} for device with public key {public_key}: {err}", - ); - return Err(Status::new( - Code::Internal, - format!( - "Failed to retrieve user for device with public key {public_key}: {err}", - ), - )); - } - }; - - Ok(user) - } - /// Utility function extracting metadata fields during gRPC communication. fn extract_metadata(metadata: &MetadataMap) -> Result { let (version, _info) = version_info_from_metadata(metadata); @@ -292,166 +202,163 @@ impl WireguardPeerStats { } } -/* - -/// Process received Gateway events -/// -/// Main gRPC server uses a shared channel for broadcasting all Gateway events, -/// so the handler must determine if an event is relevant for the network being serviced. -async fn handle_events( - mut current_network: WireguardNetwork, - tx: UnboundedSender, - mut events_rx: Receiver, -) { - info!("Starting update stream network {current_network}"); - while let Some(event) = events_rx.recv().await { - debug!("Received networking state update event: {event:?}"); - let (update_type, update) = match event { - GatewayEvent::NetworkCreated(network, _fixme) => { - if network.id != current_network.id { - continue; - } - ( - UpdateType::Create, - update::Update::Network(Configuration { - name: network.name.clone(), - prvkey: network.prvkey.clone(), - addresses: network.address.to_string(), - port: network.port as u32, - peers: Vec::new(), - }), - ) - } - GatewayEvent::NetworkModified(network, peers, _fixme) => { - if network.id != current_network.id { - continue; - } - // update stored network data - current_network = network.clone(); - ( - UpdateType::Modify, - update::Update::Network(Configuration { - name: network.name, - prvkey: network.prvkey, - addresses: network.address.to_string(), - port: network.port as u32, - peers, - }), - ) - } - GatewayEvent::NetworkDeleted(network_id, network_name) => { - if network_id != current_network.id { - continue; - } - ( - UpdateType::Delete, - update::Update::Network(Configuration { - name: network_name.to_string(), - prvkey: String::new(), - addresses: Vec::new(), - port: 0, - peers: Vec::new(), - firewall_config: None, - }), - ) - } - GatewayEvent::DeviceCreated(device) => { - // check if a peer has to be added in the current network - match device - .network_info - .iter() - .find(|info| info.network_id == current_network.id) - { - Some(network_info) => { - if current_network.mfa_enabled && !network_info.is_authorized { - debug!( - "Created WireGuard device {} is not authorized to connect to MFA enabled location {}", - device.device.name, current_network.name - ); - continue; - }; - let peer = Peer { - pubkey: device.device.wireguard_pubkey, - allowed_ips: vec![network_info.device_wireguard_ip.to_string()], - preshared_key: network_info.preshared_key.clone(), - keepalive_interval: Some(current_network.keepalive_interval as u32), - }; - (UpdateType::Create, update::Update::Peer(peer)) - } - None => continue, - } - } - GatewayEvent::DeviceModified(device) => { - // check if a peer has to be updated in the current network - match device - .network_info - .iter() - .find(|info| info.network_id == current_network.id) - { - Some(network_info) => { - if current_network.mfa_enabled && !network_info.is_authorized { - debug!( - "Modified WireGuard device {} is not authorized to connect to MFA enabled location {}", - device.device.name, current_network.name - ); - continue; - }; - let peer = Peer { - pubkey: device.device.wireguard_pubkey, - allowed_ips: vec![network_info.device_wireguard_ip.to_string()], - preshared_key: network_info.preshared_key.clone(), - keepalive_interval: Some(current_network.keepalive_interval as u32), - }; - (UpdateType::Modify, update::Update::Peer(peer)) - } - None => continue, - } - } - GatewayEvent::DeviceDeleted(device) => { - // check if a peer has to be updated in the current network - match device - .network_info - .iter() - .find(|info| info.network_id == current_network.id) - { - Some(_) => ( - UpdateType::Delete, - update::Update::Peer(Peer { - pubkey: device.device.wireguard_pubkey, - allowed_ips: Vec::new(), - preshared_key: None, - keepalive_interval: None, - }), - ), - None => continue, - } - } - GatewayEvent::FirewallConfigChanged(_fixme, _) => (), - GatewayEvent::FirewallDisabled(_id) => (), - }; +// /// Process received Gateway events +// /// +// /// Main gRPC server uses a shared channel for broadcasting all Gateway events, +// /// so the handler must determine if an event is relevant for the network being serviced. +// async fn handle_events( +// mut current_network: WireguardNetwork, +// tx: UnboundedSender, +// mut events_rx: BroadcastReceiver, +// ) { +// info!("Starting update stream network {current_network}"); +// while let Some(event) = events_rx.recv().await { +// debug!("Received networking state update event: {event:?}"); +// let (update_type, update) = match event { +// GatewayEvent::NetworkCreated(network, _fixme) => { +// if network.id != current_network.id { +// continue; +// } +// ( +// UpdateType::Create, +// update::Update::Network(Configuration { +// name: network.name.clone(), +// prvkey: network.prvkey.clone(), +// addresses: network.address.to_string(), +// port: network.port as u32, +// peers: Vec::new(), +// }), +// ) +// } +// GatewayEvent::NetworkModified(network, peers, _fixme) => { +// if network.id != current_network.id { +// continue; +// } +// // update stored network data +// current_network = network.clone(); +// ( +// UpdateType::Modify, +// update::Update::Network(Configuration { +// name: network.name, +// prvkey: network.prvkey, +// addresses: network.address.to_string(), +// port: network.port as u32, +// peers, +// }), +// ) +// } +// GatewayEvent::NetworkDeleted(network_id, network_name) => { +// if network_id != current_network.id { +// continue; +// } +// ( +// UpdateType::Delete, +// update::Update::Network(Configuration { +// name: network_name.to_string(), +// prvkey: String::new(), +// addresses: Vec::new(), +// port: 0, +// peers: Vec::new(), +// firewall_config: None, +// }), +// ) +// } +// GatewayEvent::DeviceCreated(device) => { +// // check if a peer has to be added in the current network +// match device +// .network_info +// .iter() +// .find(|info| info.network_id == current_network.id) +// { +// Some(network_info) => { +// if current_network.mfa_enabled && !network_info.is_authorized { +// debug!( +// "Created WireGuard device {} is not authorized to connect to MFA enabled location {}", +// device.device.name, current_network.name +// ); +// continue; +// }; +// let peer = Peer { +// pubkey: device.device.wireguard_pubkey, +// allowed_ips: vec![network_info.device_wireguard_ip.to_string()], +// preshared_key: network_info.preshared_key.clone(), +// keepalive_interval: Some(current_network.keepalive_interval as u32), +// }; +// (UpdateType::Create, update::Update::Peer(peer)) +// } +// None => continue, +// } +// } +// GatewayEvent::DeviceModified(device) => { +// // check if a peer has to be updated in the current network +// match device +// .network_info +// .iter() +// .find(|info| info.network_id == current_network.id) +// { +// Some(network_info) => { +// if current_network.mfa_enabled && !network_info.is_authorized { +// debug!( +// "Modified WireGuard device {} is not authorized to connect to MFA enabled location {}", +// device.device.name, current_network.name +// ); +// continue; +// }; +// let peer = Peer { +// pubkey: device.device.wireguard_pubkey, +// allowed_ips: vec![network_info.device_wireguard_ip.to_string()], +// preshared_key: network_info.preshared_key.clone(), +// keepalive_interval: Some(current_network.keepalive_interval as u32), +// }; +// (UpdateType::Modify, update::Update::Peer(peer)) +// } +// None => continue, +// } +// } +// GatewayEvent::DeviceDeleted(device) => { +// // check if a peer has to be updated in the current network +// match device +// .network_info +// .iter() +// .find(|info| info.network_id == current_network.id) +// { +// Some(_) => ( +// UpdateType::Delete, +// update::Update::Peer(Peer { +// pubkey: device.device.wireguard_pubkey, +// allowed_ips: Vec::new(), +// preshared_key: None, +// keepalive_interval: None, +// }), +// ), +// None => continue, +// } +// } +// GatewayEvent::FirewallConfigChanged(_fixme, _) => (), +// GatewayEvent::FirewallDisabled(_id) => (), +// }; - let req = CoreResponse { - id: 0, - payload: Some(core_response::Payload::Update(Update { - update_type: update_type as i32, - update: Some(update), - })), - }; - if let Err(err) = tx.send(req) { - error!( - "Failed to send network update, network {current_network}, update type: {}, error: \ - {err}", - update_type.as_str_name() - ); - break; - } - debug!( - "Network update sent for network {current_network}, update type: {}", - update_type.as_str_name() - ); - } -} -*/ +// let req = CoreResponse { +// id: 0, +// payload: Some(core_response::Payload::Update(Update { +// update_type: update_type as i32, +// update: Some(update), +// })), +// }; +// if let Err(err) = tx.send(req) { +// error!( +// "Failed to send network update, network {current_network}, update type: {}, error: \ +// {err}", +// update_type.as_str_name() +// ); +// break; +// } +// debug!( +// "Network update sent for network {current_network}, update type: {}", +// update_type.as_str_name() +// ); +// } +// } /// Helper struct for handling gateway events struct GatewayUpdatesHandler { @@ -479,7 +386,7 @@ impl GatewayUpdatesHandler { } } - /// Process incoming gateway events + /// Process incoming Gateway events /// /// Main gRPC server uses a shared channel for broadcasting all gateway events /// so the handler must determine if an event is relevant for the network being serviced @@ -797,58 +704,14 @@ impl GatewayUpdatesHandler { } } -pub struct GatewayUpdatesStream { - task_handle: JoinHandle<()>, - rx: Receiver>, - network_id: Id, - gateway_hostname: String, - gateway_state: Arc>, - pool: PgPool, -} - -impl GatewayUpdatesStream { - #[must_use] - pub fn new( - task_handle: JoinHandle<()>, - rx: Receiver>, - network_id: Id, - gateway_hostname: String, - gateway_state: Arc>, - pool: PgPool, - ) -> Self { - Self { - task_handle, - rx, - network_id, - gateway_hostname, - gateway_state, - pool, - } - } -} - -impl Stream for GatewayUpdatesStream { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.rx).poll_recv(cx) - } -} - -impl Drop for GatewayUpdatesStream { - fn drop(&mut self) { - info!("Client disconnected"); - // terminate update task - self.task_handle.abort(); - // update gateway state - // TODO: possibly use a oneshot channel instead - self.gateway_state - .lock() - .unwrap() - .disconnect_gateway(self.network_id, self.gateway_hostname.clone(), &self.pool) - .expect("Unable to disconnect gateway."); - } -} +// pub struct GatewayUpdatesStream { +// task_handle: JoinHandle<()>, +// rx: Receiver>, +// network_id: Id, +// gateway_hostname: String, +// gateway_state: Arc>, +// pool: PgPool, +// } // #[tonic::async_trait] // impl gateway_service_server::GatewayService for GatewayServer { @@ -871,7 +734,7 @@ impl Drop for GatewayUpdatesStream { // // version = version.to_string(), info); // // let _guard = span.enter(); // loop { -// // Wait for a message or update client map at least once a mninute, if no messages are +// // Wait for a message or update client map at least once a minute, if no messages are // // received. // let stats_update = tokio::select! { // message = stream.message() => { diff --git a/crates/defguard_core/src/grpc/gateway/state.rs b/crates/defguard_core/src/grpc/gateway/state.rs index 788801106f..0f9b10f629 100644 --- a/crates/defguard_core/src/grpc/gateway/state.rs +++ b/crates/defguard_core/src/grpc/gateway/state.rs @@ -13,7 +13,7 @@ use utoipa::ToSchema; use uuid::Uuid; use crate::{ - db::models::gateway::Gateway, + // db::models::gateway::Gateway, grpc::MIN_GATEWAY_VERSION, handlers::mail::{send_gateway_disconnected_email, send_gateway_reconnected_email}, }; diff --git a/crates/defguard_core/src/grpc/mod.rs b/crates/defguard_core/src/grpc/mod.rs index cee430a613..8b9e126eef 100644 --- a/crates/defguard_core/src/grpc/mod.rs +++ b/crates/defguard_core/src/grpc/mod.rs @@ -15,7 +15,7 @@ use defguard_common::{ use defguard_mail::Mail; use defguard_version::{ ComponentInfo, DefguardComponent, Version, client::ClientVersionInterceptor, - get_tracing_variables, server::DefguardVersionLayer, + get_tracing_variables, }; use openidconnect::{AuthorizationCode, Nonce, Scope, core::CoreAuthenticationFlow}; use reqwest::Url; @@ -65,7 +65,7 @@ use crate::{ ldap::utils::ldap_update_user_state, }, events::{BidiStreamEvent, GrpcEvent}, - grpc::gateway::{client_state::ClientMap, map::GatewayMap}, + grpc::gateway::client_state::ClientMap, server_config, version::{ IncompatibleComponents, IncompatibleProxyData, MIN_GATEWAY_VERSION, @@ -556,8 +556,10 @@ const GATEWAY_TABLE_TRIGGER: &str = "gateway_change"; /// Bi-directional gRPC stream for comminication with Defguard Gateway. pub async fn run_grpc_gateway_stream( pool: PgPool, + client_state: Arc>, events_tx: Sender, mail_tx: UnboundedSender, + grpc_event_tx: UnboundedSender, ) -> Result<(), anyhow::Error> { let config = server_config(); let tls_config = config.grpc_client_tls_config()?; @@ -572,8 +574,10 @@ pub async fn run_grpc_gateway_stream( gateway, tls_config.clone(), pool.clone(), + Arc::clone(&client_state), events_tx.clone(), mail_tx.clone(), + grpc_event_tx.clone(), )?; let abort_handle = tasks.spawn(async move { gateway_handler.handle_connection().await; @@ -581,8 +585,7 @@ pub async fn run_grpc_gateway_stream( Ok(abort_handle) }; - let gateways = Gateway::all(&pool).await?; - for gateway in gateways { + for gateway in Gateway::all(&pool).await? { let id = gateway.id; let abort_handle = launch_gateway_handler(gateway)?; abort_handles.insert(id, abort_handle); @@ -757,15 +760,9 @@ pub async fn run_grpc_bidi_stream( pub async fn run_grpc_server( worker_state: Arc>, pool: PgPool, - gateway_state: Arc>, - client_state: Arc>, - wireguard_tx: Sender, - mail_tx: UnboundedSender, grpc_cert: Option, grpc_key: Option, failed_logins: Arc>, - grpc_event_tx: UnboundedSender, - incompatible_components: Arc>, ) -> Result<(), anyhow::Error> { // Build gRPC services let server = if let (Some(cert), Some(key)) = (grpc_cert, grpc_key) { @@ -775,19 +772,7 @@ pub async fn run_grpc_server( Server::builder() }; - let router = build_grpc_service_router( - server, - pool, - worker_state, - // gateway_state, - // client_state, - // wireguard_tx, - // mail_tx, - failed_logins, - // grpc_event_tx, - // incompatible_components, - ) - .await?; + let router = build_grpc_service_router(server, pool, worker_state, failed_logins).await?; // Run gRPC server let addr = SocketAddr::new( @@ -806,12 +791,7 @@ pub async fn build_grpc_service_router( server: Server, pool: PgPool, worker_state: Arc>, - // gateway_state: Arc>, - // client_state: Arc>, - // wireguard_tx: Sender, - // mail_tx: UnboundedSender, failed_logins: Arc>, - // grpc_event_tx: UnboundedSender, // incompatible_components: Arc>, ) -> Result { let auth_service = AuthServiceServer::new(AuthServer::new(pool.clone(), failed_logins)); From febd7e628f3e4d679657642a5a4732e7874a5b0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Mon, 1 Dec 2025 13:55:58 +0100 Subject: [PATCH 05/17] Unclog GatewayUpdatesHandler --- .../defguard_core/src/grpc/gateway/handler.rs | 52 +- crates/defguard_core/src/grpc/gateway/mod.rs | 655 +++++------------- 2 files changed, 219 insertions(+), 488 deletions(-) diff --git a/crates/defguard_core/src/grpc/gateway/handler.rs b/crates/defguard_core/src/grpc/gateway/handler.rs index 30c6ecc6f2..e5527ee16e 100644 --- a/crates/defguard_core/src/grpc/gateway/handler.rs +++ b/crates/defguard_core/src/grpc/gateway/handler.rs @@ -80,7 +80,10 @@ impl GatewayHandler { } /// Send network and VPN configuration to Gateway. - async fn send_configuration(&self, tx: &UnboundedSender) -> Result<(), Status> { + async fn send_configuration( + &self, + tx: &UnboundedSender, + ) -> Result, Status> { debug!("Sending configuration to Gateway"); let network_id = self.gateway.network_id; // let hostname = Self::get_gateway_hostname(request.metadata())?; @@ -146,7 +149,7 @@ impl GatewayHandler { match tx.send(req) { Ok(()) => { info!("Configuration sent to {}, network {network}", self.gateway); - Ok(()) + Ok(network) } Err(err) => { error!("Failed to send configuration sent to {}", self.gateway); @@ -360,13 +363,32 @@ impl GatewayHandler { // Send network configuration to Gateway. match self.send_configuration(&tx).await { - Ok(()) => { + Ok(network) => { info!("Sent configuration to {}", self.gateway); config_sent = true; let _ = self .gateway .touch_connected(&self.pool, config_request.hostname) .await; + let guh = super::GatewayUpdatesHandler::new( + self.gateway.network_id, + network, + self + .gateway + .hostname + .as_ref() + .cloned() + .unwrap_or_default() + .clone(), + self.events_tx.subscribe(), + tx.clone(), + ); + // tokio::spawn(super::handle_events( + // network, + // // self.gateway.hostname.unwrap_or_default().clone(), + // tx.clone(), + // self.events_tx.subscribe(), + // )); } Err(err) => { error!( @@ -375,26 +397,6 @@ impl GatewayHandler { ); } } - - // Start observing configuration changes. - let Ok(Some(network)) = WireguardNetwork::find_by_id( - &self.pool, - self.gateway.network_id, - ) - .await - else { - error!( - "Failed to fetch network ID {} from the database", - self.gateway.network_id - ); - continue; - }; - // tokio::spawn(super::handle_events( - // network, - // self.gateway.hostname.unwrap_or_default().clone(), - // tx.clone(), - // self.events_tx.subscribe(), - // )); } Some(core_request::Payload::PeerStats(peer_stats)) => { if !config_sent { @@ -408,7 +410,7 @@ impl GatewayHandler { let public_key = peer_stats.public_key.clone(); - // fetch device from DB + // Fetch device from database. // TODO: fetch only when device has changed and use client state // otherwise let Ok(Some(device)) = self.fetch_device_from_db(&public_key).await @@ -561,7 +563,7 @@ impl GatewayHandler { }; } Err(err) => { - error!("Disconnected from gateway at {uri}, error: {err}"); + error!("Disconnected from Gateway at {uri}, error: {err}"); // Important: call this funtion before setting disconnection time. self.send_disconnect_notification().await; let _ = self.gateway.touch_disconnected(&self.pool).await; diff --git a/crates/defguard_core/src/grpc/gateway/mod.rs b/crates/defguard_core/src/grpc/gateway/mod.rs index 329a67f85f..ae2ff3ce5e 100644 --- a/crates/defguard_core/src/grpc/gateway/mod.rs +++ b/crates/defguard_core/src/grpc/gateway/mod.rs @@ -9,7 +9,7 @@ use defguard_common::db::{Id, NoId}; use defguard_mail::Mail; use defguard_proto::{ enterprise::firewall::FirewallConfig, - gateway::{Configuration, Peer, PeerStats, Update, update}, + gateway::{Configuration, CoreResponse, Peer, PeerStats, Update, core_response, update}, }; use defguard_version::version_info_from_metadata; use semver::Version; @@ -202,163 +202,164 @@ impl WireguardPeerStats { } } -// /// Process received Gateway events -// /// -// /// Main gRPC server uses a shared channel for broadcasting all Gateway events, -// /// so the handler must determine if an event is relevant for the network being serviced. -// async fn handle_events( -// mut current_network: WireguardNetwork, -// tx: UnboundedSender, -// mut events_rx: BroadcastReceiver, -// ) { -// info!("Starting update stream network {current_network}"); -// while let Some(event) = events_rx.recv().await { -// debug!("Received networking state update event: {event:?}"); -// let (update_type, update) = match event { -// GatewayEvent::NetworkCreated(network, _fixme) => { -// if network.id != current_network.id { -// continue; -// } -// ( -// UpdateType::Create, -// update::Update::Network(Configuration { -// name: network.name.clone(), -// prvkey: network.prvkey.clone(), -// addresses: network.address.to_string(), -// port: network.port as u32, -// peers: Vec::new(), -// }), -// ) -// } -// GatewayEvent::NetworkModified(network, peers, _fixme) => { -// if network.id != current_network.id { -// continue; -// } -// // update stored network data -// current_network = network.clone(); -// ( -// UpdateType::Modify, -// update::Update::Network(Configuration { -// name: network.name, -// prvkey: network.prvkey, -// addresses: network.address.to_string(), -// port: network.port as u32, -// peers, -// }), -// ) -// } -// GatewayEvent::NetworkDeleted(network_id, network_name) => { -// if network_id != current_network.id { -// continue; -// } -// ( -// UpdateType::Delete, -// update::Update::Network(Configuration { -// name: network_name.to_string(), -// prvkey: String::new(), -// addresses: Vec::new(), -// port: 0, -// peers: Vec::new(), -// firewall_config: None, -// }), -// ) -// } -// GatewayEvent::DeviceCreated(device) => { -// // check if a peer has to be added in the current network -// match device -// .network_info -// .iter() -// .find(|info| info.network_id == current_network.id) -// { -// Some(network_info) => { -// if current_network.mfa_enabled && !network_info.is_authorized { -// debug!( -// "Created WireGuard device {} is not authorized to connect to MFA enabled location {}", -// device.device.name, current_network.name -// ); -// continue; -// }; -// let peer = Peer { -// pubkey: device.device.wireguard_pubkey, -// allowed_ips: vec![network_info.device_wireguard_ip.to_string()], -// preshared_key: network_info.preshared_key.clone(), -// keepalive_interval: Some(current_network.keepalive_interval as u32), -// }; -// (UpdateType::Create, update::Update::Peer(peer)) -// } -// None => continue, -// } -// } -// GatewayEvent::DeviceModified(device) => { -// // check if a peer has to be updated in the current network -// match device -// .network_info -// .iter() -// .find(|info| info.network_id == current_network.id) -// { -// Some(network_info) => { -// if current_network.mfa_enabled && !network_info.is_authorized { -// debug!( -// "Modified WireGuard device {} is not authorized to connect to MFA enabled location {}", -// device.device.name, current_network.name -// ); -// continue; -// }; -// let peer = Peer { -// pubkey: device.device.wireguard_pubkey, -// allowed_ips: vec![network_info.device_wireguard_ip.to_string()], -// preshared_key: network_info.preshared_key.clone(), -// keepalive_interval: Some(current_network.keepalive_interval as u32), -// }; -// (UpdateType::Modify, update::Update::Peer(peer)) -// } -// None => continue, -// } -// } -// GatewayEvent::DeviceDeleted(device) => { -// // check if a peer has to be updated in the current network -// match device -// .network_info -// .iter() -// .find(|info| info.network_id == current_network.id) -// { -// Some(_) => ( -// UpdateType::Delete, -// update::Update::Peer(Peer { -// pubkey: device.device.wireguard_pubkey, -// allowed_ips: Vec::new(), -// preshared_key: None, -// keepalive_interval: None, -// }), -// ), -// None => continue, -// } -// } -// GatewayEvent::FirewallConfigChanged(_fixme, _) => (), -// GatewayEvent::FirewallDisabled(_id) => (), -// }; - -// let req = CoreResponse { -// id: 0, -// payload: Some(core_response::Payload::Update(Update { -// update_type: update_type as i32, -// update: Some(update), -// })), -// }; -// if let Err(err) = tx.send(req) { -// error!( -// "Failed to send network update, network {current_network}, update type: {}, error: \ -// {err}", -// update_type.as_str_name() -// ); -// break; -// } -// debug!( -// "Network update sent for network {current_network}, update type: {}", -// update_type.as_str_name() -// ); -// } -// } +/// Process received Gateway events +/// +/// Main gRPC server uses a shared channel for broadcasting all Gateway events, +/// so the handler must determine if an event is relevant for the network being serviced. +async fn handle_events( + mut current_network: WireguardNetwork, + // gateway_hostname: String, + tx: UnboundedSender, + mut events_rx: BroadcastReceiver, +) { + info!("Starting update stream network {current_network}"); + // while let Some(event) = events_rx.recv().await { + // debug!("Received networking state update event: {event:?}"); + // let (update_type, update) = match event { + // GatewayEvent::NetworkCreated(network, _fixme) => { + // if network.id != current_network.id { + // continue; + // } + // ( + // UpdateType::Create, + // update::Update::Network(Configuration { + // name: network.name.clone(), + // prvkey: network.prvkey.clone(), + // addresses: network.address.to_string(), + // port: network.port as u32, + // peers: Vec::new(), + // }), + // ) + // } + // GatewayEvent::NetworkModified(network, peers, _fixme) => { + // if network.id != current_network.id { + // continue; + // } + // // update stored network data + // current_network = network.clone(); + // ( + // UpdateType::Modify, + // update::Update::Network(Configuration { + // name: network.name, + // prvkey: network.prvkey, + // addresses: network.address.to_string(), + // port: network.port as u32, + // peers, + // }), + // ) + // } + // GatewayEvent::NetworkDeleted(network_id, network_name) => { + // if network_id != current_network.id { + // continue; + // } + // ( + // UpdateType::Delete, + // update::Update::Network(Configuration { + // name: network_name.to_string(), + // prvkey: String::new(), + // addresses: Vec::new(), + // port: 0, + // peers: Vec::new(), + // firewall_config: None, + // }), + // ) + // } + // GatewayEvent::DeviceCreated(device) => { + // // check if a peer has to be added in the current network + // match device + // .network_info + // .iter() + // .find(|info| info.network_id == current_network.id) + // { + // Some(network_info) => { + // if current_network.mfa_enabled && !network_info.is_authorized { + // debug!( + // "Created WireGuard device {} is not authorized to connect to MFA enabled location {}", + // device.device.name, current_network.name + // ); + // continue; + // }; + // let peer = Peer { + // pubkey: device.device.wireguard_pubkey, + // allowed_ips: vec![network_info.device_wireguard_ip.to_string()], + // preshared_key: network_info.preshared_key.clone(), + // keepalive_interval: Some(current_network.keepalive_interval as u32), + // }; + // (UpdateType::Create, update::Update::Peer(peer)) + // } + // None => continue, + // } + // } + // GatewayEvent::DeviceModified(device) => { + // // check if a peer has to be updated in the current network + // match device + // .network_info + // .iter() + // .find(|info| info.network_id == current_network.id) + // { + // Some(network_info) => { + // if current_network.mfa_enabled && !network_info.is_authorized { + // debug!( + // "Modified WireGuard device {} is not authorized to connect to MFA enabled location {}", + // device.device.name, current_network.name + // ); + // continue; + // }; + // let peer = Peer { + // pubkey: device.device.wireguard_pubkey, + // allowed_ips: vec![network_info.device_wireguard_ip.to_string()], + // preshared_key: network_info.preshared_key.clone(), + // keepalive_interval: Some(current_network.keepalive_interval as u32), + // }; + // (UpdateType::Modify, update::Update::Peer(peer)) + // } + // None => continue, + // } + // } + // GatewayEvent::DeviceDeleted(device) => { + // // check if a peer has to be updated in the current network + // match device + // .network_info + // .iter() + // .find(|info| info.network_id == current_network.id) + // { + // Some(_) => ( + // UpdateType::Delete, + // update::Update::Peer(Peer { + // pubkey: device.device.wireguard_pubkey, + // allowed_ips: Vec::new(), + // preshared_key: None, + // keepalive_interval: None, + // }), + // ), + // None => continue, + // } + // } + // GatewayEvent::FirewallConfigChanged(_fixme, _) => (), + // GatewayEvent::FirewallDisabled(_id) => (), + // }; + + // let req = CoreResponse { + // id: 0, + // payload: Some(core_response::Payload::Update(Update { + // update_type: update_type as i32, + // update: Some(update), + // })), + // }; + // if let Err(err) = tx.send(req) { + // error!( + // "Failed to send network update, network {current_network}, update type: {}, error: \ + // {err}", + // update_type.as_str_name() + // ); + // break; + // } + // debug!( + // "Network update sent for network {current_network}, update type: {}", + // update_type.as_str_name() + // ); + // } +} /// Helper struct for handling gateway events struct GatewayUpdatesHandler { @@ -366,7 +367,7 @@ struct GatewayUpdatesHandler { network: WireguardNetwork, gateway_hostname: String, events_rx: BroadcastReceiver, - tx: mpsc::Sender>, + tx: UnboundedSender, } impl GatewayUpdatesHandler { @@ -375,7 +376,7 @@ impl GatewayUpdatesHandler { network: WireguardNetwork, gateway_hostname: String, events_rx: BroadcastReceiver, - tx: mpsc::Sender>, + tx: UnboundedSender, ) -> Self { Self { network_id, @@ -545,9 +546,9 @@ impl GatewayUpdatesHandler { update_type: i32, ) -> Result<(), Status> { debug!("Sending network update for network {network}"); - if let Err(err) = self - .tx - .send(Ok(Update { + if let Err(err) = self.tx.send(CoreResponse { + id: 0, + payload: Some(core_response::Payload::Update(Update { update_type, update: Some(update::Update::Network(Configuration { name: network.name.clone(), @@ -557,9 +558,8 @@ impl GatewayUpdatesHandler { peers, firewall_config, })), - })) - .await - { + })), + }) { let msg = format!( "Failed to send network update, network {network}, update type: {update_type} ({}), error: {err}", if update_type == 0 { "CREATE" } else { "MODIFY" }, @@ -577,9 +577,9 @@ impl GatewayUpdatesHandler { "Sending network delete command for network {}", self.network ); - if let Err(err) = self - .tx - .send(Ok(Update { + if let Err(err) = self.tx.send(CoreResponse { + id: 0, + payload: Some(core_response::Payload::Update(Update { update_type: 2, update: Some(update::Update::Network(Configuration { name: network_name.to_string(), @@ -589,9 +589,8 @@ impl GatewayUpdatesHandler { peers: Vec::new(), firewall_config: None, })), - })) - .await - { + })), + }) { let msg = format!( "Failed to send network update, network {}, update type: 2 (DELETE), error: {err}", self.network, @@ -606,14 +605,13 @@ impl GatewayUpdatesHandler { /// Send update peer command to gateway async fn send_peer_update(&self, peer: Peer, update_type: i32) -> Result<(), Status> { debug!("Sending peer update for network {}", self.network); - if let Err(err) = self - .tx - .send(Ok(Update { + if let Err(err) = self.tx.send(CoreResponse { + id: 0, + payload: Some(core_response::Payload::Update(Update { update_type, update: Some(update::Update::Peer(peer)), - })) - .await - { + })), + }) { let msg = format!( "Failed to send peer update for network {}, update type: {update_type} ({}), error: {err}", self.network, @@ -629,9 +627,9 @@ impl GatewayUpdatesHandler { /// Send delete peer command to gateway async fn send_peer_delete(&self, peer_pubkey: &str) -> Result<(), Status> { debug!("Sending peer delete for network {}", self.network); - if let Err(err) = self - .tx - .send(Ok(Update { + if let Err(err) = self.tx.send(CoreResponse { + id: 0, + payload: Some(core_response::Payload::Update(Update { update_type: 2, update: Some(update::Update::Peer(Peer { pubkey: peer_pubkey.into(), @@ -639,9 +637,8 @@ impl GatewayUpdatesHandler { preshared_key: None, keepalive_interval: None, })), - })) - .await - { + })), + }) { let msg = format!( "Failed to send peer update for network {}, peer {peer_pubkey}, update type: 2 (DELETE), error: {err}", self.network, @@ -659,14 +656,13 @@ impl GatewayUpdatesHandler { "Sending firewall config update for network {} with config {firewall_config:?}", self.network ); - if let Err(err) = self - .tx - .send(Ok(Update { + if let Err(err) = self.tx.send(CoreResponse { + id: 0, + payload: Some(core_response::Payload::Update(Update { update_type: 1, update: Some(update::Update::FirewallConfig(firewall_config)), - })) - .await - { + })), + }) { let msg = format!( "Failed to send firewall config update for network {}, error: {err}", self.network, @@ -684,14 +680,13 @@ impl GatewayUpdatesHandler { "Sending firewall disable command for network {}", self.network ); - if let Err(err) = self - .tx - .send(Ok(Update { + if let Err(err) = self.tx.send(CoreResponse { + id: 0, + payload: Some(core_response::Payload::Update(Update { update_type: 2, update: Some(update::Update::DisableFirewall(())), - })) - .await - { + })), + }) { let msg = format!( "Failed to send firewall disable command for network {}, error: {err}", self.network, @@ -716,273 +711,7 @@ impl GatewayUpdatesHandler { // #[tonic::async_trait] // impl gateway_service_server::GatewayService for GatewayServer { // type UpdatesStream = GatewayUpdatesStream; - -// /// Retrieve stats from gateway and save it to database -// async fn stats( -// &self, -// request: Request>, -// ) -> Result, Status> { -// let GatewayMetadata { -// network_id, -// hostname, -// .. -// } = Self::extract_metadata(request.metadata())?; -// let mut stream = request.into_inner(); -// let mut disconnect_timer = interval(Duration::from_secs(PEER_DISCONNECT_INTERVAL)); -// // FIXME: tracing causes looping messages, like `INFO gateway_config:gateway_stats:...`. -// // let span = tracing::info_span!("gateway_stats", component = %DefguardComponent::Gateway, -// // version = version.to_string(), info); -// // let _guard = span.enter(); -// loop { -// // Wait for a message or update client map at least once a minute, if no messages are -// // received. -// let stats_update = tokio::select! { -// message = stream.message() => { -// match message? { -// Some(update) => update, -// None => break, // Stream ended -// } -// } -// _ = disconnect_timer.tick() => { -// debug!("No stats updates received in last {PEER_DISCONNECT_INTERVAL} seconds. \ -// Updating disconnected VPN clients"); -// // fetch location to get current peer disconnect threshold -// let location = self.fetch_location_from_db(network_id).await?; - -// // perform client state operations in a dedicated block to drop mutex guard -// let disconnected_clients = { -// // acquire lock on client state map -// let mut client_map = self.get_client_state_guard()?; - -// // disconnect inactive clients -// client_map.disconnect_inactive_vpn_clients_for_location(&location -// )? -// }; - -// // emit client disconnect events -// for (device, context) in disconnected_clients { -// self.emit_event(GrpcEvent::ClientDisconnected { -// context, -// location: location.clone(), -// device, -// })?; -// }; -// continue; -// } -// }; - -// debug!("Received stats message: {stats_update:?}"); -// let Some(stats_update::Payload::PeerStats(peer_stats)) = stats_update.payload else { -// debug!("Received stats message is empty, skipping."); -// continue; -// }; -// let public_key = peer_stats.public_key.clone(); - -// // fetch device from DB -// // TODO: fetch only when device has changed and use client state otherwise -// let device = match self.fetch_device_from_db(&public_key).await? { -// Some(device) => device, -// None => { -// warn!( -// "Received stats update for a device which does not exist: {public_key}, skipping." -// ); -// continue; -// } -// }; - -// // copy device ID for easier reference later -// let device_id = device.id; - -// // fetch user and location from DB for activity log -// // TODO: cache usernames since they don't change -// let user = self.fetch_user_from_db(device.user_id, &public_key).await?; -// let location = self.fetch_location_from_db(network_id).await?; - -// // convert stats to DB storage format -// let stats = WireguardPeerStats::from_peer_stats(peer_stats, network_id, device_id); - -// // only perform client state update if stats include an endpoint IP -// // otherwise a peer was added to the gateway interface -// // but has not connected yet -// if let Some(endpoint) = &stats.endpoint { -// // parse client endpoint IP -// let socket_addr: SocketAddr = endpoint.clone().parse().map_err(|err| { -// error!("Failed to parse VPN client endpoint: {err}"); -// Status::new( -// Code::Internal, -// format!("Failed to parse VPN client endpoint: {err}"), -// ) -// })?; - -// // perform client state operations in a dedicated block to drop mutex guard -// let disconnected_clients = { -// // acquire lock on client state map -// let mut client_map = self.get_client_state_guard()?; - -// // update connected clients map -// match client_map.get_vpn_client(network_id, &public_key) { -// Some(client_state) => { -// // update connected client state -// client_state.update_client_state( -// device, -// socket_addr, -// stats.latest_handshake, -// stats.upload, -// stats.download, -// ); -// } -// None => { -// // don't mark inactive peers as connected -// if (Utc::now().naive_utc() - stats.latest_handshake) -// < TimeDelta::seconds(location.peer_disconnect_threshold.into()) -// { -// // mark new VPN client as connected -// client_map.connect_vpn_client( -// network_id, -// &hostname, -// &public_key, -// &device, -// &user, -// socket_addr, -// &stats, -// )?; - -// // emit connection event -// let context = GrpcRequestContext::new( -// user.id, -// user.username.clone(), -// socket_addr.ip(), -// device.id, -// device.name.clone(), -// location.clone(), -// ); -// self.emit_event(GrpcEvent::ClientConnected { -// context, -// location: location.clone(), -// device: device.clone(), -// })?; -// } -// } -// } - -// // disconnect inactive clients -// client_map.disconnect_inactive_vpn_clients_for_location(&location)? -// }; - -// // emit client disconnect events -// for (device, context) in disconnected_clients { -// self.emit_event(GrpcEvent::ClientDisconnected { -// context, -// location: location.clone(), -// device, -// })?; -// } -// } - -// // Save stats to db -// let stats = match stats.save(&self.pool).await { -// Ok(stats) => stats, -// Err(err) => { -// error!("Saving WireGuard peer stats to db failed: {err}"); -// return Err(Status::new( -// Code::Internal, -// format!("Saving WireGuard peer stats to db failed: {err}"), -// )); -// } -// }; -// info!("Saved WireGuard peer stats to db."); -// debug!("WireGuard peer stats: {stats:?}"); -// } - -// Ok(Response::new(())) -// } - -// async fn config( -// &self, -// request: Request, -// ) -> Result, Status> { -// debug!("Sending configuration to gateway client."); -// let GatewayMetadata { -// network_id, -// hostname, -// version, -// .. -// // info, -// } = Self::extract_metadata(request.metadata())?; -// // FIXME: tracing causes looping messages, like `INFO gateway_config:gateway_stats:...`. -// // let span = tracing::info_span!("gateway_config", component = %DefguardComponent::Gateway, -// // version = version.to_string(), info); -// // let _guard = span.enter(); - -// let mut conn = self.pool.acquire().await.map_err(|e| { -// error!("Failed to acquire DB connection: {e}"); -// Status::new( -// Code::Internal, -// "Failed to acquire DB connection".to_string(), -// ) -// })?; - -// let mut network = WireguardNetwork::find_by_id(&mut *conn, network_id) -// .await -// .map_err(|e| { -// error!("Network {network_id} not found"); -// Status::new(Code::Internal, format!("Failed to retrieve network: {e}")) -// })? -// .ok_or_else(|| { -// Status::new( -// Code::Internal, -// format!("Network with id {network_id} not found"), -// ) -// })?; - -// debug!("Sending configuration to gateway client, network {network}."); - -// // store connected gateway in memory -// { -// let mut state = self.gateway_state.lock().unwrap(); -// state.add_gateway( -// network_id, -// &network.name, -// hostname, -// request.into_inner().name, -// self.mail_tx.clone(), -// version, -// ); -// } - -// network.connected_at = Some(Utc::now().naive_utc()); -// if let Err(err) = network.save(&mut *conn).await { -// error!("Failed to save updated network {network_id} in the database, status: {err}"); -// } - -// let peers = network.get_peers(&mut *conn).await.map_err(|error| { -// error!("Failed to fetch peers from the database for network {network_id}: {error}",); -// Status::new( -// Code::Internal, -// format!("Failed to retrieve peers from the database for network: {network_id}"), -// ) -// })?; -// let maybe_firewall_config = -// network -// .try_get_firewall_config(&mut conn) -// .await -// .map_err(|err| { -// error!("Failed to generate firewall config for network {network_id}: {err}"); -// Status::new( -// Code::Internal, -// format!("Failed to generate firewall config for network: {network_id}"), -// ) -// })?; - -// info!("Configuration sent to gateway client, network {network}."); - -// Ok(Response::new(gen_config( -// &network, -// peers, -// maybe_firewall_config, -// ))) -// } - +// // async fn updates(&self, request: Request<()>) -> Result, Status> { // let GatewayMetadata { // network_id, From 0a759cfe208061c191a44e0bf92aa7a3cc1f8646 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Tue, 2 Dec 2025 14:15:58 +0100 Subject: [PATCH 06/17] Gateway metadata --- Cargo.lock | 28 +- .../defguard_core/src/grpc/gateway/handler.rs | 87 +++++- crates/defguard_core/src/grpc/gateway/mod.rs | 266 +----------------- 3 files changed, 96 insertions(+), 285 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4e4af3ad94..009d6fab37 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3636,12 +3636,11 @@ dependencies = [ [[package]] name = "petgraph" -version = "0.8.3" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8701b58ea97060d5e5b155d383a69952a60943f0e6dfe30b04c287beb0b27455" +checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" dependencies = [ "fixedbitset", - "hashbrown 0.15.5", "indexmap 2.12.1", ] @@ -3905,9 +3904,9 @@ dependencies = [ [[package]] name = "prost" -version = "0.14.2" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "101fec8d036f8d9d4a1e8ebf90d566d1d798f3b1aa379d2576a54a0d9acea5bd" +checksum = "7231bd9b3d3d33c86b58adbac74b5ec0ad9f496b19d22801d773636feaa95f3d" dependencies = [ "bytes", "prost-derive", @@ -3915,14 +3914,15 @@ dependencies = [ [[package]] name = "prost-build" -version = "0.14.2" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "528a07106a21e01f4880c09818d0b7e73d0f0993536ddfff161754b5c20a086c" +checksum = "ac6c3320f9abac597dcbc668774ef006702672474aad53c6d596b62e487b40b1" dependencies = [ "heck", "itertools 0.14.0", "log", "multimap", + "once_cell", "petgraph", "prettyplease", "prost", @@ -3936,9 +3936,9 @@ dependencies = [ [[package]] name = "prost-derive" -version = "0.14.2" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2d93e596a829ebe00afa41c3a056e6308d6b8a4c7d869edf184e2c91b1ba564" +checksum = "9120690fafc389a67ba3803df527d0ec9cbbc9cc45e4cc20b332996dfb672425" dependencies = [ "anyhow", "itertools 0.14.0", @@ -3949,9 +3949,9 @@ dependencies = [ [[package]] name = "prost-types" -version = "0.14.2" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5d7b7346e150de32340ae3390b8b3ffa37ad93ec31fb5dad86afe817619e4e7" +checksum = "b9b4db3d6da204ed77bb26ba83b6122a73aeb2e87e25fbf7ad2e84c4ccbf8f72" dependencies = [ "prost", ] @@ -6062,13 +6062,13 @@ checksum = "e2eebbbfe4093922c2b6734d7c679ebfebd704a0d7e56dfcb0d05818ce28977d" [[package]] name = "uuid" -version = "1.18.1" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f87b8aa10b915a06587d0dec516c282ff295b475d94abf425d62b57710070a2" +checksum = "e2e054861b4bd027cd373e18e8d8d8e6548085000e41290d95ce0c373a654b4a" dependencies = [ "getrandom 0.3.4", "js-sys", - "serde", + "serde_core", "wasm-bindgen", ] diff --git a/crates/defguard_core/src/grpc/gateway/handler.rs b/crates/defguard_core/src/grpc/gateway/handler.rs index e5527ee16e..bc32135229 100644 --- a/crates/defguard_core/src/grpc/gateway/handler.rs +++ b/crates/defguard_core/src/grpc/gateway/handler.rs @@ -11,6 +11,8 @@ use chrono::{TimeDelta, Utc}; use defguard_common::{auth::claims::Claims, db::Id}; use defguard_mail::Mail; use defguard_proto::gateway::{CoreResponse, core_request, core_response, gateway_client}; +use defguard_version::version_info_from_metadata; +use semver::Version; use sqlx::PgPool; use tokio::{ sync::{ @@ -22,6 +24,7 @@ use tokio::{ use tokio_stream::wrappers::UnboundedReceiverStream; use tonic::{ Code, Status, + metadata::MetadataMap, transport::{ClientTlsConfig, Endpoint}, }; @@ -47,6 +50,14 @@ pub(crate) struct GatewayHandler { grpc_event_tx: UnboundedSender, } +/// Utility struct encapsulating commonly extracted metadata fields during gRPC communication. +struct GatewayMetadata { + network_id: Id, + hostname: String, + version: Version, + // info: String, +} + impl GatewayHandler { pub(crate) fn new( gateway: Gateway, @@ -79,6 +90,57 @@ impl GatewayHandler { }) } + fn get_network_id(metadata: &MetadataMap) -> Result { + match Self::get_network_id_from_metadata(metadata) { + Some(m) => Ok(m), + None => Err(Status::new( + Code::Internal, + "Network ID was not found in metadata", + )), + } + } + + // parse network id from gateway request metadata from intercepted information from JWT token + fn get_network_id_from_metadata(metadata: &MetadataMap) -> Option { + if let Some(ascii_value) = metadata.get("gateway_network_id") { + if let Ok(slice) = ascii_value.clone().to_str() { + if let Ok(id) = slice.parse::() { + return Some(id); + } + } + } + None + } + + // extract gateway hostname from request headers + fn get_gateway_hostname(metadata: &MetadataMap) -> Result { + match metadata.get("hostname") { + Some(ascii_value) => { + let hostname = ascii_value.to_str().map_err(|_| { + Status::new( + Code::Internal, + "Failed to parse gateway hostname from request metadata", + ) + })?; + Ok(hostname.into()) + } + None => Err(Status::new( + Code::Internal, + "Gateway hostname not found in request metadata", + )), + } + } + + /// Utility function extracting metadata fields during gRPC communication. + fn extract_metadata(metadata: &MetadataMap) -> Result { + let (version, _info) = version_info_from_metadata(metadata); + Ok(GatewayMetadata { + network_id: Self::get_network_id(metadata)?, + hostname: Self::get_gateway_hostname(metadata)?, + version, + }) + } + /// Send network and VPN configuration to Gateway. async fn send_configuration( &self, @@ -86,7 +148,6 @@ impl GatewayHandler { ) -> Result, Status> { debug!("Sending configuration to Gateway"); let network_id = self.gateway.network_id; - // let hostname = Self::get_gateway_hostname(request.metadata())?; let mut conn = self.pool.acquire().await.map_err(|err| { error!("Failed to acquire DB connection: {err}"); @@ -310,6 +371,15 @@ impl GatewayHandler { }; info!("Connected to Defguard Gateway {uri}"); + let Ok(GatewayMetadata { + network_id, + hostname, + .. + // info, + }) = Self::extract_metadata(response.metadata()) else { + continue; + }; + let mut resp_stream = response.into_inner(); let mut config_sent = false; @@ -322,6 +392,7 @@ impl GatewayHandler { Ok(Some(received)) => { info!("Received message from Gateway."); debug!("Message from Gateway {uri}"); + match received.payload { Some(core_request::Payload::ConfigRequest(config_request)) => { if config_sent { @@ -370,11 +441,10 @@ impl GatewayHandler { .gateway .touch_connected(&self.pool, config_request.hostname) .await; - let guh = super::GatewayUpdatesHandler::new( + let mut guh = super::GatewayUpdatesHandler::new( self.gateway.network_id, network, - self - .gateway + self.gateway .hostname .as_ref() .cloned() @@ -383,12 +453,9 @@ impl GatewayHandler { self.events_tx.subscribe(), tx.clone(), ); - // tokio::spawn(super::handle_events( - // network, - // // self.gateway.hostname.unwrap_or_default().clone(), - // tx.clone(), - // self.events_tx.subscribe(), - // )); + tokio::spawn(async move { + guh.run().await; + }); } Err(err) => { error!( diff --git a/crates/defguard_core/src/grpc/gateway/mod.rs b/crates/defguard_core/src/grpc/gateway/mod.rs index ae2ff3ce5e..7be29eac1e 100644 --- a/crates/defguard_core/src/grpc/gateway/mod.rs +++ b/crates/defguard_core/src/grpc/gateway/mod.rs @@ -11,15 +11,13 @@ use defguard_proto::{ enterprise::firewall::FirewallConfig, gateway::{Configuration, CoreResponse, Peer, PeerStats, Update, core_response, update}, }; -use defguard_version::version_info_from_metadata; -use semver::Version; use sqlx::PgPool; use thiserror::Error; use tokio::sync::{ broadcast::{Receiver as BroadcastReceiver, Sender}, - mpsc::{self, UnboundedSender, error::SendError}, + mpsc::{UnboundedSender, error::SendError}, }; -use tonic::{Code, Status, metadata::MetadataMap}; +use tonic::{Code, Status}; use self::map::GatewayMap; use crate::{ @@ -84,87 +82,6 @@ pub struct GatewayServer { grpc_event_tx: UnboundedSender, } -/// Utility struct encapsulating commonly extracted metadata fields during gRPC communication. -struct GatewayMetadata { - network_id: Id, - hostname: String, - version: Version, - // info: String, -} - -impl GatewayServer { - /// Create new gateway server instance - #[must_use] - pub fn new( - pool: PgPool, - gateway_state: Arc>, - client_state: Arc>, - wireguard_tx: Sender, - mail_tx: UnboundedSender, - grpc_event_tx: UnboundedSender, - ) -> Self { - Self { - pool, - gateway_state, - client_state, - wireguard_tx, - mail_tx, - grpc_event_tx, - } - } - - fn get_network_id(metadata: &MetadataMap) -> Result { - match Self::get_network_id_from_metadata(metadata) { - Some(m) => Ok(m), - None => Err(Status::new( - Code::Internal, - "Network ID was not found in metadata", - )), - } - } - - // parse network id from gateway request metadata from intercepted information from JWT token - fn get_network_id_from_metadata(metadata: &MetadataMap) -> Option { - if let Some(ascii_value) = metadata.get("gateway_network_id") { - if let Ok(slice) = ascii_value.clone().to_str() { - if let Ok(id) = slice.parse::() { - return Some(id); - } - } - } - None - } - - // extract gateway hostname from request headers - fn get_gateway_hostname(metadata: &MetadataMap) -> Result { - match metadata.get("hostname") { - Some(ascii_value) => { - let hostname = ascii_value.to_str().map_err(|_| { - Status::new( - Code::Internal, - "Failed to parse gateway hostname from request metadata", - ) - })?; - Ok(hostname.into()) - } - None => Err(Status::new( - Code::Internal, - "Gateway hostname not found in request metadata", - )), - } - } - - /// Utility function extracting metadata fields during gRPC communication. - fn extract_metadata(metadata: &MetadataMap) -> Result { - let (version, _info) = version_info_from_metadata(metadata); - Ok(GatewayMetadata { - network_id: Self::get_network_id(metadata)?, - hostname: Self::get_gateway_hostname(metadata)?, - version, - }) - } -} - fn gen_config( network: &WireguardNetwork, peers: Vec, @@ -202,166 +119,7 @@ impl WireguardPeerStats { } } -/// Process received Gateway events -/// -/// Main gRPC server uses a shared channel for broadcasting all Gateway events, -/// so the handler must determine if an event is relevant for the network being serviced. -async fn handle_events( - mut current_network: WireguardNetwork, - // gateway_hostname: String, - tx: UnboundedSender, - mut events_rx: BroadcastReceiver, -) { - info!("Starting update stream network {current_network}"); - // while let Some(event) = events_rx.recv().await { - // debug!("Received networking state update event: {event:?}"); - // let (update_type, update) = match event { - // GatewayEvent::NetworkCreated(network, _fixme) => { - // if network.id != current_network.id { - // continue; - // } - // ( - // UpdateType::Create, - // update::Update::Network(Configuration { - // name: network.name.clone(), - // prvkey: network.prvkey.clone(), - // addresses: network.address.to_string(), - // port: network.port as u32, - // peers: Vec::new(), - // }), - // ) - // } - // GatewayEvent::NetworkModified(network, peers, _fixme) => { - // if network.id != current_network.id { - // continue; - // } - // // update stored network data - // current_network = network.clone(); - // ( - // UpdateType::Modify, - // update::Update::Network(Configuration { - // name: network.name, - // prvkey: network.prvkey, - // addresses: network.address.to_string(), - // port: network.port as u32, - // peers, - // }), - // ) - // } - // GatewayEvent::NetworkDeleted(network_id, network_name) => { - // if network_id != current_network.id { - // continue; - // } - // ( - // UpdateType::Delete, - // update::Update::Network(Configuration { - // name: network_name.to_string(), - // prvkey: String::new(), - // addresses: Vec::new(), - // port: 0, - // peers: Vec::new(), - // firewall_config: None, - // }), - // ) - // } - // GatewayEvent::DeviceCreated(device) => { - // // check if a peer has to be added in the current network - // match device - // .network_info - // .iter() - // .find(|info| info.network_id == current_network.id) - // { - // Some(network_info) => { - // if current_network.mfa_enabled && !network_info.is_authorized { - // debug!( - // "Created WireGuard device {} is not authorized to connect to MFA enabled location {}", - // device.device.name, current_network.name - // ); - // continue; - // }; - // let peer = Peer { - // pubkey: device.device.wireguard_pubkey, - // allowed_ips: vec![network_info.device_wireguard_ip.to_string()], - // preshared_key: network_info.preshared_key.clone(), - // keepalive_interval: Some(current_network.keepalive_interval as u32), - // }; - // (UpdateType::Create, update::Update::Peer(peer)) - // } - // None => continue, - // } - // } - // GatewayEvent::DeviceModified(device) => { - // // check if a peer has to be updated in the current network - // match device - // .network_info - // .iter() - // .find(|info| info.network_id == current_network.id) - // { - // Some(network_info) => { - // if current_network.mfa_enabled && !network_info.is_authorized { - // debug!( - // "Modified WireGuard device {} is not authorized to connect to MFA enabled location {}", - // device.device.name, current_network.name - // ); - // continue; - // }; - // let peer = Peer { - // pubkey: device.device.wireguard_pubkey, - // allowed_ips: vec![network_info.device_wireguard_ip.to_string()], - // preshared_key: network_info.preshared_key.clone(), - // keepalive_interval: Some(current_network.keepalive_interval as u32), - // }; - // (UpdateType::Modify, update::Update::Peer(peer)) - // } - // None => continue, - // } - // } - // GatewayEvent::DeviceDeleted(device) => { - // // check if a peer has to be updated in the current network - // match device - // .network_info - // .iter() - // .find(|info| info.network_id == current_network.id) - // { - // Some(_) => ( - // UpdateType::Delete, - // update::Update::Peer(Peer { - // pubkey: device.device.wireguard_pubkey, - // allowed_ips: Vec::new(), - // preshared_key: None, - // keepalive_interval: None, - // }), - // ), - // None => continue, - // } - // } - // GatewayEvent::FirewallConfigChanged(_fixme, _) => (), - // GatewayEvent::FirewallDisabled(_id) => (), - // }; - - // let req = CoreResponse { - // id: 0, - // payload: Some(core_response::Payload::Update(Update { - // update_type: update_type as i32, - // update: Some(update), - // })), - // }; - // if let Err(err) = tx.send(req) { - // error!( - // "Failed to send network update, network {current_network}, update type: {}, error: \ - // {err}", - // update_type.as_str_name() - // ); - // break; - // } - // debug!( - // "Network update sent for network {current_network}, update type: {}", - // update_type.as_str_name() - // ); - // } -} - -/// Helper struct for handling gateway events +/// Helper struct for handling gateway events. struct GatewayUpdatesHandler { network_id: Id, network: WireguardNetwork, @@ -561,7 +319,8 @@ impl GatewayUpdatesHandler { })), }) { let msg = format!( - "Failed to send network update, network {network}, update type: {update_type} ({}), error: {err}", + "Failed to send network update, network {network}, update type: {update_type} \ + ({}), error: {err}", if update_type == 0 { "CREATE" } else { "MODIFY" }, ); error!(msg); @@ -699,26 +458,11 @@ impl GatewayUpdatesHandler { } } -// pub struct GatewayUpdatesStream { -// task_handle: JoinHandle<()>, -// rx: Receiver>, -// network_id: Id, -// gateway_hostname: String, -// gateway_state: Arc>, -// pool: PgPool, -// } - // #[tonic::async_trait] // impl gateway_service_server::GatewayService for GatewayServer { // type UpdatesStream = GatewayUpdatesStream; // // async fn updates(&self, request: Request<()>) -> Result, Status> { -// let GatewayMetadata { -// network_id, -// hostname, -// .. -// // info, -// } = Self::extract_metadata(request.metadata())?; // // FIXME: tracing causes looping messages, like `INFO gateway_config:gateway_stats:...`. // // let span = tracing::info_span!("gateway_updates", component = %DefguardComponent::Gateway, // // version = version.to_string(), info); From 640bae9a0aea1e11395f0a29fb8c84eeefd7f115 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Wed, 3 Dec 2025 10:33:53 +0100 Subject: [PATCH 07/17] Resurrect gateway test --- Cargo.lock | 31 ++++++++----- crates/defguard_core/src/grpc/gateway/mod.rs | 8 ++-- .../defguard_core/src/grpc/gateway/state.rs | 14 ------ .../defguard_core/src/grpc/gateway/tests.rs | 44 +++++++++++++++---- .../tests/integration/grpc/common/mod.rs | 2 +- 5 files changed, 62 insertions(+), 37 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 009d6fab37..96317a4a62 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -765,6 +765,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" +[[package]] +name = "convert_case" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "633458d4ef8c78b72454de2d54fd6ab2e60f9e02be22f3c6104cdc8a4e0fceb9" +dependencies = [ + "unicode-segmentation", +] + [[package]] name = "cookie" version = "0.18.1" @@ -1379,7 +1388,7 @@ version = "0.99.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6edb4b64a43d977b8e99788fe3a04d483834fba1215a7e02caa415b626497f7f" dependencies = [ - "convert_case", + "convert_case 0.4.0", "proc-macro2", "quote", "rustc_version", @@ -1388,21 +1397,23 @@ dependencies = [ [[package]] name = "derive_more" -version = "2.0.1" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "093242cf7570c207c83073cf82f79706fe7b8317e98620a47d5be7c3d8497678" +checksum = "10b768e943bed7bf2cab53df09f4bc34bfd217cdb57d971e769874c9a6710618" dependencies = [ "derive_more-impl", ] [[package]] name = "derive_more-impl" -version = "2.0.1" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda628edc44c4bb645fbe0f758797143e4e07926f7ebf4e9bdfbd3d2ce621df3" +checksum = "6d286bfdaf75e988b4a78e013ecd79c581e06399ab53fbacd2d916c2f904f30b" dependencies = [ + "convert_case 0.10.0", "proc-macro2", "quote", + "rustc_version", "syn", "unicode-xid", ] @@ -2705,9 +2716,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.177" +version = "0.2.178" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976" +checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091" [[package]] name = "libgit2-sys" @@ -2798,9 +2809,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.28" +version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" [[package]] name = "lru-slab" @@ -3672,7 +3683,7 @@ dependencies = [ "curve25519-dalek", "cx448", "derive_builder", - "derive_more 2.0.1", + "derive_more 2.1.0", "des", "digest", "dsa", diff --git a/crates/defguard_core/src/grpc/gateway/mod.rs b/crates/defguard_core/src/grpc/gateway/mod.rs index 7be29eac1e..afcdc81fb3 100644 --- a/crates/defguard_core/src/grpc/gateway/mod.rs +++ b/crates/defguard_core/src/grpc/gateway/mod.rs @@ -32,17 +32,17 @@ pub mod client_state; pub(crate) mod handler; pub mod map; pub(crate) mod state; -// #[cfg(test)] -// mod tests; +#[cfg(test)] +mod tests; -const PEER_DISCONNECT_INTERVAL: u64 = 60; +#[cfg(test)] pub(super) static TONIC_SOCKET: &str = "tonic.sock"; /// Sends given `GatewayEvent` to be handled by gateway GRPC server /// /// If you want to use it inside the API context, use [`crate::AppState::send_wireguard_event`] instead pub fn send_wireguard_event(event: GatewayEvent, wg_tx: &Sender) { - debug!("Sending the following WireGuard event to the gateway: {event:?}"); + debug!("Sending the following WireGuard event to Defguard Gateway: {event:?}"); if let Err(err) = wg_tx.send(event) { error!("Error sending WireGuard event {err}"); } diff --git a/crates/defguard_core/src/grpc/gateway/state.rs b/crates/defguard_core/src/grpc/gateway/state.rs index 0f9b10f629..a628810f3f 100644 --- a/crates/defguard_core/src/grpc/gateway/state.rs +++ b/crates/defguard_core/src/grpc/gateway/state.rs @@ -13,7 +13,6 @@ use utoipa::ToSchema; use uuid::Uuid; use crate::{ - // db::models::gateway::Gateway, grpc::MIN_GATEWAY_VERSION, handlers::mail::{send_gateway_disconnected_email, send_gateway_reconnected_email}, }; @@ -37,19 +36,6 @@ pub struct GatewayState { } impl GatewayState { - // pub(crate) fn from_gateway(gateway: &Gateway, network_name: &str) -> Self { - // Self { - // id: gateway.id, - // connected: gateway.is_connected(), - // network_id: gateway.network_id, - // network_name: network_name.to_owned(), - // name: None, // TODO: remove - // hostname: gateway.hostname.clone().unwrap_or_default(), - // connected_at: gateway.connected_at, - // disconnected_at: gateway.disconnected_at, - // } - // } - #[must_use] pub fn new>( network_id: Id, diff --git a/crates/defguard_core/src/grpc/gateway/tests.rs b/crates/defguard_core/src/grpc/gateway/tests.rs index f79b77dba7..b00b1f004f 100644 --- a/crates/defguard_core/src/grpc/gateway/tests.rs +++ b/crates/defguard_core/src/grpc/gateway/tests.rs @@ -1,18 +1,31 @@ use std::{ io, net::{IpAddr, Ipv4Addr}, + sync::{Arc, Mutex}, }; +use defguard_common::db::setup_pool; +use defguard_mail::Mail; +use defguard_proto::gateway::{CoreRequest, CoreResponse, gateway_server}; use ipnetwork::IpNetwork; +use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use tokio::{ net::UnixListener, sync::{broadcast, mpsc::unbounded_channel}, }; -use tokio_stream::wrappers::UnixListenerStream; +use tokio_stream::wrappers::{UnboundedReceiverStream, UnixListenerStream}; use tonic::{Request, Response, Status, Streaming, transport::Server}; -use super::*; +use super::{TONIC_SOCKET, handler::GatewayHandler}; +use crate::{ + db::models::{ + gateway::Gateway, + wireguard::{GatewayEvent, LocationMfaMode, ServiceLocationMode, WireguardNetwork}, + }, + grpc::{ClientMap, GrpcEvent}, +}; +// TODO: move to "gateway" repo. struct FakeGateway; #[tonic::async_trait] @@ -23,7 +36,7 @@ impl gateway_server::Gateway for FakeGateway { &self, request: Request>, ) -> Result, Status> { - let (_tx, rx) = mpsc::unbounded_channel(); + let (_tx, rx) = unbounded_channel(); let mut stream = request.into_inner(); tokio::spawn(async move { loop { @@ -55,17 +68,21 @@ async fn fake_gateway() -> Result<(), io::Error> { } #[sqlx::test] -async fn test_gateway(pool: PgPool) { +async fn test_gateway(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; let network = WireguardNetwork::new( "TestNet".to_string(), - IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 1, 1, 1)), 24).unwrap(), + vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 1, 1, 1)), 24).unwrap()], 50051, "0.0.0.0".to_string(), None, vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 1, 1, 0)), 24).unwrap()], - false, 0, 0, + false, + false, + LocationMfaMode::default(), + ServiceLocationMode::default(), ) .save(&pool) .await @@ -74,10 +91,21 @@ async fn test_gateway(pool: PgPool) { .save(&pool) .await .unwrap(); - let (events_tx, _events_rx) = broadcast::channel::(16); + let client_state = Arc::new(Mutex::new(ClientMap::new())); + let (events_tx, _events_rx) = broadcast::channel::(16); let (mail_tx, _mail_rx) = unbounded_channel::(); + let (grpc_event_tx, _grpc_event_rx) = unbounded_channel::(); - let mut gateway_handler = GatewayHandler::new(gateway, None, pool, events_tx, mail_tx).unwrap(); + let mut gateway_handler = GatewayHandler::new( + gateway, + None, + pool, + client_state, + events_tx, + mail_tx, + grpc_event_tx, + ) + .unwrap(); let handle = tokio::spawn(async move { gateway_handler.handle_connection().await; }); diff --git a/crates/defguard_core/tests/integration/grpc/common/mod.rs b/crates/defguard_core/tests/integration/grpc/common/mod.rs index b919afcd44..82525af4ae 100644 --- a/crates/defguard_core/tests/integration/grpc/common/mod.rs +++ b/crates/defguard_core/tests/integration/grpc/common/mod.rs @@ -28,7 +28,7 @@ use tower::service_fn; use crate::common::{init_config, initialize_users}; -// pub mod mock_gateway; +pub mod mock_gateway; pub struct TestGrpcServer { grpc_server_task_handle: JoinHandle<()>, From ff62044a209a71d290f1af48a03cec23a4ed3aca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Wed, 3 Dec 2025 14:03:58 +0100 Subject: [PATCH 08/17] Do not extract metadata --- .../defguard_core/src/grpc/gateway/handler.rs | 59 ++++++++----------- crates/defguard_core/src/version.rs | 8 +-- 2 files changed, 27 insertions(+), 40 deletions(-) diff --git a/crates/defguard_core/src/grpc/gateway/handler.rs b/crates/defguard_core/src/grpc/gateway/handler.rs index bc32135229..eb2c81b58a 100644 --- a/crates/defguard_core/src/grpc/gateway/handler.rs +++ b/crates/defguard_core/src/grpc/gateway/handler.rs @@ -90,17 +90,7 @@ impl GatewayHandler { }) } - fn get_network_id(metadata: &MetadataMap) -> Result { - match Self::get_network_id_from_metadata(metadata) { - Some(m) => Ok(m), - None => Err(Status::new( - Code::Internal, - "Network ID was not found in metadata", - )), - } - } - - // parse network id from gateway request metadata from intercepted information from JWT token + // Parse network ID from Gateway request metadata from intercepted information from JWT token. fn get_network_id_from_metadata(metadata: &MetadataMap) -> Option { if let Some(ascii_value) = metadata.get("gateway_network_id") { if let Ok(slice) = ascii_value.clone().to_str() { @@ -112,30 +102,28 @@ impl GatewayHandler { None } - // extract gateway hostname from request headers - fn get_gateway_hostname(metadata: &MetadataMap) -> Result { + // Extract Gateway hostname from request headers. + fn get_gateway_hostname(metadata: &MetadataMap) -> Option { match metadata.get("hostname") { Some(ascii_value) => { - let hostname = ascii_value.to_str().map_err(|_| { - Status::new( - Code::Internal, - "Failed to parse gateway hostname from request metadata", - ) - })?; - Ok(hostname.into()) + let Ok(hostname) = ascii_value.to_str() else { + error!("Failed to parse Gateway hostname from request metadata"); + return None; + }; + Some(hostname.into()) + } + None => { + error!("Gateway hostname not found in request metadata"); + None } - None => Err(Status::new( - Code::Internal, - "Gateway hostname not found in request metadata", - )), } } /// Utility function extracting metadata fields during gRPC communication. - fn extract_metadata(metadata: &MetadataMap) -> Result { + fn extract_metadata(metadata: &MetadataMap) -> Option { let (version, _info) = version_info_from_metadata(metadata); - Ok(GatewayMetadata { - network_id: Self::get_network_id(metadata)?, + Some(GatewayMetadata { + network_id: 0, // FIXME: not needed; was Self::get_network_id_from_metadata(metadata)?, hostname: Self::get_gateway_hostname(metadata)?, version, }) @@ -364,21 +352,20 @@ impl GatewayHandler { let response = match client.bidi(UnboundedReceiverStream::new(rx)).await { Ok(response) => response, Err(err) => { - error!("Failed to connect to gateway {uri}, retrying: {err}"); + error!("Failed to connect to Gateway {uri}, retrying: {err}"); sleep(TEN_SECS).await; continue; } }; info!("Connected to Defguard Gateway {uri}"); - let Ok(GatewayMetadata { - network_id, - hostname, - .. - // info, - }) = Self::extract_metadata(response.metadata()) else { - continue; - }; + // Metadata isn't needed in reversed communication. TODO: remove, but only check version. + // let Some(GatewayMetadata { + // hostname, + // }) = Self::extract_metadata(response.metadata()) else { + // error!("Failed to extract metadata"); + // continue; + // }; let mut resp_stream = response.into_inner(); let mut config_sent = false; diff --git a/crates/defguard_core/src/version.rs b/crates/defguard_core/src/version.rs index 849c232337..16043976ff 100644 --- a/crates/defguard_core/src/version.rs +++ b/crates/defguard_core/src/version.rs @@ -10,7 +10,7 @@ use serde::Serialize; use tonic::{Status, service::Interceptor}; const MIN_PROXY_VERSION: Version = Version::new(1, 6, 0); -pub const MIN_GATEWAY_VERSION: Version = Version::new(1, 5, 0); +pub const MIN_GATEWAY_VERSION: Version = Version::new(1, 6, 0); static OUTDATED_COMPONENT_LIFETIME: TimeDelta = TimeDelta::hours(1); /// Checks if Defguard Proxy version meets minimum version requirements. @@ -110,7 +110,7 @@ impl Interceptor for GatewayVersionInterceptor { } } -#[derive(Debug, Default, Clone, Serialize)] +#[derive(Default, Clone, Serialize)] pub struct IncompatibleComponents { pub gateways: HashSet, pub proxy: Option, @@ -204,7 +204,7 @@ impl IncompatibleComponents { } } -#[derive(Clone, Debug, Serialize)] +#[derive(Clone, Serialize)] pub struct IncompatibleGatewayData { pub version: Option, pub hostname: Option, @@ -261,7 +261,7 @@ impl IncompatibleGatewayData { } } -#[derive(Clone, Debug, Serialize)] +#[derive(Clone, Serialize)] pub struct IncompatibleProxyData { pub version: Option, created: NaiveDateTime, From 41681df5fcc73637765730d758fe316566d60d21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Tue, 9 Dec 2025 10:47:59 +0100 Subject: [PATCH 09/17] Add version to gateway gRPC --- Cargo.lock | 48 +++++++++---------- .../defguard_core/src/grpc/gateway/handler.rs | 9 ++-- 2 files changed, 30 insertions(+), 27 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 96317a4a62..1dda62d884 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -404,9 +404,9 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "base64ct" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba" +checksum = "0e050f626429857a27ddccb31e0aca21356bfa709c04041aefddac081a8f068a" [[package]] name = "base64urlsafedata" @@ -574,9 +574,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.48" +version = "1.2.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c481bdbf0ed3b892f6f806287d72acd515b352a4ec27a208489b8c1bc839633a" +checksum = "90583009037521a116abf44494efecd645ba48b6622457080f080b85544e2215" dependencies = [ "find-msvc-tools", "jobserver", @@ -1954,9 +1954,9 @@ dependencies = [ [[package]] name = "git2" -version = "0.20.2" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2deb07a133b1520dc1a5690e9bd08950108873d7ed5de38dcc74d3b5ebffa110" +checksum = "3e2b37e2f62729cdada11f0e6b3b6fe383c69c29fc619e391223e12856af308c" dependencies = [ "bitflags 2.10.0", "libc", @@ -2300,9 +2300,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.18" +version = "0.1.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52e9a2a24dc5c6821e71a7030e1e14b7b632acac55c40e9d2e082c621261bb56" +checksum = "727805d60e7938b76b826a6ef209eb70eaa1812794f9424d4a4e2d740662df5f" dependencies = [ "base64 0.22.1", "bytes", @@ -2722,9 +2722,9 @@ checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091" [[package]] name = "libgit2-sys" -version = "0.18.2+1.9.1" +version = "0.18.3+1.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c42fe03df2bd3c53a3a9c7317ad91d80c81cd1fb0caec8d7cc4cd2bfa10c222" +checksum = "c9b3acc4b91781bb0b3386669d325163746af5f6e4f73e6d2d630e09a35f3487" dependencies = [ "cc", "libc", @@ -2761,9 +2761,9 @@ dependencies = [ [[package]] name = "libz-rs-sys" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "840db8cf39d9ec4dd794376f38acc40d0fc65eec2a8f484f7fd375b84602becd" +checksum = "8b484ba8d4f775eeca644c452a56650e544bf7e617f1d170fe7298122ead5222" dependencies = [ "zlib-rs", ] @@ -2933,9 +2933,9 @@ dependencies = [ [[package]] name = "mio" -version = "1.1.0" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69d83b0086dc8ecf3ce9ae2874b2d1290252e2a30720bea58a5c6639b0092873" +checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" dependencies = [ "libc", "wasi", @@ -4222,9 +4222,9 @@ checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" [[package]] name = "reqwest" -version = "0.12.24" +version = "0.12.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d0946410b9f7b082a427e4ef5c8ff541a88b357bc6c637c40db3a68ac70a36f" +checksum = "b6eff9328d40131d43bd911d42d79eb6a47312002a4daefc9e37f17e74a7701a" dependencies = [ "base64 0.22.1", "bytes", @@ -4882,9 +4882,9 @@ dependencies = [ [[package]] name = "simd-adler32" -version = "0.3.7" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" +checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" [[package]] name = "simple_asn1" @@ -5630,9 +5630,9 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.23.7" +version = "0.23.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6485ef6d0d9b5d0ec17244ff7eb05310113c3f316f2d14200d4de56b3cb98f8d" +checksum = "5d7cbc3b4b49633d57a0509303158ca50de80ae32c265093b24c414705807832" dependencies = [ "indexmap 2.12.1", "toml_datetime", @@ -5766,9 +5766,9 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.6.7" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cf146f99d442e8e68e585f5d798ccd3cad9a7835b917e09728880a862706456" +checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" dependencies = [ "bitflags 2.10.0", "bytes", @@ -6914,9 +6914,9 @@ dependencies = [ [[package]] name = "zlib-rs" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f06ae92f42f5e5c42443fd094f245eb656abf56dd7cce9b8b263236565e00f2" +checksum = "36134c44663532e6519d7a6dfdbbe06f6f8192bde8ae9ed076e9b213f0e31df7" [[package]] name = "zopfli" diff --git a/crates/defguard_core/src/grpc/gateway/handler.rs b/crates/defguard_core/src/grpc/gateway/handler.rs index eb2c81b58a..8d6bb6b8ab 100644 --- a/crates/defguard_core/src/grpc/gateway/handler.rs +++ b/crates/defguard_core/src/grpc/gateway/handler.rs @@ -8,10 +8,10 @@ use std::{ }; use chrono::{TimeDelta, Utc}; -use defguard_common::{auth::claims::Claims, db::Id}; +use defguard_common::{VERSION, auth::claims::Claims, db::Id}; use defguard_mail::Mail; use defguard_proto::gateway::{CoreResponse, core_request, core_response, gateway_client}; -use defguard_version::version_info_from_metadata; +use defguard_version::{client::ClientVersionInterceptor, version_info_from_metadata}; use semver::Version; use sqlx::PgPool; use tokio::{ @@ -347,7 +347,10 @@ impl GatewayHandler { )); debug!("Connecting to Gateway {uri}"); - let mut client = gateway_client::GatewayClient::new(channel); + let interceptor = ClientVersionInterceptor::new( + Version::parse(VERSION).expect("failed to parse self version"), + ); + let mut client = gateway_client::GatewayClient::with_interceptor(channel, interceptor); let (tx, rx) = mpsc::unbounded_channel(); let response = match client.bidi(UnboundedReceiverStream::new(rx)).await { Ok(response) => response, From 3b1f0973ca6ec80cb26c11f48ec2af72efaf38d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Tue, 9 Dec 2025 13:07:41 +0100 Subject: [PATCH 10/17] Cleanup --- crates/defguard_core/src/grpc/client_mfa.rs | 4 +- .../defguard_core/src/grpc/gateway/handler.rs | 67 +++-------- crates/defguard_core/src/grpc/gateway/mod.rs | 106 ++---------------- .../tests/integration/grpc/common/mod.rs | 5 - 4 files changed, 28 insertions(+), 154 deletions(-) diff --git a/crates/defguard_core/src/grpc/client_mfa.rs b/crates/defguard_core/src/grpc/client_mfa.rs index f688a41a48..90abe62f47 100644 --- a/crates/defguard_core/src/grpc/client_mfa.rs +++ b/crates/defguard_core/src/grpc/client_mfa.rs @@ -484,7 +484,7 @@ impl ClientMfaServer { } MfaMethod::Totp => { let code = if let Some(code) = request.code { - code.to_string() + code.clone() } else { error!("TOTP code not provided in request"); self.emit_event(BidiStreamEvent { @@ -518,7 +518,7 @@ impl ClientMfaServer { } MfaMethod::Email => { let code = if let Some(code) = request.code { - code.to_string() + code.clone() } else { error!("Email MFA code not provided in request"); self.emit_event(BidiStreamEvent { diff --git a/crates/defguard_core/src/grpc/gateway/handler.rs b/crates/defguard_core/src/grpc/gateway/handler.rs index 8d6bb6b8ab..8968386312 100644 --- a/crates/defguard_core/src/grpc/gateway/handler.rs +++ b/crates/defguard_core/src/grpc/gateway/handler.rs @@ -11,7 +11,7 @@ use chrono::{TimeDelta, Utc}; use defguard_common::{VERSION, auth::claims::Claims, db::Id}; use defguard_mail::Mail; use defguard_proto::gateway::{CoreResponse, core_request, core_response, gateway_client}; -use defguard_version::{client::ClientVersionInterceptor, version_info_from_metadata}; +use defguard_version::client::ClientVersionInterceptor; use semver::Version; use sqlx::PgPool; use tokio::{ @@ -50,14 +50,6 @@ pub(crate) struct GatewayHandler { grpc_event_tx: UnboundedSender, } -/// Utility struct encapsulating commonly extracted metadata fields during gRPC communication. -struct GatewayMetadata { - network_id: Id, - hostname: String, - version: Version, - // info: String, -} - impl GatewayHandler { pub(crate) fn new( gateway: Gateway, @@ -68,7 +60,7 @@ impl GatewayHandler { mail_tx: UnboundedSender, grpc_event_tx: UnboundedSender, ) -> Result { - let endpoint = Endpoint::from_shared(gateway.url.to_string())? + let endpoint = Endpoint::from_shared(gateway.url.clone())? .http2_keep_alive_interval(TEN_SECS) .tcp_keepalive(Some(TEN_SECS)) .keep_alive_while_idle(true); @@ -104,31 +96,18 @@ impl GatewayHandler { // Extract Gateway hostname from request headers. fn get_gateway_hostname(metadata: &MetadataMap) -> Option { - match metadata.get("hostname") { - Some(ascii_value) => { - let Ok(hostname) = ascii_value.to_str() else { - error!("Failed to parse Gateway hostname from request metadata"); - return None; - }; - Some(hostname.into()) - } - None => { - error!("Gateway hostname not found in request metadata"); - None - } + if let Some(ascii_value) = metadata.get("hostname") { + let Ok(hostname) = ascii_value.to_str() else { + error!("Failed to parse Gateway hostname from request metadata"); + return None; + }; + Some(hostname.into()) + } else { + error!("Gateway hostname not found in request metadata"); + None } } - /// Utility function extracting metadata fields during gRPC communication. - fn extract_metadata(metadata: &MetadataMap) -> Option { - let (version, _info) = version_info_from_metadata(metadata); - Some(GatewayMetadata { - network_id: 0, // FIXME: not needed; was Self::get_network_id_from_metadata(metadata)?, - hostname: Self::get_gateway_hostname(metadata)?, - version, - }) - } - /// Send network and VPN configuration to Gateway. async fn send_configuration( &self, @@ -179,7 +158,7 @@ impl GatewayHandler { let maybe_firewall_config = network - .try_get_firewall_config(&mut *conn) + .try_get_firewall_config(&mut conn) .await .map_err(|err| { error!("Failed to generate firewall config for network {network_id}: {err}"); @@ -255,7 +234,7 @@ impl GatewayHandler { "{} disconnected. Email notification not sent.", self.gateway ); - }; + } } /// Helper method to fetch `Device` info from DB by pubkey and return appropriate errors @@ -360,15 +339,7 @@ impl GatewayHandler { continue; } }; - info!("Connected to Defguard Gateway {uri}"); - // Metadata isn't needed in reversed communication. TODO: remove, but only check version. - // let Some(GatewayMetadata { - // hostname, - // }) = Self::extract_metadata(response.metadata()) else { - // error!("Failed to extract metadata"); - // continue; - // }; let mut resp_stream = response.into_inner(); let mut config_sent = false; @@ -431,20 +402,19 @@ impl GatewayHandler { .gateway .touch_connected(&self.pool, config_request.hostname) .await; - let mut guh = super::GatewayUpdatesHandler::new( + let mut updates_handler = super::GatewayUpdatesHandler::new( self.gateway.network_id, network, self.gateway .hostname - .as_ref() - .cloned() + .clone() .unwrap_or_default() .clone(), self.events_tx.subscribe(), tx.clone(), ); tokio::spawn(async move { - guh.run().await; + updates_handler.run().await; }); } Err(err) => { @@ -548,8 +518,7 @@ impl GatewayHandler { &self .gateway .hostname - .as_ref() - .cloned() + .clone() .unwrap_or_default(), &public_key, &device, @@ -617,7 +586,7 @@ impl GatewayHandler { debug!("WireGuard peer stats: {stats:?}"); } None => (), - }; + } } Err(err) => { error!("Disconnected from Gateway at {uri}, error: {err}"); diff --git a/crates/defguard_core/src/grpc/gateway/mod.rs b/crates/defguard_core/src/grpc/gateway/mod.rs index afcdc81fb3..b1e48deca5 100644 --- a/crates/defguard_core/src/grpc/gateway/mod.rs +++ b/crates/defguard_core/src/grpc/gateway/mod.rs @@ -1,31 +1,23 @@ -use std::{ - net::IpAddr, - sync::{Arc, Mutex}, -}; +use std::net::IpAddr; use chrono::{DateTime, Utc}; -use client_state::ClientMap; use defguard_common::db::{Id, NoId}; -use defguard_mail::Mail; use defguard_proto::{ enterprise::firewall::FirewallConfig, gateway::{Configuration, CoreResponse, Peer, PeerStats, Update, core_response, update}, }; -use sqlx::PgPool; -use thiserror::Error; use tokio::sync::{ broadcast::{Receiver as BroadcastReceiver, Sender}, - mpsc::{UnboundedSender, error::SendError}, + mpsc::UnboundedSender, }; use tonic::{Code, Status}; -use self::map::GatewayMap; use crate::{ db::{ GatewayEvent, models::{wireguard::WireguardNetwork, wireguard_peer_stats::WireguardPeerStats}, }, - events::{GrpcEvent, GrpcRequestContext}, + events::GrpcRequestContext, }; pub mod client_state; @@ -52,36 +44,12 @@ pub fn send_wireguard_event(event: GatewayEvent, wg_tx: &Sender) { /// /// If you want to use it inside the API context, use [`crate::AppState::send_multiple_wireguard_events`] instead pub fn send_multiple_wireguard_events(events: Vec, wg_tx: &Sender) { - debug!("Sending {} wireguard events", events.len()); + debug!("Sending {} WireGuard events", events.len()); for event in events { send_wireguard_event(event, wg_tx); } } -#[allow(clippy::large_enum_variant)] -#[derive(Debug, Error)] -pub enum GatewayServerError { - #[error("Failed to acquire lock on VPN client state map")] - ClientStateMutexError, - #[error("gRPC event channel error: {0}")] - GrpcEventChannelError(#[from] SendError), -} - -impl From for Status { - fn from(value: GatewayServerError) -> Self { - Self::new(Code::Internal, value.to_string()) - } -} - -pub struct GatewayServer { - pool: PgPool, - gateway_state: Arc>, - client_state: Arc>, - wireguard_tx: Sender, - mail_tx: UnboundedSender, - grpc_event_tx: UnboundedSender, -} - fn gen_config( network: &WireguardNetwork, peers: Vec, @@ -199,7 +167,8 @@ impl GatewayUpdatesHandler { Some(network_info) => { if self.network.mfa_enabled() && !network_info.is_authorized { debug!( - "Created WireGuard device {} is not authorized to connect to MFA enabled location {}", + "Created WireGuard device {} is not authorized to connect to \ + MFA enabled location {}", device.device.name, self.network.name ); continue; @@ -234,7 +203,8 @@ impl GatewayUpdatesHandler { Some(network_info) => { if self.network.mfa_enabled() && !network_info.is_authorized { debug!( - "Modified WireGuard device {} is not authorized to connect to MFA enabled location {}", + "Modified WireGuard device {} is not authorized to connect to \ + MFA enabled location {}", device.device.name, self.network.name ); continue; @@ -457,63 +427,3 @@ impl GatewayUpdatesHandler { Ok(()) } } - -// #[tonic::async_trait] -// impl gateway_service_server::GatewayService for GatewayServer { -// type UpdatesStream = GatewayUpdatesStream; -// -// async fn updates(&self, request: Request<()>) -> Result, Status> { -// // FIXME: tracing causes looping messages, like `INFO gateway_config:gateway_stats:...`. -// // let span = tracing::info_span!("gateway_updates", component = %DefguardComponent::Gateway, -// // version = version.to_string(), info); -// // let _guard = span.enter(); - -// let Some(network) = WireguardNetwork::find_by_id(&self.pool, network_id) -// .await -// .map_err(|_| { -// error!("Failed to fetch network {network_id} from the database"); -// Status::new( -// Code::Internal, -// format!("Failed to retrieve network {network_id} from the database"), -// ) -// })? -// else { -// return Err(Status::new( -// Code::Internal, -// format!("Network with id {network_id} not found"), -// )); -// }; - -// info!("New client connected to updates stream: {hostname}, network {network}",); - -// let (tx, rx) = mpsc::channel(4); -// let events_rx = self.wireguard_tx.subscribe(); -// let mut state = self.gateway_state.lock().unwrap(); -// state -// .connect_gateway(network_id, &hostname, &self.pool) -// .map_err(|err| { -// error!("Failed to connect gateway on network {network_id}: {err}"); -// Status::new( -// Code::Internal, -// format!("Failed to connect gateway on network {network_id}"), -// ) -// })?; - -// // clone here before moving into a closure -// let gateway_hostname = hostname.clone(); -// let handle = tokio::spawn(async move { -// let mut update_handler = -// GatewayUpdatesHandler::new(network_id, network, gateway_hostname, events_rx, tx); -// update_handler.run().await; -// }); - -// Ok(Response::new(GatewayUpdatesStream::new( -// handle, -// rx, -// network_id, -// hostname, -// Arc::clone(&self.gateway_state), -// self.pool.clone(), -// ))) -// } -// } diff --git a/crates/defguard_core/tests/integration/grpc/common/mod.rs b/crates/defguard_core/tests/integration/grpc/common/mod.rs index 82525af4ae..d4ca1d0b1c 100644 --- a/crates/defguard_core/tests/integration/grpc/common/mod.rs +++ b/crates/defguard_core/tests/integration/grpc/common/mod.rs @@ -156,13 +156,8 @@ pub(crate) async fn make_grpc_test_server(pool: &PgPool) -> TestGrpcServer { server, pool.clone(), worker_state, - // gateway_state.clone(), - // client_state.clone(), - // wg_tx.clone(), mail_tx, failed_logins, - // grpc_event_tx, - // Default::default(), ) .await .unwrap(); From 9e36670758fbce5f8807cd460e3bb23e0e20aad9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Tue, 16 Dec 2025 14:45:16 +0100 Subject: [PATCH 11/17] Re-organise --- Cargo.lock | 44 +++--- crates/defguard/src/main.rs | 4 +- .../defguard_core/src/db/models/wireguard.rs | 4 +- crates/defguard_core/src/grpc/gateway/mod.rs | 126 +++++++++++++++++- crates/defguard_core/src/grpc/mod.rs | 110 +-------------- .../defguard_core/src/handlers/wireguard.rs | 90 +++++-------- crates/defguard_version/src/lib.rs | 2 +- 7 files changed, 181 insertions(+), 199 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1dda62d884..a401d595c2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -410,9 +410,9 @@ checksum = "0e050f626429857a27ddccb31e0aca21356bfa709c04041aefddac081a8f068a" [[package]] name = "base64urlsafedata" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "215ee31f8a88f588c349ce2d20108b2ed96089b96b9c2b03775dc35dd72938e8" +checksum = "42f7f6be94fa637132933fd0a68b9140bcb60e3d46164cb68e82a2bb8d102b3a" dependencies = [ "base64 0.21.7", "pastey", @@ -2396,9 +2396,9 @@ checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a" [[package]] name = "icu_properties" -version = "2.1.1" +version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e93fcd3157766c0c8da2f8cff6ce651a31f0810eaa1c51ec363ef790bbb5fb99" +checksum = "020bfc02fe870ec3a66d93e677ccca0562506e5872c650f893269e08615d74ec" dependencies = [ "icu_collections", "icu_locale_core", @@ -2410,9 +2410,9 @@ dependencies = [ [[package]] name = "icu_properties_data" -version = "2.1.1" +version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02845b3647bb045f1100ecd6480ff52f34c35f82d9880e029d329c21d1054899" +checksum = "616c294cf8d725c6afcd8f55abc17c56464ef6211f9ed59cccffe534129c77af" [[package]] name = "icu_provider" @@ -2761,9 +2761,9 @@ dependencies = [ [[package]] name = "libz-rs-sys" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b484ba8d4f775eeca644c452a56650e544bf7e617f1d170fe7298122ead5222" +checksum = "15413ef615ad868d4d65dce091cb233b229419c7c0c4bcaa746c0901c49ff39c" dependencies = [ "zlib-rs", ] @@ -6285,9 +6285,9 @@ dependencies = [ [[package]] name = "webauthn-attestation-ca" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f77a2892ec44032e6c48dad9aad1b05fada09c346ada11d8d32db119b4b4f205" +checksum = "fafcf13f7dc1fb292ed4aea22cdd3757c285d7559e9748950ee390249da4da6b" dependencies = [ "base64urlsafedata", "openssl", @@ -6299,9 +6299,9 @@ dependencies = [ [[package]] name = "webauthn-authenticator-rs" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45f8fe3811c8d6c6830d263452670a608fd4dcdfc481349bd4d1e6a46d6c7a0f" +checksum = "78b41ed08aba475a969094226ae0691a286686210ae497bb2c5d0ed722d8d526" dependencies = [ "async-stream", "async-trait", @@ -6332,9 +6332,9 @@ dependencies = [ [[package]] name = "webauthn-rs" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb7c3a2f9c8bddd524e47bbd427bcf3a28aa074de55d74470b42a91a41937b8e" +checksum = "1b24d082d3360258fefb6ffe56123beef7d6868c765c779f97b7a2fcf06727f8" dependencies = [ "base64urlsafedata", "serde", @@ -6346,9 +6346,9 @@ dependencies = [ [[package]] name = "webauthn-rs-core" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19f1d80f3146382529fe70a3ab5d0feb2413a015204ed7843f9377cd39357fc4" +checksum = "15784340a24c170ce60567282fb956a0938742dbfbf9eff5df793a686a009b8b" dependencies = [ "base64 0.21.7", "base64urlsafedata", @@ -6357,8 +6357,8 @@ dependencies = [ "nom 7.1.3", "openssl", "openssl-sys", - "rand 0.8.5", - "rand_chacha 0.3.1", + "rand 0.9.2", + "rand_chacha 0.9.0", "serde", "serde_cbor_2 0.13.0", "serde_json", @@ -6373,9 +6373,9 @@ dependencies = [ [[package]] name = "webauthn-rs-proto" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e786894f89facb9aaf1c5f6559670236723c98382e045521c76f3d5ca5047bd" +checksum = "16a1fb2580ce73baa42d3011a24de2ceab0d428de1879ece06e02e8c416e497c" dependencies = [ "base64 0.21.7", "base64urlsafedata", @@ -6914,9 +6914,9 @@ dependencies = [ [[package]] name = "zlib-rs" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36134c44663532e6519d7a6dfdbbe06f6f8192bde8ae9ed076e9b213f0e31df7" +checksum = "51f936044d677be1a1168fae1d03b583a285a5dd9d8cbf7b24c23aa1fc775235" [[package]] name = "zopfli" diff --git a/crates/defguard/src/main.rs b/crates/defguard/src/main.rs index 11f25325a4..070c652bad 100644 --- a/crates/defguard/src/main.rs +++ b/crates/defguard/src/main.rs @@ -23,8 +23,8 @@ use defguard_core::{ events::{ApiEvent, BidiStreamEvent, GrpcEvent, InternalEvent}, grpc::{ WorkerState, - gateway::{client_state::ClientMap, map::GatewayMap}, - run_grpc_bidi_stream, run_grpc_gateway_stream, run_grpc_server, + gateway::{client_state::ClientMap, map::GatewayMap, run_grpc_gateway_stream}, + run_grpc_bidi_stream, run_grpc_server, }, init_dev_env, init_vpn_location, run_web_server, utility_thread::run_utility_thread, diff --git a/crates/defguard_core/src/db/models/wireguard.rs b/crates/defguard_core/src/db/models/wireguard.rs index 32c4a4e4fb..7ab886b6bc 100644 --- a/crates/defguard_core/src/db/models/wireguard.rs +++ b/crates/defguard_core/src/db/models/wireguard.rs @@ -41,7 +41,7 @@ use super::{ }; use crate::{ enterprise::{firewall::FirewallError, is_enterprise_enabled}, - grpc::gateway::{send_multiple_wireguard_events, state::GatewayState}, + grpc::gateway::send_multiple_wireguard_events, wg_config::ImportedDevice, }; @@ -1449,7 +1449,7 @@ pub struct WireguardNetworkInfo { #[serde(flatten)] pub network: WireguardNetwork, pub connected: bool, - pub gateways: Vec, + // pub gateways: Vec, pub allowed_groups: Vec, } diff --git a/crates/defguard_core/src/grpc/gateway/mod.rs b/crates/defguard_core/src/grpc/gateway/mod.rs index b1e48deca5..9b9809dc8f 100644 --- a/crates/defguard_core/src/grpc/gateway/mod.rs +++ b/crates/defguard_core/src/grpc/gateway/mod.rs @@ -1,23 +1,38 @@ -use std::net::IpAddr; +use std::{ + collections::HashMap, + net::IpAddr, + sync::{Arc, Mutex}, +}; use chrono::{DateTime, Utc}; -use defguard_common::db::{Id, NoId}; +use defguard_common::{ + config::server_config, + db::{ChangeNotification, Id, NoId, TriggerOperation}, +}; +use defguard_mail::Mail; use defguard_proto::{ enterprise::firewall::FirewallConfig, gateway::{Configuration, CoreResponse, Peer, PeerStats, Update, core_response, update}, }; -use tokio::sync::{ - broadcast::{Receiver as BroadcastReceiver, Sender}, - mpsc::UnboundedSender, +use sqlx::{PgPool, postgres::PgListener}; +use tokio::{ + sync::{ + broadcast::{Receiver as BroadcastReceiver, Sender}, + mpsc::UnboundedSender, + }, + task::{AbortHandle, JoinSet}, }; use tonic::{Code, Status}; use crate::{ db::{ GatewayEvent, - models::{wireguard::WireguardNetwork, wireguard_peer_stats::WireguardPeerStats}, + models::{ + gateway::Gateway, wireguard::WireguardNetwork, wireguard_peer_stats::WireguardPeerStats, + }, }, - events::GrpcRequestContext, + events::{GrpcEvent, GrpcRequestContext}, + grpc::gateway::{client_state::ClientMap, handler::GatewayHandler}, }; pub mod client_state; @@ -87,6 +102,103 @@ impl WireguardPeerStats { } } +const GATEWAY_TABLE_TRIGGER: &str = "gateway_change"; + +/// Bi-directional gRPC stream for comminication with Defguard Gateway. +pub async fn run_grpc_gateway_stream( + pool: PgPool, + client_state: Arc>, + events_tx: Sender, + mail_tx: UnboundedSender, + grpc_event_tx: UnboundedSender, +) -> Result<(), anyhow::Error> { + let config = server_config(); + let tls_config = config.grpc_client_tls_config()?; + + let mut abort_handles = HashMap::new(); + + let mut tasks = JoinSet::new(); + // Helper closure to launch `GatewayHandler`. + let mut launch_gateway_handler = + |gateway: Gateway| -> Result { + let mut gateway_handler = GatewayHandler::new( + gateway, + tls_config.clone(), + pool.clone(), + Arc::clone(&client_state), + events_tx.clone(), + mail_tx.clone(), + grpc_event_tx.clone(), + )?; + let abort_handle = tasks.spawn(async move { + gateway_handler.handle_connection().await; + }); + Ok(abort_handle) + }; + + for gateway in Gateway::all(&pool).await? { + let id = gateway.id; + let abort_handle = launch_gateway_handler(gateway)?; + abort_handles.insert(id, abort_handle); + } + + // Observe gateway URL changes. + let mut listener = PgListener::connect_with(&pool).await?; + listener.listen(GATEWAY_TABLE_TRIGGER).await?; + while let Ok(notification) = listener.recv().await { + let payload = notification.payload(); + match serde_json::from_str::>>(payload) { + Ok(gateway_notification) => match gateway_notification.operation { + TriggerOperation::Insert => { + if let Some(new) = gateway_notification.new { + let id = new.id; + let abort_handle = launch_gateway_handler(new)?; + abort_handles.insert(id, abort_handle); + } + } + TriggerOperation::Update => { + if let (Some(old), Some(new)) = + (gateway_notification.old, gateway_notification.new) + { + if old.url == new.url { + debug!( + "Gateway URL didn't change. Keeping the current gateway handler" + ); + } else if let Some(abort_handle) = abort_handles.remove(&old.id) { + info!("Aborting connection to {old}, it has changed in the database"); + abort_handle.abort(); + let id = new.id; + let abort_handle = launch_gateway_handler(new)?; + abort_handles.insert(id, abort_handle); + } else { + warn!("Cannot find {old} on the list of connected gateways"); + } + } + } + TriggerOperation::Delete => { + if let Some(old) = gateway_notification.old { + if let Some(abort_handle) = abort_handles.remove(&old.id) { + info!( + "Aborting connection to {old}, it has disappeard from the database" + ); + abort_handle.abort(); + } else { + warn!("Cannot find {old} on the list of connected gateways"); + } + } + } + }, + Err(err) => error!("Failed to de-serialize database notification object: {err}"), + } + } + + while let Some(Ok(_result)) = tasks.join_next().await { + debug!("Gateway gRPC task has ended"); + } + + Ok(()) +} + /// Helper struct for handling gateway events. struct GatewayUpdatesHandler { network_id: Id, diff --git a/crates/defguard_core/src/grpc/mod.rs b/crates/defguard_core/src/grpc/mod.rs index 8b9e126eef..fcea09feb9 100644 --- a/crates/defguard_core/src/grpc/mod.rs +++ b/crates/defguard_core/src/grpc/mod.rs @@ -10,7 +10,7 @@ use axum::http::Uri; use defguard_common::{ VERSION, auth::claims::ClaimsType, - db::{ChangeNotification, Id, TriggerOperation, models::Settings}, + db::{Id, models::Settings}, }; use defguard_mail::Mail; use defguard_version::{ @@ -20,13 +20,12 @@ use defguard_version::{ use openidconnect::{AuthorizationCode, Nonce, Scope, core::CoreAuthenticationFlow}; use reqwest::Url; use serde::Serialize; -use sqlx::{PgPool, postgres::PgListener}; +use sqlx::PgPool; use tokio::{ sync::{ broadcast::Sender, mpsc::{self, UnboundedSender}, }, - task::{AbortHandle, JoinSet}, time::sleep, }; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -39,17 +38,13 @@ use tonic::{ use self::{ auth::AuthServer, client_mfa::ClientMfaServer, enrollment::EnrollmentServer, - gateway::handler::GatewayHandler, interceptor::JwtInterceptor, - password_reset::PasswordResetServer, worker::WorkerServer, + interceptor::JwtInterceptor, password_reset::PasswordResetServer, worker::WorkerServer, }; use crate::{ auth::failed_login::FailedLoginMap, db::{ AppEvent, GatewayEvent, - models::{ - enrollment::{ENROLLMENT_TOKEN_TYPE, Token}, - gateway::Gateway, - }, + models::enrollment::{ENROLLMENT_TOKEN_TYPE, Token}, }, enterprise::{ db::models::{ @@ -551,103 +546,6 @@ async fn handle_proxy_message_loop( Ok(()) } -const GATEWAY_TABLE_TRIGGER: &str = "gateway_change"; - -/// Bi-directional gRPC stream for comminication with Defguard Gateway. -pub async fn run_grpc_gateway_stream( - pool: PgPool, - client_state: Arc>, - events_tx: Sender, - mail_tx: UnboundedSender, - grpc_event_tx: UnboundedSender, -) -> Result<(), anyhow::Error> { - let config = server_config(); - let tls_config = config.grpc_client_tls_config()?; - - let mut abort_handles = HashMap::new(); - - let mut tasks = JoinSet::new(); - // Helper closure to launch `GatewayHandler`. - let mut launch_gateway_handler = - |gateway: Gateway| -> Result { - let mut gateway_handler = GatewayHandler::new( - gateway, - tls_config.clone(), - pool.clone(), - Arc::clone(&client_state), - events_tx.clone(), - mail_tx.clone(), - grpc_event_tx.clone(), - )?; - let abort_handle = tasks.spawn(async move { - gateway_handler.handle_connection().await; - }); - Ok(abort_handle) - }; - - for gateway in Gateway::all(&pool).await? { - let id = gateway.id; - let abort_handle = launch_gateway_handler(gateway)?; - abort_handles.insert(id, abort_handle); - } - - // Observe gateway URL changes. - let mut listener = PgListener::connect_with(&pool).await?; - listener.listen(GATEWAY_TABLE_TRIGGER).await?; - while let Ok(notification) = listener.recv().await { - let payload = notification.payload(); - match serde_json::from_str::>>(payload) { - Ok(gateway_notification) => match gateway_notification.operation { - TriggerOperation::Insert => { - if let Some(new) = gateway_notification.new { - let id = new.id; - let abort_handle = launch_gateway_handler(new)?; - abort_handles.insert(id, abort_handle); - } - } - TriggerOperation::Update => { - if let (Some(old), Some(new)) = - (gateway_notification.old, gateway_notification.new) - { - if old.url == new.url { - debug!( - "Gateway URL didn't change. Keeping the current gateway handler" - ); - } else if let Some(abort_handle) = abort_handles.remove(&old.id) { - info!("Aborting connection to {old}, it has changed in the database"); - abort_handle.abort(); - let id = new.id; - let abort_handle = launch_gateway_handler(new)?; - abort_handles.insert(id, abort_handle); - } else { - warn!("Cannot find {old} on the list of connected gateways"); - } - } - } - TriggerOperation::Delete => { - if let Some(old) = gateway_notification.old { - if let Some(abort_handle) = abort_handles.remove(&old.id) { - info!( - "Aborting connection to {old}, it has disappeard from the database" - ); - abort_handle.abort(); - } else { - warn!("Cannot find {old} on the list of connected gateways"); - } - } - } - }, - Err(err) => error!("Failed to de-serialize database notification object: {err}"), - } - } - - while let Some(Ok(_result)) = tasks.join_next().await { - debug!("Gateway gRPC task has ended"); - } - - Ok(()) -} - /// Bi-directional gRPC stream for communication with Defguard Proxy. #[instrument(skip_all)] pub async fn run_grpc_bidi_stream( diff --git a/crates/defguard_core/src/handlers/wireguard.rs b/crates/defguard_core/src/handlers/wireguard.rs index 9410134bdd..091ebd439f 100644 --- a/crates/defguard_core/src/handlers/wireguard.rs +++ b/crates/defguard_core/src/handlers/wireguard.rs @@ -1,12 +1,6 @@ -use std::{ - collections::HashSet, - net::IpAddr, - str::FromStr, - sync::{Arc, Mutex}, -}; +use std::{collections::HashSet, net::IpAddr, str::FromStr}; use axum::{ - Extension, extract::{Json, Path, Query, State}, http::StatusCode, }; @@ -17,7 +11,6 @@ use ipnetwork::IpNetwork; use serde_json::{Value, json}; use sqlx::PgPool; use utoipa::ToSchema; -use uuid::Uuid; use super::{ApiResponse, ApiResult, WebError, device_for_admin_or_self, user_for_admin_or_self}; use crate::{ @@ -44,7 +37,6 @@ use crate::{ limits::update_counts, }, events::{ApiEvent, ApiEventType, ApiRequestContext}, - grpc::gateway::map::GatewayMap, handlers::mail::send_new_device_added_email, server_config, wg_config::{ImportedDevice, parse_wireguard_config}, @@ -445,11 +437,7 @@ pub(crate) async fn delete_network( ("api_token" = []) ) )] -pub(crate) async fn list_networks( - _role: AdminRole, - State(appstate): State, - Extension(gateway_state): Extension>>, -) -> ApiResult { +pub(crate) async fn list_networks(_role: AdminRole, State(appstate): State) -> ApiResult { debug!("Listing WireGuard networks"); let mut network_info = Vec::new(); let networks = WireguardNetwork::all(&appstate.pool).await?; @@ -458,13 +446,10 @@ pub(crate) async fn list_networks( let network_id = network.id; let allowed_groups = network.fetch_allowed_groups(&appstate.pool).await?; { - let gateway_state = gateway_state - .lock() - .expect("Failed to acquire gateway state lock"); network_info.push(WireguardNetworkInfo { network, - connected: gateway_state.connected(network_id), - gateways: gateway_state.get_network_gateway_status(network_id), + connected: false, // FIXME: was: gateway_state.connected(network_id), + // gateways: gateway_state.get_network_gateway_status(network_id), allowed_groups, }); } @@ -504,20 +489,16 @@ pub(crate) async fn network_details( Path(network_id): Path, _role: AdminRole, State(appstate): State, - Extension(gateway_state): Extension>>, ) -> ApiResult { debug!("Displaying network details for network {network_id}"); let network = WireguardNetwork::find_by_id(&appstate.pool, network_id).await?; let response = match network { Some(network) => { let allowed_groups = network.fetch_allowed_groups(&appstate.pool).await?; - let gateway_state = gateway_state - .lock() - .expect("Failed to acquire gateway state lock"); let network_info = WireguardNetworkInfo { network, - connected: gateway_state.connected(network_id), - gateways: gateway_state.get_network_gateway_status(network_id), + connected: false, // FIXME: was: gateway_state.connected(network_id), + // gateways: gateway_state.get_network_gateway_status(network_id), allowed_groups, }; ApiResponse { @@ -539,56 +520,47 @@ pub(crate) async fn network_details( /// /// # Returns /// Returns `Vec` for requested network -pub(crate) async fn gateway_status( - Path(network_id): Path, - _role: AdminRole, - Extension(gateway_state): Extension>>, -) -> ApiResult { +pub(crate) async fn gateway_status(Path(network_id): Path, _role: AdminRole) -> ApiResult { debug!("Displaying gateway status for network {network_id}"); - let gateway_state = gateway_state - .lock() - .expect("Failed to acquire gateway state lock"); + + // TODO: fetch gateways from db + debug!("Displayed gateway status for network {network_id}"); - Ok(ApiResponse { - json: json!(gateway_state.get_network_gateway_status(network_id)), - status: StatusCode::OK, - }) + // Ok(ApiResponse { + // json: json!(gateway_state.get_network_gateway_status(network_id)), + // status: StatusCode::OK, + // }) + Ok(ApiResponse::default()) } /// Returns state of gateways for all networks /// /// Returns current state of gateways as `HashMap>` where key is an id of `WireguardNetwork` -pub(crate) async fn all_gateways_status( - _role: AdminRole, - Extension(gateway_state): Extension>>, -) -> ApiResult { +pub(crate) async fn all_gateways_status(_role: AdminRole) -> ApiResult { debug!("Displaying gateways status for all networks."); - let gateway_state = gateway_state - .lock() - .expect("Failed to acquire gateway state lock"); - let flattened = (*gateway_state).as_flattened(); - Ok(ApiResponse { - json: json!(flattened), - status: StatusCode::OK, - }) + + // let flattened = (*gateway_state).as_flattened(); + // Ok(ApiResponse { + // json: json!(flattened), + // status: StatusCode::OK, + // }) + Ok(ApiResponse::default()) } pub(crate) async fn remove_gateway( Path((network_id, gateway_id)): Path<(i64, String)>, _role: AdminRole, - Extension(gateway_state): Extension>>, ) -> ApiResult { debug!("Removing gateway {gateway_id} in network {network_id}"); - let mut gateway_state = gateway_state - .lock() - .expect("Failed to acquire gateway state lock"); - - gateway_state.remove_gateway( - network_id, - Uuid::from_str(&gateway_id) - .map_err(|_| WebError::Http(StatusCode::INTERNAL_SERVER_ERROR))?, - )?; + + // TODO: fetch gateways from db + + // gateway_state.remove_gateway( + // network_id, + // Uuid::from_str(&gateway_id) + // .map_err(|_| WebError::Http(StatusCode::INTERNAL_SERVER_ERROR))?, + // )?; info!("Removed gateway {gateway_id} in network {network_id}"); diff --git a/crates/defguard_version/src/lib.rs b/crates/defguard_version/src/lib.rs index 05f177b24f..8d6881440e 100644 --- a/crates/defguard_version/src/lib.rs +++ b/crates/defguard_version/src/lib.rs @@ -62,7 +62,7 @@ use std::{cmp::Ordering, fmt, str::FromStr}; -use ::tracing::{error, warn}; +use ::tracing::warn; pub use semver::{BuildMetadata, Error as SemverError, Prerelease, Version}; use serde::Serialize; use thiserror::Error; From ce598fb93d1ccf5a1d27c9b1c6d550d8c29435bd Mon Sep 17 00:00:00 2001 From: Jacek Chmielewski Date: Fri, 19 Dec 2025 10:24:17 +0100 Subject: [PATCH 12/17] restore WireguardNetwork::get_peers method --- .../src/db/models/wireguard.rs | 64 ++++++++++++++++++- .../defguard_core/src/grpc/gateway/handler.rs | 11 ++-- crates/defguard_core/src/grpc/gateway/mod.rs | 8 +-- 3 files changed, 71 insertions(+), 12 deletions(-) diff --git a/crates/defguard_common/src/db/models/wireguard.rs b/crates/defguard_common/src/db/models/wireguard.rs index 28e245fcc8..ee807fb8af 100644 --- a/crates/defguard_common/src/db/models/wireguard.rs +++ b/crates/defguard_common/src/db/models/wireguard.rs @@ -537,8 +537,70 @@ impl WireguardNetwork { Ok(connected_at) } + /// Get a list of all allowed peers + /// + /// Each device is marked as allowed or not allowed in a given network, + /// which enables enforcing peer disconnect in MFA-protected networks. + /// + /// If the location is a service location, only returns peers if enterprise features are enabled. + pub async fn get_peers<'e, E>(&self, executor: E) -> Result, sqlx::Error> + where + E: PgExecutor<'e>, + { + debug!("Fetching all peers for network {}", self.id); + + if self.should_prevent_service_location_usage() { + warn!( + "Tried to use service location {} with disabled enterprise features. No clients will be allowed to connect.", + self.name + ); + return Ok(Vec::new()); + } + + let rows = query!( + "SELECT d.wireguard_pubkey pubkey, preshared_key, \ + -- TODO possible to not use ARRAY-unnest here? + ARRAY( + SELECT host(ip) + FROM unnest(wnd.wireguard_ips) AS ip + ) \"allowed_ips!: Vec\" \ + FROM wireguard_network_device wnd \ + JOIN device d ON wnd.device_id = d.id \ + JOIN \"user\" u ON d.user_id = u.id \ + WHERE wireguard_network_id = $1 AND (is_authorized = true OR NOT $2) \ + AND d.configured = true \ + AND u.is_active = true \ + ORDER BY d.id ASC", + self.id, + self.mfa_enabled() + ) + .fetch_all(executor) + .await?; + + // keepalive has to be added manually because Postgres + // doesn't support unsigned integers + let result = rows + .into_iter() + .map(|row| Peer { + pubkey: row.pubkey, + allowed_ips: row.allowed_ips, + // Don't send preshared key if MFA is not enabled, it can't be used and may + // cause issues with clients connecting if they expect no preshared key + // e.g. when you disable MFA on a location + preshared_key: if self.mfa_enabled() { + row.preshared_key + } else { + None + }, + keepalive_interval: Some(self.keepalive_interval as u32), + }) + .collect(); + + Ok(result) + } + /// Update `connected_at` to the current time and save it to the database. - pub(crate) async fn touch_connected<'e, E>(&mut self, executor: E) -> Result<(), sqlx::Error> + pub async fn touch_connected<'e, E>(&mut self, executor: E) -> Result<(), sqlx::Error> where E: PgExecutor<'e>, { diff --git a/crates/defguard_core/src/grpc/gateway/handler.rs b/crates/defguard_core/src/grpc/gateway/handler.rs index 7273f5b658..248267d813 100644 --- a/crates/defguard_core/src/grpc/gateway/handler.rs +++ b/crates/defguard_core/src/grpc/gateway/handler.rs @@ -7,8 +7,8 @@ use std::{ }, }; -use chrono::{TimeDelta, Utc}; -use defguard_common::{VERSION, auth::claims::Claims, db::{Id, models::wireguard_peer_stats::WireguardPeerStats}}; +use chrono::{DateTime, TimeDelta, Utc}; +use defguard_common::{VERSION, auth::claims::Claims, db::{Id, NoId, models::{WireguardNetwork, wireguard_peer_stats::WireguardPeerStats}}}; use defguard_mail::Mail; use defguard_proto::gateway::{CoreResponse, PeerStats, core_request, core_response, gateway_client}; use defguard_version::client::ClientVersionInterceptor; @@ -30,11 +30,8 @@ use tonic::{ use crate::{ ClaimsType, - db::{ - Device, GatewayEvent, User, WireguardNetwork, - models::{gateway::Gateway, wireguard_peer_stats::WireguardPeerStats}, - }, - grpc::{ClientMap, GrpcEvent, TEN_SECS, gateway::GrpcRequestContext}, + db::models::gateway::Gateway, + grpc::{ClientMap, GrpcEvent, TEN_SECS, gateway::{GrpcRequestContext, events::GatewayEvent}}, handlers::mail::send_gateway_disconnected_email, }; diff --git a/crates/defguard_core/src/grpc/gateway/mod.rs b/crates/defguard_core/src/grpc/gateway/mod.rs index 94b29ba2b2..3337774c16 100644 --- a/crates/defguard_core/src/grpc/gateway/mod.rs +++ b/crates/defguard_core/src/grpc/gateway/mod.rs @@ -1,7 +1,7 @@ use std::{ collections::HashMap, net::IpAddr, - sync::{Arc, Mutex}, + sync::{Arc, Mutex, mpsc::Receiver}, thread::JoinHandle, }; use chrono::{DateTime, TimeDelta, Utc}; @@ -17,7 +17,7 @@ models::{ use defguard_mail::Mail; use defguard_proto::{ enterprise::firewall::FirewallConfig, - gateway::{Configuration, Peer, PeerStats, Update, update}, proxy::{CoreResponse, core_response}, + gateway::{Configuration, ConfigurationRequest, CoreResponse, Peer, PeerStats, Update, core_response, update}, }; use sqlx::{PgPool, postgres::PgListener}; use defguard_version::version_info_from_metadata; @@ -26,11 +26,11 @@ use thiserror::Error; use tokio::{ sync::{ broadcast::{Receiver as BroadcastReceiver, Sender}, - mpsc::UnboundedSender, + mpsc::{self, UnboundedSender}, }, task::{AbortHandle, JoinSet}, }; -use tonic::{Code, Status, metadata::MetadataMap}; +use tonic::{Request, Response, Code, Status, metadata::MetadataMap}; use crate::{ db::{ From 13e7d52aff04115f65af9831d0fc3db58a8394a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Fri, 19 Dec 2025 11:43:56 +0100 Subject: [PATCH 13/17] Let it build --- Cargo.lock | 57 +- crates/defguard/src/main.rs | 13 +- .../src/db/models/wireguard.rs | 90 +- .../defguard_core/src/grpc/gateway/handler.rs | 75 +- crates/defguard_core/src/grpc/gateway/mod.rs | 821 ++++++++++-------- .../defguard_core/src/grpc/gateway/tests.rs | 17 +- crates/defguard_core/src/grpc/mod.rs | 34 +- 7 files changed, 553 insertions(+), 554 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bed4ea9785..8aa34e9b9c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -534,9 +534,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.19.0" +version = "3.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" +checksum = "5dd9dc738b7a8311c7ade152424974d8115f2cdad61e8dab8dac9f2362298510" [[package]] name = "byteorder" @@ -791,9 +791,9 @@ dependencies = [ [[package]] name = "cookie_store" -version = "0.21.1" +version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2eac901828f88a5241ee0600950ab981148a18f2f756900ffba1b125ca6a3ef9" +checksum = "3fc4bff745c9b4c7fb1e97b25d13153da2bc7796260141df62378998d070207f" dependencies = [ "cookie", "document-features", @@ -2782,13 +2782,13 @@ checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" [[package]] name = "libredox" -version = "0.1.10" +version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "416f7e718bdb06000964960ffa43b4335ad4012ae8b99060261aa4a8088d5ccb" +checksum = "df15f6eac291ed1cf25865b1ee60399f57e7c227e7f51bdbd4c5270396a9ed50" dependencies = [ "bitflags 2.10.0", "libc", - "redox_syscall", + "redox_syscall 0.6.0", ] [[package]] @@ -3495,9 +3495,9 @@ dependencies = [ [[package]] name = "os_info" -version = "3.13.0" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c39b5918402d564846d5aba164c09a66cc88d232179dfd3e3c619a25a268392" +checksum = "e4022a17595a00d6a369236fdae483f0de7f0a339960a53118b818238e132224" dependencies = [ "android_system_properties", "log", @@ -3571,7 +3571,7 @@ checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" dependencies = [ "cfg-if", "libc", - "redox_syscall", + "redox_syscall 0.5.18", "smallvec", "windows-link", ] @@ -4213,6 +4213,15 @@ dependencies = [ "bitflags 2.10.0", ] +[[package]] +name = "redox_syscall" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec96166dafa0886eb81fe1c0a388bece180fbef2135f97c1e2cf8302e74b43b5" +dependencies = [ + "bitflags 2.10.0", +] + [[package]] name = "ref-cast" version = "1.0.25" @@ -4264,9 +4273,9 @@ checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" [[package]] name = "reqwest" -version = "0.12.25" +version = "0.12.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6eff9328d40131d43bd911d42d79eb6a47312002a4daefc9e37f17e74a7701a" +checksum = "3b4c14b2d9afca6a60277086b0cc6a6ae0b568f6f7916c943a8cdc79f8be240f" dependencies = [ "base64 0.22.1", "bytes", @@ -4477,9 +4486,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.13.1" +version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "708c0f9d5f54ba0272468c1d306a52c495b31fa155e91bc25371e6df7996908c" +checksum = "21e6f2ab2928ca4291b86736a8bd920a277a399bba1589409d72154ff87c1282" dependencies = [ "web-time", "zeroize", @@ -5663,18 +5672,18 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.7.3" +version = "0.7.5+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2cdb639ebbc97961c51720f858597f7f24c4fc295327923af55b74c3c724533" +checksum = "92e1cfed4a3038bc5a127e35a2d360f145e1f4b971b551a2ba5fd7aedf7e1347" dependencies = [ "serde_core", ] [[package]] name = "toml_edit" -version = "0.23.9" +version = "0.23.10+spec-1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d7cbc3b4b49633d57a0509303158ca50de80ae32c265093b24c414705807832" +checksum = "84c8b9f757e028cee9fa244aea147aab2a9ec09d5325a9b01e0a49730c2b5269" dependencies = [ "indexmap 2.12.1", "toml_datetime", @@ -5684,9 +5693,9 @@ dependencies = [ [[package]] name = "toml_parser" -version = "1.0.4" +version = "1.0.6+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0cbe268d35bdb4bb5a56a2de88d0ad0eb70af5384a99d648cd4b3d04039800e" +checksum = "a3198b4b0a8e11f09dd03e133c0280504d0801269e9afa46362ffde1cbeebf44" dependencies = [ "winnow", ] @@ -5848,9 +5857,9 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" -version = "0.1.43" +version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d15d90a0b5c19378952d479dc858407149d7bb45a14de0142f6c534b16fc647" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ "log", "pin-project-lite", @@ -5871,9 +5880,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.35" +version = "0.1.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a04e24fab5c89c6a36eb8558c9656f30d81de51dfa4d3b45f26b21d61fa0a6c" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" dependencies = [ "once_cell", "valuable", diff --git a/crates/defguard/src/main.rs b/crates/defguard/src/main.rs index 2a51b1dd3a..0a7d54c499 100644 --- a/crates/defguard/src/main.rs +++ b/crates/defguard/src/main.rs @@ -28,8 +28,10 @@ use defguard_core::{ events::{ApiEvent, BidiStreamEvent, GrpcEvent, InternalEvent}, grpc::{ WorkerState, - gateway::{client_state::ClientMap, map::GatewayMap, run_grpc_gateway_stream}, - run_grpc_bidi_stream, run_grpc_server, + gateway::{ + client_state::ClientMap, events::GatewayEvent, map::GatewayMap, run_grpc_gateway_stream, + }, + run_grpc_server, }, init_dev_env, init_vpn_location, run_web_server, utility_thread::run_utility_thread, @@ -173,13 +175,6 @@ async fn main() -> Result<(), anyhow::Error> { mail_tx.clone(), grpc_event_tx, ) => error!("Gateway gRPC stream returned early: {res:?}"), - res = run_grpc_bidi_stream( - pool.clone(), - wireguard_tx.clone(), - mail_tx.clone(), - bidi_event_tx, - Arc::clone(&incompatible_components), - ), if config.proxy_url.is_some() => error!("Proxy gRPC stream returned early: {res:?}"), res = run_grpc_server( Arc::clone(&worker_state), pool.clone(), diff --git a/crates/defguard_common/src/db/models/wireguard.rs b/crates/defguard_common/src/db/models/wireguard.rs index ee807fb8af..82b9076097 100644 --- a/crates/defguard_common/src/db/models/wireguard.rs +++ b/crates/defguard_common/src/db/models/wireguard.rs @@ -5,18 +5,6 @@ use std::{ net::{IpAddr, Ipv4Addr}, }; -use crate::{ - auth::claims::{Claims, ClaimsType}, - db::{ - Id, NoId, - models::{ - ModelError, - group::{Group, Permission}, - wireguard_peer_stats::WireguardPeerStats, - }, - }, - types::user_info::UserInfo, -}; use base64::prelude::{BASE64_STANDARD, Engine}; use chrono::{NaiveDateTime, TimeDelta, Utc}; use ipnetwork::{IpNetwork, IpNetworkError, NetworkSize}; @@ -28,7 +16,7 @@ use sqlx::{ query_scalar, }; use thiserror::Error; -use tracing::{debug, info, warn}; +use tracing::{debug, info}; use utoipa::ToSchema; use x25519_dalek::{PublicKey, StaticSecret}; @@ -36,11 +24,23 @@ use super::{ device::{Device, DeviceError, DeviceType, WireguardNetworkDevice}, user::User, }; +use crate::{ + auth::claims::{Claims, ClaimsType}, + db::{ + Id, NoId, + models::{ + ModelError, + group::{Group, Permission}, + wireguard_peer_stats::WireguardPeerStats, + }, + }, + types::user_info::UserInfo, +}; pub const DEFAULT_KEEPALIVE_INTERVAL: i32 = 25; pub const DEFAULT_DISCONNECT_THRESHOLD: i32 = 300; -// Used in process of importing network from wireguard config +// Used in process of importing network from WireGuard config. #[derive(Clone, Debug, Deserialize, Serialize)] pub struct MappedDevice { pub user_id: Id, @@ -537,68 +537,6 @@ impl WireguardNetwork { Ok(connected_at) } - /// Get a list of all allowed peers - /// - /// Each device is marked as allowed or not allowed in a given network, - /// which enables enforcing peer disconnect in MFA-protected networks. - /// - /// If the location is a service location, only returns peers if enterprise features are enabled. - pub async fn get_peers<'e, E>(&self, executor: E) -> Result, sqlx::Error> - where - E: PgExecutor<'e>, - { - debug!("Fetching all peers for network {}", self.id); - - if self.should_prevent_service_location_usage() { - warn!( - "Tried to use service location {} with disabled enterprise features. No clients will be allowed to connect.", - self.name - ); - return Ok(Vec::new()); - } - - let rows = query!( - "SELECT d.wireguard_pubkey pubkey, preshared_key, \ - -- TODO possible to not use ARRAY-unnest here? - ARRAY( - SELECT host(ip) - FROM unnest(wnd.wireguard_ips) AS ip - ) \"allowed_ips!: Vec\" \ - FROM wireguard_network_device wnd \ - JOIN device d ON wnd.device_id = d.id \ - JOIN \"user\" u ON d.user_id = u.id \ - WHERE wireguard_network_id = $1 AND (is_authorized = true OR NOT $2) \ - AND d.configured = true \ - AND u.is_active = true \ - ORDER BY d.id ASC", - self.id, - self.mfa_enabled() - ) - .fetch_all(executor) - .await?; - - // keepalive has to be added manually because Postgres - // doesn't support unsigned integers - let result = rows - .into_iter() - .map(|row| Peer { - pubkey: row.pubkey, - allowed_ips: row.allowed_ips, - // Don't send preshared key if MFA is not enabled, it can't be used and may - // cause issues with clients connecting if they expect no preshared key - // e.g. when you disable MFA on a location - preshared_key: if self.mfa_enabled() { - row.preshared_key - } else { - None - }, - keepalive_interval: Some(self.keepalive_interval as u32), - }) - .collect(); - - Ok(result) - } - /// Update `connected_at` to the current time and save it to the database. pub async fn touch_connected<'e, E>(&mut self, executor: E) -> Result<(), sqlx::Error> where diff --git a/crates/defguard_core/src/grpc/gateway/handler.rs b/crates/defguard_core/src/grpc/gateway/handler.rs index 248267d813..42921ff880 100644 --- a/crates/defguard_core/src/grpc/gateway/handler.rs +++ b/crates/defguard_core/src/grpc/gateway/handler.rs @@ -8,9 +8,18 @@ use std::{ }; use chrono::{DateTime, TimeDelta, Utc}; -use defguard_common::{VERSION, auth::claims::Claims, db::{Id, NoId, models::{WireguardNetwork, wireguard_peer_stats::WireguardPeerStats}}}; +use defguard_common::{ + VERSION, + auth::claims::Claims, + db::{ + Id, NoId, + models::{Device, User, WireguardNetwork, wireguard_peer_stats::WireguardPeerStats}, + }, +}; use defguard_mail::Mail; -use defguard_proto::gateway::{CoreResponse, PeerStats, core_request, core_response, gateway_client}; +use defguard_proto::gateway::{ + CoreResponse, PeerStats, core_request, core_response, gateway_client, +}; use defguard_version::client::ClientVersionInterceptor; use semver::Version; use sqlx::PgPool; @@ -31,28 +40,32 @@ use tonic::{ use crate::{ ClaimsType, db::models::gateway::Gateway, - grpc::{ClientMap, GrpcEvent, TEN_SECS, gateway::{GrpcRequestContext, events::GatewayEvent}}, + enterprise::firewall::try_get_location_firewall_config, + grpc::{ + ClientMap, GrpcEvent, TEN_SECS, + gateway::{GrpcRequestContext, events::GatewayEvent, get_peers}, + }, handlers::mail::send_gateway_disconnected_email, }; fn peer_stats_from_proto(stats: PeerStats, network_id: Id, device_id: Id) -> WireguardPeerStats { - let endpoint = match stats.endpoint { - endpoint if endpoint.is_empty() => None, - _ => Some(stats.endpoint), - }; - WireguardPeerStats { - id: NoId, - network: network_id, - endpoint, - device_id, - collected_at: Utc::now().naive_utc(), - upload: stats.upload as i64, - download: stats.download as i64, - latest_handshake: DateTime::from_timestamp(stats.latest_handshake as i64, 0) - .unwrap_or_default() - .naive_utc(), - allowed_ips: Some(stats.allowed_ips), - } + let endpoint = match stats.endpoint { + endpoint if endpoint.is_empty() => None, + _ => Some(stats.endpoint), + }; + WireguardPeerStats { + id: NoId, + network: network_id, + endpoint, + device_id, + collected_at: Utc::now().naive_utc(), + upload: stats.upload as i64, + download: stats.download as i64, + latest_handshake: DateTime::from_timestamp(stats.latest_handshake as i64, 0) + .unwrap_or_default() + .naive_utc(), + allowed_ips: Some(stats.allowed_ips), + } } /// One instance per connected Gateway. @@ -165,7 +178,7 @@ impl GatewayHandler { ); } - let peers = network.get_peers(&self.pool).await.map_err(|error| { + let peers = get_peers(&network, &self.pool).await.map_err(|error| { error!("Failed to fetch peers from the database for network {network_id}: {error}",); Status::new( Code::Internal, @@ -173,17 +186,15 @@ impl GatewayHandler { ) })?; - let maybe_firewall_config = - network - .try_get_firewall_config(&mut conn) - .await - .map_err(|err| { - error!("Failed to generate firewall config for network {network_id}: {err}"); - Status::new( - Code::Internal, - format!("Failed to generate firewall config for network: {network_id}"), - ) - })?; + let maybe_firewall_config = try_get_location_firewall_config(&network, &mut conn) + .await + .map_err(|err| { + error!("Failed to generate firewall config for network {network_id}: {err}"); + Status::new( + Code::Internal, + format!("Failed to generate firewall config for network: {network_id}"), + ) + })?; let payload = Some(core_response::Payload::Config(super::gen_config( &network, peers, diff --git a/crates/defguard_core/src/grpc/gateway/mod.rs b/crates/defguard_core/src/grpc/gateway/mod.rs index 3337774c16..4aaff653be 100644 --- a/crates/defguard_core/src/grpc/gateway/mod.rs +++ b/crates/defguard_core/src/grpc/gateway/mod.rs @@ -1,43 +1,46 @@ use std::{ collections::HashMap, net::IpAddr, - sync::{Arc, Mutex, mpsc::Receiver}, thread::JoinHandle, + sync::{Arc, Mutex, mpsc::Receiver}, + thread::JoinHandle, }; -use chrono::{DateTime, TimeDelta, Utc}; +use chrono::{DateTime, Utc}; use defguard_common::{ config::server_config, - db::{ChangeNotification, Id, NoId, TriggerOperation, -models::{ - Device, User, WireguardNetwork, wireguard::ServiceLocationMode, - wireguard_peer_stats::WireguardPeerStats, + db::{ + ChangeNotification, Id, NoId, TriggerOperation, + models::{ + Device, User, WireguardNetwork, wireguard::ServiceLocationMode, + wireguard_peer_stats::WireguardPeerStats, + }, }, - }, }; use defguard_mail::Mail; use defguard_proto::{ enterprise::firewall::FirewallConfig, - gateway::{Configuration, ConfigurationRequest, CoreResponse, Peer, PeerStats, Update, core_response, update}, + gateway::{Configuration, CoreResponse, Peer, PeerStats, Update, core_response, update}, }; -use sqlx::{PgPool, postgres::PgListener}; use defguard_version::version_info_from_metadata; use semver::Version; +use sqlx::{PgExecutor, PgPool, postgres::PgListener, query}; use thiserror::Error; use tokio::{ sync::{ broadcast::{Receiver as BroadcastReceiver, Sender}, - mpsc::{self, UnboundedSender}, + mpsc::{UnboundedSender, error::SendError}, }, task::{AbortHandle, JoinSet}, }; -use tonic::{Request, Response, Code, Status, metadata::MetadataMap}; +use tonic::{Code, Status, metadata::MetadataMap}; use crate::{ - db::{ - models::{ - gateway::Gateway, - }, - }, enterprise::{firewall::try_get_location_firewall_config, is_enterprise_license_active}, events::{GrpcEvent, GrpcRequestContext}, grpc::gateway::{client_state::ClientMap, events::GatewayEvent, handler::GatewayHandler, map::GatewayMap}, location_management::allowed_peers::get_location_allowed_peers + db::models::gateway::Gateway, + enterprise::is_enterprise_license_active, + events::{GrpcEvent, GrpcRequestContext}, + grpc::gateway::{ + client_state::ClientMap, events::GatewayEvent, handler::GatewayHandler, map::GatewayMap, + }, }; pub mod client_state; @@ -129,6 +132,74 @@ pub fn should_prevent_service_location_usage(location: &WireguardNetwork) -> && !is_enterprise_license_active() } +/// Get a list of all allowed peers +/// +/// Each device is marked as allowed or not allowed in a given network, +/// which enables enforcing peer disconnect in MFA-protected networks. +/// +/// If the location is a service location, only returns peers if enterprise features are enabled. +/// +/// XXX: should be implemented in defguard_core::db::models::wireguard::WireguardNetwork. +pub async fn get_peers<'e, E>( + location: &WireguardNetwork, + executor: E, +) -> Result, sqlx::Error> +where + E: PgExecutor<'e>, +{ + debug!("Fetching all peers for network {}", location.id); + + if should_prevent_service_location_usage(&location) { + warn!( + "Tried to use service location {} with disabled enterprise features. No clients \ + will be allowed to connect.", + location.name + ); + return Ok(Vec::new()); + } + + let rows = query!( + "SELECT d.wireguard_pubkey pubkey, preshared_key, \ + -- TODO possible to not use ARRAY-unnest here? + ARRAY( + SELECT host(ip) + FROM unnest(wnd.wireguard_ips) AS ip + ) \"allowed_ips!: Vec\" \ + FROM wireguard_network_device wnd \ + JOIN device d ON wnd.device_id = d.id \ + JOIN \"user\" u ON d.user_id = u.id \ + WHERE wireguard_network_id = $1 AND (is_authorized = true OR NOT $2) \ + AND d.configured = true \ + AND u.is_active = true \ + ORDER BY d.id ASC", + location.id, + location.mfa_enabled() + ) + .fetch_all(executor) + .await?; + + // keepalive has to be added manually because Postgres + // doesn't support unsigned integers + let result = rows + .into_iter() + .map(|row| Peer { + pubkey: row.pubkey, + allowed_ips: row.allowed_ips, + // Don't send preshared key if MFA is not enabled, it can't be used and may + // cause issues with clients connecting if they expect no preshared key + // e.g. when you disable MFA on a location + preshared_key: if location.mfa_enabled() { + row.preshared_key + } else { + None + }, + keepalive_interval: Some(location.keepalive_interval as u32), + }) + .collect(); + + Ok(result) +} + /// Utility struct encapsulating commonly extracted metadata fields during gRPC communication. struct GatewayMetadata { network_id: Id, @@ -775,362 +846,362 @@ impl GatewayUpdatesStream { } } -impl Stream for GatewayUpdatesStream { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.rx).poll_recv(cx) - } -} - -impl Drop for GatewayUpdatesStream { - fn drop(&mut self) { - info!("Client disconnected"); - // terminate update task - self.task_handle.abort(); - // update gateway state - // TODO: possibly use a oneshot channel instead - self.gateway_state - .lock() - .unwrap() - .disconnect_gateway(self.network_id, self.gateway_hostname.clone(), &self.pool) - .expect("Unable to disconnect gateway."); - } -} - -#[tonic::async_trait] -impl gateway_service_server::GatewayService for GatewayServer { - type UpdatesStream = GatewayUpdatesStream; - - /// Retrieve stats from gateway and save it to database - async fn stats( - &self, - request: Request>, - ) -> Result, Status> { - let GatewayMetadata { - network_id, - hostname, - .. - } = Self::extract_metadata(request.metadata())?; - let mut stream = request.into_inner(); - let mut disconnect_timer = interval(Duration::from_secs(PEER_DISCONNECT_INTERVAL)); - // FIXME: tracing causes looping messages, like `INFO gateway_config:gateway_stats:...`. - // let span = tracing::info_span!("gateway_stats", component = %DefguardComponent::Gateway, - // version = version.to_string(), info); - // let _guard = span.enter(); - loop { - // Wait for a message or update client map at least once a mninute, if no messages are - // received. - let stats_update = tokio::select! { - message = stream.message() => { - match message? { - Some(update) => update, - None => break, // Stream ended - } - } - _ = disconnect_timer.tick() => { - debug!("No stats updates received in last {PEER_DISCONNECT_INTERVAL} seconds. \ - Updating disconnected VPN clients"); - // fetch location to get current peer disconnect threshold - let location = self.fetch_location_from_db(network_id).await?; - - // perform client state operations in a dedicated block to drop mutex guard - let disconnected_clients = { - // acquire lock on client state map - let mut client_map = self.get_client_state_guard()?; - - // disconnect inactive clients - client_map.disconnect_inactive_vpn_clients_for_location(&location - )? - }; - - // emit client disconnect events - for (device, context) in disconnected_clients { - self.emit_event(GrpcEvent::ClientDisconnected { - context, - location: location.clone(), - device, - })?; - }; - continue; - } - }; - - debug!("Received stats message: {stats_update:?}"); - let Some(stats_update::Payload::PeerStats(peer_stats)) = stats_update.payload else { - debug!("Received stats message is empty, skipping."); - continue; - }; - let public_key = peer_stats.public_key.clone(); - - // fetch device from DB - // TODO: fetch only when device has changed and use client state otherwise - let device = match self.fetch_device_from_db(&public_key).await? { - Some(device) => device, - None => { - warn!( - "Received stats update for a device which does not exist: {public_key}, skipping." - ); - continue; - } - }; - - // copy device ID for easier reference later - let device_id = device.id; - - // fetch user and location from DB for activity log - // TODO: cache usernames since they don't change - let user = self.fetch_user_from_db(device.user_id, &public_key).await?; - let location = self.fetch_location_from_db(network_id).await?; - - // convert stats to DB storage format - let stats = protos_into_internal_stats(peer_stats, network_id, device_id); - - // only perform client state update if stats include an endpoint IP - // otherwise a peer was added to the gateway interface - // but has not connected yet - if let Some(endpoint) = &stats.endpoint { - // parse client endpoint IP - let socket_addr: SocketAddr = endpoint.clone().parse().map_err(|err| { - error!("Failed to parse VPN client endpoint: {err}"); - Status::new( - Code::Internal, - format!("Failed to parse VPN client endpoint: {err}"), - ) - })?; - - // perform client state operations in a dedicated block to drop mutex guard - let disconnected_clients = { - // acquire lock on client state map - let mut client_map = self.get_client_state_guard()?; - - // update connected clients map - match client_map.get_vpn_client(network_id, &public_key) { - Some(client_state) => { - // update connected client state - client_state.update_client_state( - device, - socket_addr, - stats.latest_handshake, - stats.upload, - stats.download, - ); - } - None => { - // don't mark inactive peers as connected - if (Utc::now().naive_utc() - stats.latest_handshake) - < TimeDelta::seconds(location.peer_disconnect_threshold.into()) - { - // mark new VPN client as connected - client_map.connect_vpn_client( - network_id, - &hostname, - &public_key, - &device, - &user, - socket_addr, - &stats, - )?; - - // emit connection event - let context = GrpcRequestContext::new( - user.id, - user.username.clone(), - socket_addr.ip(), - device.id, - device.name.clone(), - location.clone(), - ); - self.emit_event(GrpcEvent::ClientConnected { - context, - location: location.clone(), - device: device.clone(), - })?; - } - } - } - - // disconnect inactive clients - client_map.disconnect_inactive_vpn_clients_for_location(&location)? - }; - - // emit client disconnect events - for (device, context) in disconnected_clients { - self.emit_event(GrpcEvent::ClientDisconnected { - context, - location: location.clone(), - device, - })?; - } - } - - // Save stats to db - let stats = match stats.save(&self.pool).await { - Ok(stats) => stats, - Err(err) => { - error!("Saving WireGuard peer stats to db failed: {err}"); - return Err(Status::new( - Code::Internal, - format!("Saving WireGuard peer stats to db failed: {err}"), - )); - } - }; - info!("Saved WireGuard peer stats to db."); - debug!("WireGuard peer stats: {stats:?}"); - } - - Ok(Response::new(())) - } - - async fn config( - &self, - request: Request, - ) -> Result, Status> { - debug!("Sending configuration to gateway client."); - let GatewayMetadata { - network_id, - hostname, - version, - .. - // info, - } = Self::extract_metadata(request.metadata())?; - // FIXME: tracing causes looping messages, like `INFO gateway_config:gateway_stats:...`. - // let span = tracing::info_span!("gateway_config", component = %DefguardComponent::Gateway, - // version = version.to_string(), info); - // let _guard = span.enter(); - - let mut conn = self.pool.acquire().await.map_err(|e| { - error!("Failed to acquire DB connection: {e}"); - Status::new( - Code::Internal, - "Failed to acquire DB connection".to_string(), - ) - })?; - - let mut network = WireguardNetwork::find_by_id(&mut *conn, network_id) - .await - .map_err(|e| { - error!("Network {network_id} not found"); - Status::new(Code::Internal, format!("Failed to retrieve network: {e}")) - })? - .ok_or_else(|| { - Status::new( - Code::Internal, - format!("Network with id {network_id} not found"), - ) - })?; - - debug!("Sending configuration to gateway client, network {network}."); - - // store connected gateway in memory - { - let mut state = self.gateway_state.lock().unwrap(); - state.add_gateway( - network_id, - &network.name, - hostname, - request.into_inner().name, - self.mail_tx.clone(), - version, - ); - } - - network.connected_at = Some(Utc::now().naive_utc()); - if let Err(err) = network.save(&mut *conn).await { - error!("Failed to save updated network {network_id} in the database, status: {err}"); - } - - let peers = - get_location_allowed_peers(&network, &mut *conn) - .await - .map_err(|error| { - error!( - "Failed to fetch peers from the database for network {network_id}: {error}", - ); - Status::new( - Code::Internal, - format!( - "Failed to retrieve peers from the database for network: {network_id}" - ), - ) - })?; - let maybe_firewall_config = try_get_location_firewall_config(&network, &mut conn) - .await - .map_err(|err| { - error!("Failed to generate firewall config for network {network_id}: {err}"); - Status::new( - Code::Internal, - format!("Failed to generate firewall config for network: {network_id}"), - ) - })?; - - info!("Configuration sent to gateway client, network {network}."); - - Ok(Response::new(gen_config( - &network, - peers, - maybe_firewall_config, - ))) - } - - async fn updates(&self, request: Request<()>) -> Result, Status> { - let GatewayMetadata { - network_id, - hostname, - .. - // info, - } = Self::extract_metadata(request.metadata())?; - // FIXME: tracing causes looping messages, like `INFO gateway_config:gateway_stats:...`. - // let span = tracing::info_span!("gateway_updates", component = %DefguardComponent::Gateway, - // version = version.to_string(), info); - // let _guard = span.enter(); - - let Some(network) = WireguardNetwork::find_by_id(&self.pool, network_id) - .await - .map_err(|_| { - error!("Failed to fetch network {network_id} from the database"); - Status::new( - Code::Internal, - format!("Failed to retrieve network {network_id} from the database"), - ) - })? - else { - return Err(Status::new( - Code::Internal, - format!("Network with id {network_id} not found"), - )); - }; - - info!("New client connected to updates stream: {hostname}, network {network}",); - - let (tx, rx) = mpsc::channel(4); - let events_rx = self.wireguard_tx.subscribe(); - let mut state = self.gateway_state.lock().unwrap(); - state - .connect_gateway(network_id, &hostname, &self.pool) - .map_err(|err| { - error!("Failed to connect gateway on network {network_id}: {err}"); - Status::new( - Code::Internal, - format!("Failed to connect gateway on network {network_id}"), - ) - })?; - - // clone here before moving into a closure - let gateway_hostname = hostname.clone(); - let handle = tokio::spawn(async move { - let mut update_handler = - GatewayUpdatesHandler::new(network_id, network, gateway_hostname, events_rx, tx); - update_handler.run().await; - }); - - Ok(Response::new(GatewayUpdatesStream::new( - handle, - rx, - network_id, - hostname, - Arc::clone(&self.gateway_state), - self.pool.clone(), - ))) - } -} +// impl Stream for GatewayUpdatesStream { +// type Item = Result; + +// fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { +// Pin::new(&mut self.rx).poll_recv(cx) +// } +// } + +// impl Drop for GatewayUpdatesStream { +// fn drop(&mut self) { +// info!("Client disconnected"); +// // terminate update task +// self.task_handle.abort(); +// // update gateway state +// // TODO: possibly use a oneshot channel instead +// self.gateway_state +// .lock() +// .unwrap() +// .disconnect_gateway(self.network_id, self.gateway_hostname.clone(), &self.pool) +// .expect("Unable to disconnect gateway."); +// } +// } + +// #[tonic::async_trait] +// impl gateway_service_server::GatewayService for GatewayServer { +// type UpdatesStream = GatewayUpdatesStream; + +// /// Retrieve stats from gateway and save it to database +// async fn stats( +// &self, +// request: Request>, +// ) -> Result, Status> { +// let GatewayMetadata { +// network_id, +// hostname, +// .. +// } = Self::extract_metadata(request.metadata())?; +// let mut stream = request.into_inner(); +// let mut disconnect_timer = interval(Duration::from_secs(PEER_DISCONNECT_INTERVAL)); +// // FIXME: tracing causes looping messages, like `INFO gateway_config:gateway_stats:...`. +// // let span = tracing::info_span!("gateway_stats", component = %DefguardComponent::Gateway, +// // version = version.to_string(), info); +// // let _guard = span.enter(); +// loop { +// // Wait for a message or update client map at least once a mninute, if no messages are +// // received. +// let stats_update = tokio::select! { +// message = stream.message() => { +// match message? { +// Some(update) => update, +// None => break, // Stream ended +// } +// } +// _ = disconnect_timer.tick() => { +// debug!("No stats updates received in last {PEER_DISCONNECT_INTERVAL} seconds. \ +// Updating disconnected VPN clients"); +// // fetch location to get current peer disconnect threshold +// let location = self.fetch_location_from_db(network_id).await?; + +// // perform client state operations in a dedicated block to drop mutex guard +// let disconnected_clients = { +// // acquire lock on client state map +// let mut client_map = self.get_client_state_guard()?; + +// // disconnect inactive clients +// client_map.disconnect_inactive_vpn_clients_for_location(&location +// )? +// }; + +// // emit client disconnect events +// for (device, context) in disconnected_clients { +// self.emit_event(GrpcEvent::ClientDisconnected { +// context, +// location: location.clone(), +// device, +// })?; +// }; +// continue; +// } +// }; + +// debug!("Received stats message: {stats_update:?}"); +// let Some(stats_update::Payload::PeerStats(peer_stats)) = stats_update.payload else { +// debug!("Received stats message is empty, skipping."); +// continue; +// }; +// let public_key = peer_stats.public_key.clone(); + +// // fetch device from DB +// // TODO: fetch only when device has changed and use client state otherwise +// let device = match self.fetch_device_from_db(&public_key).await? { +// Some(device) => device, +// None => { +// warn!( +// "Received stats update for a device which does not exist: {public_key}, skipping." +// ); +// continue; +// } +// }; + +// // copy device ID for easier reference later +// let device_id = device.id; + +// // fetch user and location from DB for activity log +// // TODO: cache usernames since they don't change +// let user = self.fetch_user_from_db(device.user_id, &public_key).await?; +// let location = self.fetch_location_from_db(network_id).await?; + +// // convert stats to DB storage format +// let stats = protos_into_internal_stats(peer_stats, network_id, device_id); + +// // only perform client state update if stats include an endpoint IP +// // otherwise a peer was added to the gateway interface +// // but has not connected yet +// if let Some(endpoint) = &stats.endpoint { +// // parse client endpoint IP +// let socket_addr: SocketAddr = endpoint.clone().parse().map_err(|err| { +// error!("Failed to parse VPN client endpoint: {err}"); +// Status::new( +// Code::Internal, +// format!("Failed to parse VPN client endpoint: {err}"), +// ) +// })?; + +// // perform client state operations in a dedicated block to drop mutex guard +// let disconnected_clients = { +// // acquire lock on client state map +// let mut client_map = self.get_client_state_guard()?; + +// // update connected clients map +// match client_map.get_vpn_client(network_id, &public_key) { +// Some(client_state) => { +// // update connected client state +// client_state.update_client_state( +// device, +// socket_addr, +// stats.latest_handshake, +// stats.upload, +// stats.download, +// ); +// } +// None => { +// // don't mark inactive peers as connected +// if (Utc::now().naive_utc() - stats.latest_handshake) +// < TimeDelta::seconds(location.peer_disconnect_threshold.into()) +// { +// // mark new VPN client as connected +// client_map.connect_vpn_client( +// network_id, +// &hostname, +// &public_key, +// &device, +// &user, +// socket_addr, +// &stats, +// )?; + +// // emit connection event +// let context = GrpcRequestContext::new( +// user.id, +// user.username.clone(), +// socket_addr.ip(), +// device.id, +// device.name.clone(), +// location.clone(), +// ); +// self.emit_event(GrpcEvent::ClientConnected { +// context, +// location: location.clone(), +// device: device.clone(), +// })?; +// } +// } +// } + +// // disconnect inactive clients +// client_map.disconnect_inactive_vpn_clients_for_location(&location)? +// }; + +// // emit client disconnect events +// for (device, context) in disconnected_clients { +// self.emit_event(GrpcEvent::ClientDisconnected { +// context, +// location: location.clone(), +// device, +// })?; +// } +// } + +// // Save stats to db +// let stats = match stats.save(&self.pool).await { +// Ok(stats) => stats, +// Err(err) => { +// error!("Saving WireGuard peer stats to db failed: {err}"); +// return Err(Status::new( +// Code::Internal, +// format!("Saving WireGuard peer stats to db failed: {err}"), +// )); +// } +// }; +// info!("Saved WireGuard peer stats to db."); +// debug!("WireGuard peer stats: {stats:?}"); +// } + +// Ok(Response::new(())) +// } + +// async fn config( +// &self, +// request: Request, +// ) -> Result, Status> { +// debug!("Sending configuration to gateway client."); +// let GatewayMetadata { +// network_id, +// hostname, +// version, +// .. +// // info, +// } = Self::extract_metadata(request.metadata())?; +// // FIXME: tracing causes looping messages, like `INFO gateway_config:gateway_stats:...`. +// // let span = tracing::info_span!("gateway_config", component = %DefguardComponent::Gateway, +// // version = version.to_string(), info); +// // let _guard = span.enter(); + +// let mut conn = self.pool.acquire().await.map_err(|e| { +// error!("Failed to acquire DB connection: {e}"); +// Status::new( +// Code::Internal, +// "Failed to acquire DB connection".to_string(), +// ) +// })?; + +// let mut network = WireguardNetwork::find_by_id(&mut *conn, network_id) +// .await +// .map_err(|e| { +// error!("Network {network_id} not found"); +// Status::new(Code::Internal, format!("Failed to retrieve network: {e}")) +// })? +// .ok_or_else(|| { +// Status::new( +// Code::Internal, +// format!("Network with id {network_id} not found"), +// ) +// })?; + +// debug!("Sending configuration to gateway client, network {network}."); + +// // store connected gateway in memory +// { +// let mut state = self.gateway_state.lock().unwrap(); +// state.add_gateway( +// network_id, +// &network.name, +// hostname, +// request.into_inner().name, +// self.mail_tx.clone(), +// version, +// ); +// } + +// network.connected_at = Some(Utc::now().naive_utc()); +// if let Err(err) = network.save(&mut *conn).await { +// error!("Failed to save updated network {network_id} in the database, status: {err}"); +// } + +// let peers = +// get_location_allowed_peers(&network, &mut *conn) +// .await +// .map_err(|error| { +// error!( +// "Failed to fetch peers from the database for network {network_id}: {error}", +// ); +// Status::new( +// Code::Internal, +// format!( +// "Failed to retrieve peers from the database for network: {network_id}" +// ), +// ) +// })?; +// let maybe_firewall_config = try_get_location_firewall_config(&network, &mut conn) +// .await +// .map_err(|err| { +// error!("Failed to generate firewall config for network {network_id}: {err}"); +// Status::new( +// Code::Internal, +// format!("Failed to generate firewall config for network: {network_id}"), +// ) +// })?; + +// info!("Configuration sent to gateway client, network {network}."); + +// Ok(Response::new(gen_config( +// &network, +// peers, +// maybe_firewall_config, +// ))) +// } + +// async fn updates(&self, request: Request<()>) -> Result, Status> { +// let GatewayMetadata { +// network_id, +// hostname, +// .. +// // info, +// } = Self::extract_metadata(request.metadata())?; +// // FIXME: tracing causes looping messages, like `INFO gateway_config:gateway_stats:...`. +// // let span = tracing::info_span!("gateway_updates", component = %DefguardComponent::Gateway, +// // version = version.to_string(), info); +// // let _guard = span.enter(); + +// let Some(network) = WireguardNetwork::find_by_id(&self.pool, network_id) +// .await +// .map_err(|_| { +// error!("Failed to fetch network {network_id} from the database"); +// Status::new( +// Code::Internal, +// format!("Failed to retrieve network {network_id} from the database"), +// ) +// })? +// else { +// return Err(Status::new( +// Code::Internal, +// format!("Network with id {network_id} not found"), +// )); +// }; + +// info!("New client connected to updates stream: {hostname}, network {network}",); + +// let (tx, rx) = mpsc::channel(4); +// let events_rx = self.wireguard_tx.subscribe(); +// let mut state = self.gateway_state.lock().unwrap(); +// state +// .connect_gateway(network_id, &hostname, &self.pool) +// .map_err(|err| { +// error!("Failed to connect gateway on network {network_id}: {err}"); +// Status::new( +// Code::Internal, +// format!("Failed to connect gateway on network {network_id}"), +// ) +// })?; + +// // clone here before moving into a closure +// let gateway_hostname = hostname.clone(); +// let handle = tokio::spawn(async move { +// let mut update_handler = +// GatewayUpdatesHandler::new(network_id, network, gateway_hostname, events_rx, tx); +// update_handler.run().await; +// }); + +// Ok(Response::new(GatewayUpdatesStream::new( +// handle, +// rx, +// network_id, +// hostname, +// Arc::clone(&self.gateway_state), +// self.pool.clone(), +// ))) +// } +// } diff --git a/crates/defguard_core/src/grpc/gateway/tests.rs b/crates/defguard_core/src/grpc/gateway/tests.rs index b00b1f004f..c3c209c9a0 100644 --- a/crates/defguard_core/src/grpc/gateway/tests.rs +++ b/crates/defguard_core/src/grpc/gateway/tests.rs @@ -4,9 +4,6 @@ use std::{ sync::{Arc, Mutex}, }; -use defguard_common::db::setup_pool; -use defguard_mail::Mail; -use defguard_proto::gateway::{CoreRequest, CoreResponse, gateway_server}; use ipnetwork::IpNetwork; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use tokio::{ @@ -16,13 +13,17 @@ use tokio::{ use tokio_stream::wrappers::{UnboundedReceiverStream, UnixListenerStream}; use tonic::{Request, Response, Status, Streaming, transport::Server}; +use defguard_common::db::{ + models::wireguard::{LocationMfaMode, ServiceLocationMode, WireguardNetwork}, + setup_pool, +}; +use defguard_mail::Mail; +use defguard_proto::gateway::{CoreRequest, CoreResponse, gateway_server}; + use super::{TONIC_SOCKET, handler::GatewayHandler}; use crate::{ - db::models::{ - gateway::Gateway, - wireguard::{GatewayEvent, LocationMfaMode, ServiceLocationMode, WireguardNetwork}, - }, - grpc::{ClientMap, GrpcEvent}, + db::models::gateway::Gateway, + grpc::{ClientMap, GrpcEvent, gateway::events::GatewayEvent}, }; // TODO: move to "gateway" repo. diff --git a/crates/defguard_core/src/grpc/mod.rs b/crates/defguard_core/src/grpc/mod.rs index 946e39d835..3a6ce742e8 100644 --- a/crates/defguard_core/src/grpc/mod.rs +++ b/crates/defguard_core/src/grpc/mod.rs @@ -8,33 +8,15 @@ use std::{ use reqwest::Url; use serde::Serialize; use sqlx::PgPool; -use tokio::sync::{broadcast::Sender, mpsc::UnboundedSender}; +use tokio::sync::mpsc::UnboundedSender; use tonic::transport::{Identity, Server, ServerTlsConfig, server::Router}; -use tower::ServiceBuilder; use defguard_common::{ - VERSION, auth::claims::ClaimsType, db::{Id, models::Settings}, }; -use defguard_mail::Mail; -use defguard_version::{ - ComponentInfo, DefguardComponent, Version, client::ClientVersionInterceptor, - get_tracing_variables, -}; -use openidconnect::{AuthorizationCode, Nonce, Scope, core::CoreAuthenticationFlow}; -use tokio_stream::wrappers::UnboundedReceiverStream; -use tonic::{ - Code, Streaming, - transport::{ - Certificate, ClientTlsConfig, Endpoint, - }, -}; -use self::{ - auth::AuthServer, - interceptor::JwtInterceptor, worker::WorkerServer, -}; +use self::{auth::AuthServer, interceptor::JwtInterceptor, worker::WorkerServer}; use crate::{ auth::failed_login::FailedLoginMap, db::AppEvent, @@ -45,13 +27,10 @@ use crate::{ }, is_business_license_active, }, - events::{BidiStreamEvent, GrpcEvent}, + events::GrpcEvent, grpc::gateway::client_state::ClientMap, server_config, - version::{ - IncompatibleComponents, IncompatibleProxyData, MIN_GATEWAY_VERSION, - is_proxy_version_supported, - }, + version::MIN_GATEWAY_VERSION, }; mod auth; @@ -72,11 +51,6 @@ pub mod proto { use defguard_proto::{ auth::auth_service_server::AuthServiceServer, - proxy::{ - AuthCallbackResponse, AuthInfoResponse, CoreError, CoreRequest, CoreResponse, core_request, - core_response, proxy_client::ProxyClient, - }, - gateway::gateway_service_server::GatewayServiceServer, worker::worker_service_server::WorkerServiceServer, }; From e4f83acdfb628b2f441333fb165da24c37cac434 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Fri, 19 Dec 2025 12:25:34 +0100 Subject: [PATCH 14/17] Removed commented-out section --- crates/defguard_core/src/grpc/mod.rs | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/crates/defguard_core/src/grpc/mod.rs b/crates/defguard_core/src/grpc/mod.rs index 3a6ce742e8..4072df66d0 100644 --- a/crates/defguard_core/src/grpc/mod.rs +++ b/crates/defguard_core/src/grpc/mod.rs @@ -118,33 +118,6 @@ pub async fn build_grpc_service_router( .tcp_keepalive(Some(TEN_SECS)) .add_service(health_service) .add_service(auth_service); - - // let router = { - // use crate::version::GatewayVersionInterceptor; - - // let gateway_service = GatewayServiceServer::new(GatewayServer::new( - // pool, - // gateway_state, - // client_state, - // wireguard_tx, - // mail_tx, - // grpc_event_tx, - // )); - - // let own_version = Version::parse(VERSION)?; - // router.add_service( - // ServiceBuilder::new() - // .layer(tonic::service::InterceptorLayer::new(JwtInterceptor::new( - // ClaimsType::Gateway, - // ))) - // .layer(tonic::service::InterceptorLayer::new( - // GatewayVersionInterceptor::new(MIN_GATEWAY_VERSION, incompatible_components), - // )) - // .layer(DefguardVersionLayer::new(own_version)) - // .service(gateway_service), - // ) - // }; - let router = router.add_service(worker_service); Ok(router) From 1c3ef813130385360716ed205f081219b8643962 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Fri, 19 Dec 2025 12:35:08 +0100 Subject: [PATCH 15/17] cargo fmt --- crates/defguard_core/src/handlers/wireguard.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/crates/defguard_core/src/handlers/wireguard.rs b/crates/defguard_core/src/handlers/wireguard.rs index 73d34c306e..d5dfa01a5c 100644 --- a/crates/defguard_core/src/handlers/wireguard.rs +++ b/crates/defguard_core/src/handlers/wireguard.rs @@ -13,7 +13,9 @@ use defguard_common::{ Device, DeviceConfig, DeviceNetworkInfo, DeviceType, WireguardNetwork, device::{AddDevice, DeviceInfo, ModifyDevice, WireguardNetworkDevice}, wireguard::{ - DateTimeAggregation, LocationMfaMode, MappedDevice, ServiceLocationMode, WireguardDeviceStatsRow, WireguardNetworkInfo, WireguardNetworkStats, WireguardUserStatsRow, networks_stats + DateTimeAggregation, LocationMfaMode, MappedDevice, ServiceLocationMode, + WireguardDeviceStatsRow, WireguardNetworkInfo, WireguardNetworkStats, + WireguardUserStatsRow, networks_stats, }, }, }, From 3a35e2d6bd1ea5b6b48b94b138071d9d68766fe1 Mon Sep 17 00:00:00 2001 From: Jacek Chmielewski Date: Fri, 19 Dec 2025 12:36:00 +0100 Subject: [PATCH 16/17] don't quit when no proxy url is provided --- crates/defguard_proxy_manager/src/lib.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/defguard_proxy_manager/src/lib.rs b/crates/defguard_proxy_manager/src/lib.rs index f950214de9..7e4f34e9c1 100644 --- a/crates/defguard_proxy_manager/src/lib.rs +++ b/crates/defguard_proxy_manager/src/lib.rs @@ -168,6 +168,7 @@ impl ProxyOrchestrator { pub async fn run(self, url: &Option) -> Result<(), ProxyError> { // TODO retrieve proxies from db let Some(url) = url else { + tokio::time::sleep(Duration::MAX).await; return Ok(()); }; let proxies = vec![Proxy::new( From acd6a4378426d37e9b39b549d59e89300e431620 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Fri, 19 Dec 2025 12:44:48 +0100 Subject: [PATCH 17/17] .sqlx --- ...06ff43f8ff063a6772586b79f1c524a848554.json | 52 +++++++++++++++++++ ...719811a7c80e0954d6024f00b5ef883aeab3d.json | 50 ++++++++++++++++++ ...21950b9cb440f20904802a56cdebf2339d481.json | 15 ++++++ ...4948da760fa885ececf72a624f859f73fb38c.json | 14 +++++ ...6e24f24239df0ba778666ca1751c9436aec8e.json | 26 ++++++++++ ...611ebf0d827fcbba365130308998a1b691c50.json | 35 +++++++++++++ ...29532ee95c54bc1343706bc61d9e4bceed4d4.json | 16 ++++++ ...2db17b3b9754bb28b66bf001f0a3aa2655839.json | 15 ++++++ ...f997807b66e0b532da747b146513c34e15c5c.json | 52 +++++++++++++++++++ ...cf3da8e12f7179da8cb42fa0446b224f9dec4.json | 19 +++++++ 10 files changed, 294 insertions(+) create mode 100644 .sqlx/query-0648b5c0e4d4a4cd922ccaef80e06ff43f8ff063a6772586b79f1c524a848554.json create mode 100644 .sqlx/query-0cd356bb88839bcc76d1fe3ed26719811a7c80e0954d6024f00b5ef883aeab3d.json create mode 100644 .sqlx/query-1a4f2c99da8a0db6208abdfaf7e21950b9cb440f20904802a56cdebf2339d481.json create mode 100644 .sqlx/query-5c0fd685bdb1c165d74a62f64824948da760fa885ececf72a624f859f73fb38c.json create mode 100644 .sqlx/query-5d00c03eccbe17efc023bd0c6006e24f24239df0ba778666ca1751c9436aec8e.json create mode 100644 .sqlx/query-6750be49c1eb6546a217512ea61611ebf0d827fcbba365130308998a1b691c50.json create mode 100644 .sqlx/query-b3cae8e4b2468e9b0ae88ef135729532ee95c54bc1343706bc61d9e4bceed4d4.json create mode 100644 .sqlx/query-c0b17ec0369d1286c82515cef2e2db17b3b9754bb28b66bf001f0a3aa2655839.json create mode 100644 .sqlx/query-d10c9a7b0b391aeb8b4869f6bddf997807b66e0b532da747b146513c34e15c5c.json create mode 100644 .sqlx/query-dbcf7be91c6a4c92e865e35d565cf3da8e12f7179da8cb42fa0446b224f9dec4.json diff --git a/.sqlx/query-0648b5c0e4d4a4cd922ccaef80e06ff43f8ff063a6772586b79f1c524a848554.json b/.sqlx/query-0648b5c0e4d4a4cd922ccaef80e06ff43f8ff063a6772586b79f1c524a848554.json new file mode 100644 index 0000000000..8afd59ce08 --- /dev/null +++ b/.sqlx/query-0648b5c0e4d4a4cd922ccaef80e06ff43f8ff063a6772586b79f1c524a848554.json @@ -0,0 +1,52 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT id, \"network_id\",\"url\",\"hostname\",\"connected_at\",\"disconnected_at\" FROM \"gateway\" WHERE id = $1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "network_id", + "type_info": "Int8" + }, + { + "ordinal": 2, + "name": "url", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "hostname", + "type_info": "Text" + }, + { + "ordinal": 4, + "name": "connected_at", + "type_info": "Timestamp" + }, + { + "ordinal": 5, + "name": "disconnected_at", + "type_info": "Timestamp" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + false, + false, + false, + true, + true, + true + ] + }, + "hash": "0648b5c0e4d4a4cd922ccaef80e06ff43f8ff063a6772586b79f1c524a848554" +} diff --git a/.sqlx/query-0cd356bb88839bcc76d1fe3ed26719811a7c80e0954d6024f00b5ef883aeab3d.json b/.sqlx/query-0cd356bb88839bcc76d1fe3ed26719811a7c80e0954d6024f00b5ef883aeab3d.json new file mode 100644 index 0000000000..9476d09d17 --- /dev/null +++ b/.sqlx/query-0cd356bb88839bcc76d1fe3ed26719811a7c80e0954d6024f00b5ef883aeab3d.json @@ -0,0 +1,50 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT id, \"network_id\",\"url\",\"hostname\",\"connected_at\",\"disconnected_at\" FROM \"gateway\"", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "network_id", + "type_info": "Int8" + }, + { + "ordinal": 2, + "name": "url", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "hostname", + "type_info": "Text" + }, + { + "ordinal": 4, + "name": "connected_at", + "type_info": "Timestamp" + }, + { + "ordinal": 5, + "name": "disconnected_at", + "type_info": "Timestamp" + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + false, + false, + false, + true, + true, + true + ] + }, + "hash": "0cd356bb88839bcc76d1fe3ed26719811a7c80e0954d6024f00b5ef883aeab3d" +} diff --git a/.sqlx/query-1a4f2c99da8a0db6208abdfaf7e21950b9cb440f20904802a56cdebf2339d481.json b/.sqlx/query-1a4f2c99da8a0db6208abdfaf7e21950b9cb440f20904802a56cdebf2339d481.json new file mode 100644 index 0000000000..e02b54b479 --- /dev/null +++ b/.sqlx/query-1a4f2c99da8a0db6208abdfaf7e21950b9cb440f20904802a56cdebf2339d481.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "UPDATE wireguard_network SET connected_at = $2 WHERE name = $1", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Text", + "Timestamp" + ] + }, + "nullable": [] + }, + "hash": "1a4f2c99da8a0db6208abdfaf7e21950b9cb440f20904802a56cdebf2339d481" +} diff --git a/.sqlx/query-5c0fd685bdb1c165d74a62f64824948da760fa885ececf72a624f859f73fb38c.json b/.sqlx/query-5c0fd685bdb1c165d74a62f64824948da760fa885ececf72a624f859f73fb38c.json new file mode 100644 index 0000000000..522e152b81 --- /dev/null +++ b/.sqlx/query-5c0fd685bdb1c165d74a62f64824948da760fa885ececf72a624f859f73fb38c.json @@ -0,0 +1,14 @@ +{ + "db_name": "PostgreSQL", + "query": "DELETE FROM \"gateway\" WHERE id = $1", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [] + }, + "hash": "5c0fd685bdb1c165d74a62f64824948da760fa885ececf72a624f859f73fb38c" +} diff --git a/.sqlx/query-5d00c03eccbe17efc023bd0c6006e24f24239df0ba778666ca1751c9436aec8e.json b/.sqlx/query-5d00c03eccbe17efc023bd0c6006e24f24239df0ba778666ca1751c9436aec8e.json new file mode 100644 index 0000000000..58a9bd507f --- /dev/null +++ b/.sqlx/query-5d00c03eccbe17efc023bd0c6006e24f24239df0ba778666ca1751c9436aec8e.json @@ -0,0 +1,26 @@ +{ + "db_name": "PostgreSQL", + "query": "INSERT INTO \"gateway\" (\"network_id\",\"url\",\"hostname\",\"connected_at\",\"disconnected_at\") VALUES ($1,$2,$3,$4,$5) RETURNING id", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + } + ], + "parameters": { + "Left": [ + "Int8", + "Text", + "Text", + "Timestamp", + "Timestamp" + ] + }, + "nullable": [ + false + ] + }, + "hash": "5d00c03eccbe17efc023bd0c6006e24f24239df0ba778666ca1751c9436aec8e" +} diff --git a/.sqlx/query-6750be49c1eb6546a217512ea61611ebf0d827fcbba365130308998a1b691c50.json b/.sqlx/query-6750be49c1eb6546a217512ea61611ebf0d827fcbba365130308998a1b691c50.json new file mode 100644 index 0000000000..6c39c928d6 --- /dev/null +++ b/.sqlx/query-6750be49c1eb6546a217512ea61611ebf0d827fcbba365130308998a1b691c50.json @@ -0,0 +1,35 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT d.wireguard_pubkey pubkey, preshared_key, -- TODO possible to not use ARRAY-unnest here?\n ARRAY(\n SELECT host(ip)\n FROM unnest(wnd.wireguard_ips) AS ip\n ) \"allowed_ips!: Vec\" FROM wireguard_network_device wnd JOIN device d ON wnd.device_id = d.id JOIN \"user\" u ON d.user_id = u.id WHERE wireguard_network_id = $1 AND (is_authorized = true OR NOT $2) AND d.configured = true AND u.is_active = true ORDER BY d.id ASC", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "pubkey", + "type_info": "Text" + }, + { + "ordinal": 1, + "name": "preshared_key", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "allowed_ips!: Vec", + "type_info": "TextArray" + } + ], + "parameters": { + "Left": [ + "Int8", + "Bool" + ] + }, + "nullable": [ + false, + true, + null + ] + }, + "hash": "6750be49c1eb6546a217512ea61611ebf0d827fcbba365130308998a1b691c50" +} diff --git a/.sqlx/query-b3cae8e4b2468e9b0ae88ef135729532ee95c54bc1343706bc61d9e4bceed4d4.json b/.sqlx/query-b3cae8e4b2468e9b0ae88ef135729532ee95c54bc1343706bc61d9e4bceed4d4.json new file mode 100644 index 0000000000..85b60b0fb8 --- /dev/null +++ b/.sqlx/query-b3cae8e4b2468e9b0ae88ef135729532ee95c54bc1343706bc61d9e4bceed4d4.json @@ -0,0 +1,16 @@ +{ + "db_name": "PostgreSQL", + "query": "UPDATE gateway SET hostname = $2, connected_at = $3 WHERE id = $1", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8", + "Text", + "Timestamp" + ] + }, + "nullable": [] + }, + "hash": "b3cae8e4b2468e9b0ae88ef135729532ee95c54bc1343706bc61d9e4bceed4d4" +} diff --git a/.sqlx/query-c0b17ec0369d1286c82515cef2e2db17b3b9754bb28b66bf001f0a3aa2655839.json b/.sqlx/query-c0b17ec0369d1286c82515cef2e2db17b3b9754bb28b66bf001f0a3aa2655839.json new file mode 100644 index 0000000000..1c9ce00049 --- /dev/null +++ b/.sqlx/query-c0b17ec0369d1286c82515cef2e2db17b3b9754bb28b66bf001f0a3aa2655839.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "UPDATE gateway SET disconnected_at = $2 WHERE id = $1", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8", + "Timestamp" + ] + }, + "nullable": [] + }, + "hash": "c0b17ec0369d1286c82515cef2e2db17b3b9754bb28b66bf001f0a3aa2655839" +} diff --git a/.sqlx/query-d10c9a7b0b391aeb8b4869f6bddf997807b66e0b532da747b146513c34e15c5c.json b/.sqlx/query-d10c9a7b0b391aeb8b4869f6bddf997807b66e0b532da747b146513c34e15c5c.json new file mode 100644 index 0000000000..6fa0952987 --- /dev/null +++ b/.sqlx/query-d10c9a7b0b391aeb8b4869f6bddf997807b66e0b532da747b146513c34e15c5c.json @@ -0,0 +1,52 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT * FROM gateway WHERE network_id = $1 ORDER BY id", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "network_id", + "type_info": "Int8" + }, + { + "ordinal": 2, + "name": "url", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "hostname", + "type_info": "Text" + }, + { + "ordinal": 4, + "name": "connected_at", + "type_info": "Timestamp" + }, + { + "ordinal": 5, + "name": "disconnected_at", + "type_info": "Timestamp" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + false, + false, + false, + true, + true, + true + ] + }, + "hash": "d10c9a7b0b391aeb8b4869f6bddf997807b66e0b532da747b146513c34e15c5c" +} diff --git a/.sqlx/query-dbcf7be91c6a4c92e865e35d565cf3da8e12f7179da8cb42fa0446b224f9dec4.json b/.sqlx/query-dbcf7be91c6a4c92e865e35d565cf3da8e12f7179da8cb42fa0446b224f9dec4.json new file mode 100644 index 0000000000..f5f9307d7d --- /dev/null +++ b/.sqlx/query-dbcf7be91c6a4c92e865e35d565cf3da8e12f7179da8cb42fa0446b224f9dec4.json @@ -0,0 +1,19 @@ +{ + "db_name": "PostgreSQL", + "query": "UPDATE \"gateway\" SET \"network_id\" = $2,\"url\" = $3,\"hostname\" = $4,\"connected_at\" = $5,\"disconnected_at\" = $6 WHERE id = $1", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8", + "Int8", + "Text", + "Text", + "Timestamp", + "Timestamp" + ] + }, + "nullable": [] + }, + "hash": "dbcf7be91c6a4c92e865e35d565cf3da8e12f7179da8cb42fa0446b224f9dec4" +}