diff --git a/crates/defguard_core/src/db/models/user.rs b/crates/defguard_core/src/db/models/user.rs index 9c4061d0d4..d087712971 100644 --- a/crates/defguard_core/src/db/models/user.rs +++ b/crates/defguard_core/src/db/models/user.rs @@ -435,18 +435,19 @@ impl User { let gateway_events = network .sync_allowed_devices_for_user(&mut *conn, self, None) .await?; + // check if any peers were updated if !gateway_events.is_empty() { // send peer update events send_multiple_wireguard_events(gateway_events, wg_tx); + } - // send firewall config update if ACLs & enterprise features are enabled - if let Some(firewall_config) = network.try_get_firewall_config(&mut *conn).await? { - send_wireguard_event( - GatewayEvent::FirewallConfigChanged(network.id, firewall_config), - wg_tx, - ); - } + // send firewall config update if ACLs & enterprise features are enabled + if let Some(firewall_config) = network.try_get_firewall_config(&mut *conn).await? { + send_wireguard_event( + GatewayEvent::FirewallConfigChanged(network.id, firewall_config), + wg_tx, + ); } } info!("Allowed devices of user {} synced", self.username); diff --git a/crates/defguard_core/src/grpc/enrollment.rs b/crates/defguard_core/src/grpc/enrollment.rs index 6e8a53ed38..9cb362bb06 100644 --- a/crates/defguard_core/src/grpc/enrollment.rs +++ b/crates/defguard_core/src/grpc/enrollment.rs @@ -1,3 +1,5 @@ +use std::collections::HashSet; + use sqlx::{PgPool, Transaction}; use tokio::sync::{ broadcast::Sender, @@ -20,7 +22,7 @@ use crate::{ enrollment::{Token, TokenError, ENROLLMENT_TOKEN_TYPE}, polling_token::PollingToken, }, - Device, GatewayEvent, Id, Settings, User, + Device, GatewayEvent, Id, Settings, User, WireguardNetwork, }, enterprise::{ db::models::enterprise_settings::EnterpriseSettings, ldap::utils::ldap_add_user, @@ -274,7 +276,7 @@ impl EnrollmentServer { ip_address = String::new(); device_info = None; } - debug!("IP address {}, device info {device_info:?}", ip_address); + debug!("IP address {ip_address}, device info {device_info:?}"); // check if password is strong enough debug!("Verifying password strength for user activation process."); @@ -588,6 +590,39 @@ impl EnrollmentServer { (device, network_info, configs) }; + // get all locations affected by device being added + let mut affected_location_ids = HashSet::new(); + for network_info_item in network_info.clone() { + affected_location_ids.insert(network_info_item.network_id); + } + + // send firewall config updates to affected locations + // if they have ACL enabled & enterprise features are active + for location_id in affected_location_ids { + if let Some(location) = WireguardNetwork::find_by_id(&mut *transaction, location_id) + .await + .map_err(|err| { + error!("Failed to fetch WireguardNetwork with ID {location_id}: {err}",); + Status::internal("unexpected error") + })? + { + if let Some(firewall_config) = location + .try_get_firewall_config(&mut transaction) + .await + .map_err(|err| { + error!("Failed to get firewall config for location {location}: {err}",); + Status::internal("unexpected error") + })? + { + debug!("Sending firewall config update for location {location} affected by adding new device {}, user {}({})", device.wireguard_pubkey, user.username, user.id); + self.send_wireguard_event(GatewayEvent::FirewallConfigChanged( + location_id, + firewall_config, + )); + } + } + } + debug!( "Sending DeviceCreated event to gateway for device {}, user {}({:?})", device.wireguard_pubkey, user.username, user.id,