Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/enterprise/firewall/packetfilter/calls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,13 +329,13 @@ impl Pool {
}

/// Insert `PoolAddr` at the end of the list. Take ownership of the given `PoolAddr`.
pub(super) fn insert_pool_addr(&mut self, mut pool_addr: PoolAddr) {
pub(super) fn insert_pool_addr(&mut self, pool_addr: &mut PoolAddr) {
// TODO: Traverse tail queue; for now assume empty tail queue.
assert!(
self.list.tqh_first.is_null(),
"Expected one entry in PoolAddr TailQueue."
);
self.list.tqh_first = &raw mut pool_addr;
self.list.tqh_first = &raw mut *pool_addr;
self.list.tqh_last = &raw mut pool_addr.entries.tqe_next;
pool_addr.entries.tqe_next = ptr::null_mut();
pool_addr.entries.tqe_prev = &raw mut self.list.tqh_first;
Expand Down
51 changes: 30 additions & 21 deletions src/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ impl From<Configuration> for InterfaceConfiguration {
}
}

type ClientMap = HashMap<SocketAddr, mpsc::UnboundedSender<Result<CoreRequest, Status>>>;
type PubKey = String;

pub struct Gateway {
Expand All @@ -76,8 +75,8 @@ pub struct Gateway {
firewall_api: FirewallApi,
firewall_config: Option<FirewallConfig>,
pub connected: Arc<AtomicBool>,
// TODO: allow only one client.
pub(super) clients: ClientMap,
// Transmission channel. Important: allows only one connected client.
client_tx: Option<mpsc::UnboundedSender<Result<CoreRequest, Status>>>,
}

impl Gateway {
Expand All @@ -94,7 +93,7 @@ impl Gateway {
firewall_api,
firewall_config: None,
connected: Arc::new(AtomicBool::new(false)),
clients: ClientMap::new(),
client_tx: None,
})
}

Expand Down Expand Up @@ -337,12 +336,12 @@ impl Gateway {
Ok(())
}

/// Send message to all connected clients.
fn broadcast_to_clients(&self, message: &CoreRequest) {
for (addr, tx) in &self.clients {
if tx.send(Ok(message.clone())).is_err() {
debug!("Failed to send message to {addr}");
}
/// Send message to the connected client.
fn send_to_client(&self, message: &CoreRequest) {
if let Some(tx) = &self.client_tx
&& tx.send(Ok(message.clone())).is_err()
{
debug!("Failed to send message to Core.");
}
}

Expand Down Expand Up @@ -568,6 +567,18 @@ impl gateway_server::Gateway for GatewayServer {
return Err(Status::internal("Unsupported Defguard Core version"));
}

// Drop new connections if another Core has already been connected.
if self
.gateway
.lock()
.expect("Gateway lock poison")
.client_tx
.is_some()
{
error!("Only one client connection is allowed.");
return Err(Status::internal("Client already connected"));
}

let (tx, rx) = mpsc::unbounded_channel();
let Ok(hostname) = gethostname().into_string() else {
error!("Unable to get hostname");
Expand All @@ -594,7 +605,7 @@ impl gateway_server::Gateway for GatewayServer {
}
}

self.gateway.lock().unwrap().clients.insert(address, tx);
self.gateway.lock().expect("Gateway lock poison").client_tx = Some(tx);

let gateway = Arc::clone(&self.gateway);
let mut stream = request.into_inner();
Expand Down Expand Up @@ -636,12 +647,10 @@ impl gateway_server::Gateway for GatewayServer {
}
}
info!("Defguard Core gRPC stream has been disconnected: {address}");
gateway
.lock()
.unwrap()
.connected
.store(false, Ordering::Relaxed);
gateway.lock().unwrap().clients.remove(&address);
if let Ok(mut gateway) = gateway.lock() {
gateway.connected.store(false, Ordering::Relaxed);
gateway.client_tx = None;
}
});

Ok(Response::new(UnboundedReceiverStream::new(rx)))
Expand Down Expand Up @@ -693,7 +702,7 @@ pub async fn run_stats(gateway: Arc<Mutex<Gateway>>, period: Duration) -> Result
gateway
.lock()
.expect("gateway mutex poison")
.broadcast_to_clients(&message);
.send_to_client(&message);
debug!("Sent statistics for peer {}", peer.public_key);
} else {
debug!(
Expand Down Expand Up @@ -767,7 +776,7 @@ mod tests {
firewall_api,
firewall_config: None,
connected: Arc::new(AtomicBool::new(false)),
clients: ClientMap::new(),
client_tx: None,
};

// new config is the same
Expand Down Expand Up @@ -952,7 +961,7 @@ mod tests {
firewall_api: FirewallApi::new("test_interface").unwrap(),
firewall_config: None,
connected: Arc::new(AtomicBool::new(false)),
clients: ClientMap::new(),
client_tx: None,
};

// Gateway has no firewall config, new rules are empty
Expand Down Expand Up @@ -1018,7 +1027,7 @@ mod tests {
firewall_api: FirewallApi::new("test_interface").unwrap(),
firewall_config: None,
connected: Arc::new(AtomicBool::new(false)),
clients: ClientMap::new(),
client_tx: None,
};
// Gateway has no config
gateway.firewall_config = None;
Expand Down