diff --git a/src/enterprise/firewall/packetfilter/calls.rs b/src/enterprise/firewall/packetfilter/calls.rs index b664950f..b69c9b27 100644 --- a/src/enterprise/firewall/packetfilter/calls.rs +++ b/src/enterprise/firewall/packetfilter/calls.rs @@ -329,13 +329,13 @@ impl Pool { } /// Insert `PoolAddr` at the end of the list. Take ownership of the given `PoolAddr`. - pub(super) fn insert_pool_addr(&mut self, mut pool_addr: PoolAddr) { + pub(super) fn insert_pool_addr(&mut self, pool_addr: &mut PoolAddr) { // TODO: Traverse tail queue; for now assume empty tail queue. assert!( self.list.tqh_first.is_null(), "Expected one entry in PoolAddr TailQueue." ); - self.list.tqh_first = &raw mut pool_addr; + self.list.tqh_first = &raw mut *pool_addr; self.list.tqh_last = &raw mut pool_addr.entries.tqe_next; pool_addr.entries.tqe_next = ptr::null_mut(); pool_addr.entries.tqe_prev = &raw mut self.list.tqh_first; diff --git a/src/gateway.rs b/src/gateway.rs index 1f338e53..eee594f1 100644 --- a/src/gateway.rs +++ b/src/gateway.rs @@ -65,7 +65,6 @@ impl From for InterfaceConfiguration { } } -type ClientMap = HashMap>>; type PubKey = String; pub struct Gateway { @@ -76,8 +75,8 @@ pub struct Gateway { firewall_api: FirewallApi, firewall_config: Option, pub connected: Arc, - // TODO: allow only one client. - pub(super) clients: ClientMap, + // Transmission channel. Important: allows only one connected client. + client_tx: Option>>, } impl Gateway { @@ -94,7 +93,7 @@ impl Gateway { firewall_api, firewall_config: None, connected: Arc::new(AtomicBool::new(false)), - clients: ClientMap::new(), + client_tx: None, }) } @@ -337,12 +336,12 @@ impl Gateway { Ok(()) } - /// Send message to all connected clients. - fn broadcast_to_clients(&self, message: &CoreRequest) { - for (addr, tx) in &self.clients { - if tx.send(Ok(message.clone())).is_err() { - debug!("Failed to send message to {addr}"); - } + /// Send message to the connected client. + fn send_to_client(&self, message: &CoreRequest) { + if let Some(tx) = &self.client_tx + && tx.send(Ok(message.clone())).is_err() + { + debug!("Failed to send message to Core."); } } @@ -568,6 +567,18 @@ impl gateway_server::Gateway for GatewayServer { return Err(Status::internal("Unsupported Defguard Core version")); } + // Drop new connections if another Core has already been connected. + if self + .gateway + .lock() + .expect("Gateway lock poison") + .client_tx + .is_some() + { + error!("Only one client connection is allowed."); + return Err(Status::internal("Client already connected")); + } + let (tx, rx) = mpsc::unbounded_channel(); let Ok(hostname) = gethostname().into_string() else { error!("Unable to get hostname"); @@ -594,7 +605,7 @@ impl gateway_server::Gateway for GatewayServer { } } - self.gateway.lock().unwrap().clients.insert(address, tx); + self.gateway.lock().expect("Gateway lock poison").client_tx = Some(tx); let gateway = Arc::clone(&self.gateway); let mut stream = request.into_inner(); @@ -636,12 +647,10 @@ impl gateway_server::Gateway for GatewayServer { } } info!("Defguard Core gRPC stream has been disconnected: {address}"); - gateway - .lock() - .unwrap() - .connected - .store(false, Ordering::Relaxed); - gateway.lock().unwrap().clients.remove(&address); + if let Ok(mut gateway) = gateway.lock() { + gateway.connected.store(false, Ordering::Relaxed); + gateway.client_tx = None; + } }); Ok(Response::new(UnboundedReceiverStream::new(rx))) @@ -693,7 +702,7 @@ pub async fn run_stats(gateway: Arc>, period: Duration) -> Result gateway .lock() .expect("gateway mutex poison") - .broadcast_to_clients(&message); + .send_to_client(&message); debug!("Sent statistics for peer {}", peer.public_key); } else { debug!( @@ -767,7 +776,7 @@ mod tests { firewall_api, firewall_config: None, connected: Arc::new(AtomicBool::new(false)), - clients: ClientMap::new(), + client_tx: None, }; // new config is the same @@ -952,7 +961,7 @@ mod tests { firewall_api: FirewallApi::new("test_interface").unwrap(), firewall_config: None, connected: Arc::new(AtomicBool::new(false)), - clients: ClientMap::new(), + client_tx: None, }; // Gateway has no firewall config, new rules are empty @@ -1018,7 +1027,7 @@ mod tests { firewall_api: FirewallApi::new("test_interface").unwrap(), firewall_config: None, connected: Arc::new(AtomicBool::new(false)), - clients: ClientMap::new(), + client_tx: None, }; // Gateway has no config gateway.firewall_config = None;