Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion crates/defguard_core/src/db/models/wireguard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ impl From<ServiceLocationMode> for ProtoServiceLocationMode {
}

/// Stores configuration required to setup a WireGuard network
#[derive(Clone, Debug, Deserialize, Eq, Hash, Model, PartialEq, Serialize, ToSchema)]
#[derive(Clone, Deserialize, Eq, Hash, Model, PartialEq, Serialize, ToSchema)]
#[table(wireguard_network)]
pub struct WireguardNetwork<I = NoId> {
pub id: I,
Expand Down Expand Up @@ -207,6 +207,29 @@ impl fmt::Display for WireguardNetwork<Id> {
}
}

impl fmt::Debug for WireguardNetwork<Id> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WireguardNetwork")
.field("id", &self.id)
.field("name", &self.name)
.field("address", &self.address)
.field("port", &self.port)
.field("pubkey", &self.pubkey)
.field("prvkey", &"***")
.field("endpoint", &self.endpoint)
.field("dns", &self.dns)
.field("allowed_ips", &self.allowed_ips)
.field("connected_at", &self.connected_at)
.field("acl_enabled", &self.acl_enabled)
.field("acl_default_allow", &self.acl_default_allow)
.field("keepalive_interval", &self.keepalive_interval)
.field("peer_disconnect_threshold", &self.peer_disconnect_threshold)
.field("location_mfa_mode", &self.location_mfa_mode)
.field("service_location_mode", &self.service_location_mode)
.finish()
}
}

#[cfg(test)]
impl Default for WireguardNetwork<Id> {
fn default() -> Self {
Expand Down
38 changes: 20 additions & 18 deletions crates/defguard_core/src/grpc/gateway/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,25 +239,18 @@ impl GatewayServer {
Ok(self.grpc_event_tx.send(event)?)
}

/// Helper method to fetch `Device` info from DB and return appropriate errors
async fn fetch_device_from_db(&self, public_key: &str) -> Result<Device<Id>, Status> {
let device = match Device::find_by_pubkey(&self.pool, public_key).await {
Ok(Some(device)) => device,
Ok(None) => {
error!("Device with public key {public_key} not found");
return Err(Status::new(
Code::Internal,
format!("Device with public key {public_key} not found"),
));
}
Err(err) => {
/// Helper method to fetch `Device` info from DB by pubkey and return appropriate errors
async fn fetch_device_from_db(&self, public_key: &str) -> Result<Option<Device<Id>>, Status> {
let device = Device::find_by_pubkey(&self.pool, public_key)
.await
.map_err(|err| {
error!("Failed to retrieve device with public key {public_key}: {err}",);
return Err(Status::new(
Status::new(
Code::Internal,
format!("Failed to retrieve device with public key {public_key}: {err}",),
));
}
};
)
})?;

Ok(device)
}

Expand Down Expand Up @@ -825,8 +818,17 @@ impl gateway_service_server::GatewayService for GatewayServer {

// fetch device from DB
// TODO: fetch only when device has changed and use client state otherwise
let device = self.fetch_device_from_db(&public_key).await?;
// copy for easier reference later
let device = match self.fetch_device_from_db(&public_key).await? {
Some(device) => device,
None => {
warn!(
"Received stats update for a device which does not exist: {public_key}, skipping."
);
continue;
}
};

// copy device ID for easier reference later
let device_id = device.id;

// fetch user and location from DB for activity log
Expand Down
114 changes: 114 additions & 0 deletions crates/defguard_core/tests/integration/grpc/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -558,3 +558,117 @@ async fn test_gateway_version_validation(_: PgPoolOptions, options: PgConnectOpt
let status = response.err().unwrap();
assert_eq!(status.code(), Code::FailedPrecondition);
}

// https://github.com/DefGuard/defguard/issues/1671
#[sqlx::test]
async fn test_device_pubkey_change(_: PgPoolOptions, options: PgConnectOptions) {
let pool = setup_pool(options).await;
let (mut test_server, mut gateway, test_location, test_user) =
setup_test_server(pool.clone()).await;

// initial client map is empty
{
let client_map = test_server.get_client_map();
assert!(client_map.is_empty())
}

// connect stats stream
let stats_tx = gateway.setup_stats_update_stream().await;
let mut update_id = 1;

// add user device
let device_pubkey = "wYOt6ImBaQ3BEMQ3Xf5P5fTnbqwOvjcqYkkSBt+1xOg=";
let mut test_device = Device::new(
"test device".into(),
device_pubkey.into(),
test_user.id,
DeviceType::User,
None,
true,
)
.save(&pool)
.await
.unwrap();

// send stats update for existing device
stats_tx
.send(StatsUpdate {
id: update_id,
payload: Some(Payload::PeerStats(PeerStats {
public_key: device_pubkey.into(),
endpoint: "1.2.3.4:1234".into(),
latest_handshake: Utc::now().timestamp() as u64,
..Default::default()
})),
})
.expect("failed to send stats update");

// wait for event to be emitted
sleep(Duration::from_millis(100)).await;
let grpc_event = test_server
.grpc_event_rx
.try_recv()
.expect("failed to receive gRPC event");
assert_matches!(
grpc_event,
GrpcEvent::ClientConnected {
context: _,
location,
device
} if ((location.id == test_location.id) & (device.id == test_device.id))
);

// change device pubkey
let new_device_pubkey = "TJG2T6rhndZtk06KnIIOlD6hhd7wpVkBss8sfyvMCAA=";
test_device.wireguard_pubkey = new_device_pubkey.to_owned();
test_device.save(&pool).await.unwrap();

// send stats update with old pubkey
update_id += 1;
stats_tx
.send(StatsUpdate {
id: update_id,
payload: Some(Payload::PeerStats(PeerStats {
public_key: device_pubkey.into(),
endpoint: "1.2.3.4:1234".into(),
latest_handshake: Utc::now().timestamp() as u64,
..Default::default()
})),
})
.expect("failed to send stats update");

// no event should be emitted
sleep(Duration::from_millis(100)).await;
assert_err_eq!(test_server.grpc_event_rx.try_recv(), TryRecvError::Empty);

// send stats update with new pubkey
update_id += 1;
stats_tx
.send(StatsUpdate {
id: update_id,
payload: Some(Payload::PeerStats(PeerStats {
public_key: new_device_pubkey.into(),
endpoint: "1.2.3.4:1234".into(),
latest_handshake: Utc::now().timestamp() as u64,
..Default::default()
})),
})
.expect("failed to send stats update");

// wait for event
// FIXME: ideally this should not be emitted; we'll fix it once we implement a more robust VPN session logic
sleep(Duration::from_millis(100)).await;
let grpc_event = test_server
.grpc_event_rx
.try_recv()
.expect("failed to receive gRPC event");

assert_matches!(
grpc_event,
GrpcEvent::ClientConnected {
context: _,
location,
device
} if ((location.id == test_location.id) & (device.id == test_device.id))
);
}