diff --git a/crates/defguard_common/src/db/models/authentication_key.rs b/crates/defguard_common/src/db/models/authentication_key.rs index a1bc189586..78b44b3c58 100644 --- a/crates/defguard_common/src/db/models/authentication_key.rs +++ b/crates/defguard_common/src/db/models/authentication_key.rs @@ -6,7 +6,7 @@ use sqlx::{Error as SqlxError, PgExecutor, Type, query_as}; use crate::db::{Id, NoId}; -#[derive(Clone, Debug, Deserialize, Serialize, Type)] +#[derive(Clone, Debug, Deserialize, Serialize, Type, PartialEq)] #[sqlx(type_name = "authentication_key_type", rename_all = "lowercase")] #[serde(rename_all = "lowercase")] pub enum AuthenticationKeyType { @@ -23,7 +23,7 @@ impl Display for AuthenticationKeyType { } } -#[derive(Clone, Debug, Deserialize, Model, Serialize)] +#[derive(Clone, Debug, Deserialize, Model, Serialize, PartialEq)] #[table(authentication_key)] pub struct AuthenticationKey { pub id: I, diff --git a/crates/defguard_core/src/db/models/device.rs b/crates/defguard_core/src/db/models/device.rs index 00c895ff15..2c6433ca2e 100644 --- a/crates/defguard_core/src/db/models/device.rs +++ b/crates/defguard_core/src/db/models/device.rs @@ -3,7 +3,7 @@ use std::{fmt, net::IpAddr}; use base64::{Engine, prelude::BASE64_STANDARD}; #[cfg(test)] use chrono::NaiveDate; -use chrono::{NaiveDateTime, Utc}; +use chrono::{NaiveDateTime, Timelike, Utc}; use defguard_common::{ csv::AsCsv, db::{Id, NoId, models::ModelError}, @@ -532,12 +532,20 @@ impl Device { description: Option, configured: bool, ) -> Self { + // FIXME: this is a workaround for reducing timestamp precision. + // `chrono` has nanosecond precision by default, while Postgres only does microseconds. + // It avoids issues when comparing to objects fetched from DB. + let created = Utc::now().naive_utc(); + let created = created + .with_nanosecond((created.nanosecond() / 1_000) * 1_000) + .expect("failed to truncate timestamp precision"); + Self { id: NoId, name, wireguard_pubkey, user_id, - created: Utc::now().naive_utc(), + created, device_type, description, configured, diff --git a/crates/defguard_core/src/db/models/oauth2client.rs b/crates/defguard_core/src/db/models/oauth2client.rs index 5f695618b5..6e8fae8428 100644 --- a/crates/defguard_core/src/db/models/oauth2client.rs +++ b/crates/defguard_core/src/db/models/oauth2client.rs @@ -7,7 +7,7 @@ use super::NewOpenIDClient; use defguard_common::db::{Id, NoId}; use defguard_common::random::gen_alphanumeric; -#[derive(Clone, Debug, Deserialize, Model, Serialize)] +#[derive(Clone, Debug, Deserialize, Model, Serialize, PartialEq)] pub struct OAuth2Client { pub id: I, pub client_id: String, // unique diff --git a/crates/defguard_core/src/db/models/webauthn.rs b/crates/defguard_core/src/db/models/webauthn.rs index 9b4407d67f..8cf61dba92 100644 --- a/crates/defguard_core/src/db/models/webauthn.rs +++ b/crates/defguard_core/src/db/models/webauthn.rs @@ -4,7 +4,7 @@ use webauthn_rs::prelude::Passkey; use defguard_common::db::{Id, NoId, models::ModelError}; -#[derive(Model, Clone, Debug)] +#[derive(Model, Clone, Debug, PartialEq)] pub struct WebAuthn { pub id: I, pub user_id: Id, diff --git a/crates/defguard_core/src/db/models/webhook.rs b/crates/defguard_core/src/db/models/webhook.rs index f45d9831d7..be3e2d26a4 100644 --- a/crates/defguard_core/src/db/models/webhook.rs +++ b/crates/defguard_core/src/db/models/webhook.rs @@ -47,7 +47,7 @@ impl AppEvent { } } -#[derive(Clone, Debug, Deserialize, FromRow, Model, Serialize)] +#[derive(Clone, Debug, Deserialize, FromRow, Model, Serialize, PartialEq)] pub struct WebHook { pub id: I, pub url: String, diff --git a/crates/defguard_core/src/enterprise/db/models/activity_log_stream.rs b/crates/defguard_core/src/enterprise/db/models/activity_log_stream.rs index 2d472bec2c..d6adb11fc4 100644 --- a/crates/defguard_core/src/enterprise/db/models/activity_log_stream.rs +++ b/crates/defguard_core/src/enterprise/db/models/activity_log_stream.rs @@ -9,7 +9,7 @@ use defguard_common::{ secret::SecretStringWrapper, }; -#[derive(Debug, Serialize, Deserialize, Type, EnumString, Display, Clone)] +#[derive(Debug, Serialize, Deserialize, Type, EnumString, Display, Clone, PartialEq)] #[sqlx(type_name = "text", rename_all = "snake_case")] #[serde(rename_all = "snake_case")] pub enum ActivityLogStreamType { @@ -19,7 +19,7 @@ pub enum ActivityLogStreamType { LogstashHttp, } -#[derive(Clone, Debug, Serialize, Model, FromRow)] +#[derive(Clone, Debug, Serialize, Model, FromRow, PartialEq)] #[table(activity_log_stream)] pub struct ActivityLogStream { pub id: I, diff --git a/crates/defguard_core/src/enterprise/db/models/api_tokens.rs b/crates/defguard_core/src/enterprise/db/models/api_tokens.rs index 9e01c7248c..5b44881fc8 100644 --- a/crates/defguard_core/src/enterprise/db/models/api_tokens.rs +++ b/crates/defguard_core/src/enterprise/db/models/api_tokens.rs @@ -4,7 +4,7 @@ use sqlx::{Error as SqlxError, PgExecutor, query_as}; use defguard_common::db::{Id, NoId}; -#[derive(Clone, Debug, Deserialize, Model, Serialize)] +#[derive(Clone, Debug, Deserialize, Model, Serialize, PartialEq)] #[table(api_token)] pub struct ApiToken { pub id: I, diff --git a/crates/defguard_core/src/enterprise/db/models/openid_provider.rs b/crates/defguard_core/src/enterprise/db/models/openid_provider.rs index 1d6a563268..1ffdde6830 100644 --- a/crates/defguard_core/src/enterprise/db/models/openid_provider.rs +++ b/crates/defguard_core/src/enterprise/db/models/openid_provider.rs @@ -87,7 +87,7 @@ impl From for DirectorySyncTarget { } } -#[derive(Clone, Debug, Deserialize, Model, Serialize)] +#[derive(Clone, Debug, Deserialize, Model, Serialize, PartialEq)] pub struct OpenIdProvider { pub id: I, pub name: String, diff --git a/crates/defguard_core/src/enterprise/db/models/snat.rs b/crates/defguard_core/src/enterprise/db/models/snat.rs index 68b7b71c56..e8ed98df8e 100644 --- a/crates/defguard_core/src/enterprise/db/models/snat.rs +++ b/crates/defguard_core/src/enterprise/db/models/snat.rs @@ -8,7 +8,7 @@ use utoipa::ToSchema; use crate::enterprise::snat::error::UserSnatBindingError; use defguard_common::db::{Id, NoId}; -#[derive(Clone, Debug, Deserialize, Model, Serialize, ToSchema)] +#[derive(Clone, Debug, Deserialize, Model, Serialize, ToSchema, PartialEq)] #[table(user_snat_binding)] pub struct UserSnatBinding { pub id: I, diff --git a/crates/defguard_core/src/events.rs b/crates/defguard_core/src/events.rs index 68bc260a72..eb763a6c84 100644 --- a/crates/defguard_core/src/events.rs +++ b/crates/defguard_core/src/events.rs @@ -23,7 +23,7 @@ use defguard_proto::proxy::MfaMethod; /// Mainly meant to be stored in the activity log. /// By design this is a duplicate of a similar struct in the `event_logger` module. /// This is done in order to avoid circular imports once we split the project into multiple crates. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct ApiRequestContext { pub timestamp: NaiveDateTime, pub user_id: Id, @@ -83,7 +83,7 @@ impl GrpcRequestContext { } } -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum ApiEventType { UserLogin, UserLoginFailed { diff --git a/crates/defguard_core/tests/integration/api/acl.rs b/crates/defguard_core/tests/integration/api/acl.rs index 237b573fdd..352c92b54b 100644 --- a/crates/defguard_core/tests/integration/api/acl.rs +++ b/crates/defguard_core/tests/integration/api/acl.rs @@ -137,8 +137,8 @@ fn edit_alias_data_into_api_response( async fn test_rule_crud(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let (client, _) = make_test_client(pool).await; - authenticate_admin(&client).await; + let (mut client, _) = make_test_client(pool).await; + authenticate_admin(&mut client).await; let rule = make_rule(); @@ -189,8 +189,8 @@ async fn test_rule_crud(_: PgPoolOptions, options: PgConnectOptions) { async fn test_rule_enterprise(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let (client, _) = make_test_client(pool).await; - authenticate_admin(&client).await; + let (mut client, _) = make_test_client(pool).await; + authenticate_admin(&mut client).await; exceed_enterprise_limits(&client).await; @@ -230,8 +230,8 @@ async fn test_rule_enterprise(_: PgPoolOptions, options: PgConnectOptions) { async fn test_alias_crud(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let (client, _) = make_test_client(pool).await; - authenticate_admin(&client).await; + let (mut client, _) = make_test_client(pool).await; + authenticate_admin(&mut client).await; let alias = make_alias(); @@ -284,8 +284,8 @@ async fn test_alias_crud(_: PgPoolOptions, options: PgConnectOptions) { async fn test_alias_enterprise(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let (client, _) = make_test_client(pool).await; - authenticate_admin(&client).await; + let (mut client, _) = make_test_client(pool).await; + authenticate_admin(&mut client).await; exceed_enterprise_limits(&client).await; @@ -325,8 +325,8 @@ async fn test_alias_enterprise(_: PgPoolOptions, options: PgConnectOptions) { async fn test_empty_strings(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let (client, _) = make_test_client(pool).await; - authenticate_admin(&client).await; + let (mut client, _) = make_test_client(pool).await; + authenticate_admin(&mut client).await; // rule let mut rule = make_rule(); @@ -409,8 +409,8 @@ async fn test_related_objects(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; let config = init_config(None); - let client = make_client_v2(pool.clone(), config).await; - authenticate_admin(&client).await; + let mut client = make_client_v2(pool.clone(), config).await; + authenticate_admin(&mut client).await; // create related objects // networks @@ -541,8 +541,8 @@ async fn test_related_objects(_: PgPoolOptions, options: PgConnectOptions) { async fn test_invalid_related_objects(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let (client, state) = make_test_client(pool).await; - authenticate_admin(&client).await; + let (mut client, state) = make_test_client(pool).await; + authenticate_admin(&mut client).await; let rule = make_rule(); let response = client.post("/api/v1/acl/rule").json(&rule).send().await; @@ -644,8 +644,8 @@ async fn test_invalid_related_objects(_: PgPoolOptions, options: PgConnectOption async fn test_invalid_data(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let (client, _) = make_test_client(pool).await; - authenticate_admin(&client).await; + let (mut client, _) = make_test_client(pool).await; + authenticate_admin(&mut client).await; // invalid port let mut rule = make_rule(); @@ -677,8 +677,8 @@ async fn test_rule_create_modify_state(_: PgPoolOptions, options: PgConnectOptio let pool = setup_pool(options).await; let config = init_config(None); - let client = make_client_v2(pool.clone(), config).await; - authenticate_admin(&client).await; + let mut client = make_client_v2(pool.clone(), config).await; + authenticate_admin(&mut client).await; let rule = make_rule(); @@ -732,8 +732,8 @@ async fn test_rule_delete_state_new(_: PgPoolOptions, options: PgConnectOptions) let pool = setup_pool(options).await; let config = init_config(None); - let client = make_client_v2(pool.clone(), config).await; - authenticate_admin(&client).await; + let mut client = make_client_v2(pool.clone(), config).await; + authenticate_admin(&mut client).await; // test NEW rule deletion let rule = make_rule(); @@ -751,8 +751,8 @@ async fn test_rule_delete_state_applied(_: PgPoolOptions, options: PgConnectOpti let pool = setup_pool(options).await; let config = init_config(None); - let client = make_client_v2(pool.clone(), config).await; - authenticate_admin(&client).await; + let mut client = make_client_v2(pool.clone(), config).await; + authenticate_admin(&mut client).await; // create a location WireguardNetwork::new( @@ -812,8 +812,8 @@ async fn test_rule_duplication(_: PgPoolOptions, options: PgConnectOptions) { // each modification / deletion of parent rule should remove the child and create a new one let config = init_config(None); - let client = make_client_v2(pool.clone(), config).await; - authenticate_admin(&client).await; + let mut client = make_client_v2(pool.clone(), config).await; + authenticate_admin(&mut client).await; let rule = make_rule(); let response = client.post("/api/v1/acl/rule").json(&rule).send().await; @@ -842,8 +842,8 @@ async fn test_rule_application(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; let config = init_config(None); - let client = make_client_v2(pool.clone(), config).await; - authenticate_admin(&client).await; + let mut client = make_client_v2(pool.clone(), config).await; + authenticate_admin(&mut client).await; let rule = make_rule(); @@ -934,8 +934,8 @@ async fn test_multiple_rules_application(_: PgPoolOptions, options: PgConnectOpt let pool = setup_pool(options).await; let config = init_config(None); - let client = make_client_v2(pool.clone(), config).await; - authenticate_admin(&client).await; + let mut client = make_client_v2(pool.clone(), config).await; + authenticate_admin(&mut client).await; let rule_1 = make_rule(); let rule_2 = make_rule(); @@ -972,8 +972,8 @@ async fn test_alias_create_modify_state(_: PgPoolOptions, options: PgConnectOpti let pool = setup_pool(options).await; let config = init_config(None); - let client = make_client_v2(pool.clone(), config).await; - authenticate_admin(&client).await; + let mut client = make_client_v2(pool.clone(), config).await; + authenticate_admin(&mut client).await; let alias = make_alias(); @@ -1012,8 +1012,8 @@ async fn test_alias_delete(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; let config = init_config(None); - let client = make_client_v2(pool.clone(), config).await; - authenticate_admin(&client).await; + let mut client = make_client_v2(pool.clone(), config).await; + authenticate_admin(&mut client).await; // create alias let alias = make_alias(); @@ -1078,8 +1078,8 @@ async fn test_alias_duplication(_: PgPoolOptions, options: PgConnectOptions) { // each modification of parent alias should remove the child and create a new one let config = init_config(None); - let client = make_client_v2(pool.clone(), config).await; - authenticate_admin(&client).await; + let mut client = make_client_v2(pool.clone(), config).await; + authenticate_admin(&mut client).await; let alias = make_alias(); let response = client.post("/api/v1/acl/alias").json(&alias).send().await; @@ -1104,8 +1104,8 @@ async fn test_alias_application(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; let config = init_config(None); - let client = make_client_v2(pool.clone(), config).await; - authenticate_admin(&client).await; + let mut client = make_client_v2(pool.clone(), config).await; + authenticate_admin(&mut client).await; // create new alias let alias = make_alias(); @@ -1165,8 +1165,8 @@ async fn test_multiple_aliases_application(_: PgPoolOptions, options: PgConnectO let pool = setup_pool(options).await; let config = init_config(None); - let client = make_client_v2(pool.clone(), config).await; - authenticate_admin(&client).await; + let mut client = make_client_v2(pool.clone(), config).await; + authenticate_admin(&mut client).await; let alias_1 = make_alias(); let alias_2 = make_alias(); diff --git a/crates/defguard_core/tests/integration/api/auth.rs b/crates/defguard_core/tests/integration/api/auth.rs index ad43507125..272eeadc49 100644 --- a/crates/defguard_core/tests/integration/api/auth.rs +++ b/crates/defguard_core/tests/integration/api/auth.rs @@ -6,6 +6,7 @@ use defguard_common::db::models::{MFAMethod, Settings, settings::update_current_ use defguard_core::{ auth::{TOTP_CODE_DIGITS, TOTP_CODE_VALIDITY_PERIOD}, db::{MFAInfo, User, UserDetails}, + events::ApiEventType, handlers::{Auth, AuthCode, AuthResponse, AuthTotp}, }; use reqwest::{StatusCode, header::USER_AGENT}; @@ -59,6 +60,8 @@ async fn test_logout(_: PgPoolOptions, options: PgConnectOptions) { let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); + client.verify_api_events_with_user(&[(ApiEventType::UserLogin, 2, "hpotter")]); + // store auth cookie for later use let auth_cookie = response .cookies() @@ -74,6 +77,8 @@ async fn test_logout(_: PgPoolOptions, options: PgConnectOptions) { let response = client.get("/api/v1/me").send().await; assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + client.verify_api_events_with_user(&[(ApiEventType::UserLogout, 2, "hpotter")]); + // try reusing auth cookie client.set_cookie(&auth_cookie); let response = client.get("/api/v1/me").send().await; @@ -84,7 +89,7 @@ async fn test_logout(_: PgPoolOptions, options: PgConnectOptions) { async fn test_login_bruteforce(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let client = make_client(pool).await; + let mut client = make_client(pool).await; let invalid_auth = Auth::new("hpotter", "invalid"); @@ -93,8 +98,12 @@ async fn test_login_bruteforce(_: PgPoolOptions, options: PgConnectOptions) { let response = client.post("/api/v1/auth").json(&invalid_auth).send().await; if i == 5 { assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS); + client.assert_event_queue_is_empty(); } else { assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + client.verify_api_events(&[ApiEventType::UserLoginFailed { + message: "Authentication for hpotter failed: invalid password".into(), + }]); } } } diff --git a/crates/defguard_core/tests/integration/api/common/client.rs b/crates/defguard_core/tests/integration/api/common/client.rs index 9292b31a3e..2203c0c899 100644 --- a/crates/defguard_core/tests/integration/api/common/client.rs +++ b/crates/defguard_core/tests/integration/api/common/client.rs @@ -2,22 +2,29 @@ use std::{net::SocketAddr, sync::Arc}; use axum::{Router, serve}; use bytes::Bytes; -use defguard_core::events::ApiEvent; +use defguard_common::db::Id; +use defguard_core::{ + events::{ApiEvent, ApiEventType}, + handlers::Auth, +}; use reqwest::{ Body, Client, StatusCode, Url, cookie::{Cookie, Jar}, header::{HeaderMap, HeaderName, HeaderValue, USER_AGENT}, redirect::Policy, }; -use tokio::{net::TcpListener, sync::mpsc::UnboundedReceiver, task::JoinHandle}; +use tokio::{ + net::TcpListener, + sync::mpsc::{UnboundedReceiver, error::TryRecvError}, + task::JoinHandle, +}; pub struct TestClient { client: Client, jar: Arc, port: u16, - // Has to live during whole test - #[allow(dead_code)] api_event_rx: UnboundedReceiver, + // Has to live during whole test api_task_handle: JoinHandle<()>, } @@ -65,6 +72,15 @@ impl TestClient { .add_cookie_str(&format!("{}={}", cookie.name(), cookie.value()), &url); } + // Helper to perform API login + pub async fn login_user(&mut self, username: &str, password: &str) { + let auth = Auth::new(username, password); + let response = self.post("/api/v1/auth").json(&auth).send().await; + assert_eq!(response.status(), StatusCode::OK); + + self.verify_api_events(&[ApiEventType::UserLogin]); + } + /// returns the base URL (http://ip:port) for this TestClient /// /// this is useful when trying to check if Location headers in responses @@ -122,6 +138,101 @@ impl TestClient { builder: self.client.delete(full_url), } } + + /// Assert that expected API events have been emitted + /// + /// `expected_events` should include all events that are currently in the queue + /// for the assertions to pass. + /// If there are too many or not enough events in the queue this should panic. + pub fn verify_api_events(&mut self, expected_events: &[ApiEventType]) { + // take all the events from the queue + let events = self.drain_all_events(); + + // verify number of events + assert_eq!( + events.len(), + expected_events.len(), + "Event number different than expected" + ); + + // compare events in order + for (index, (expected_event, (event, _user_id, _username))) in + expected_events.iter().zip(events.iter()).enumerate() + { + assert_eq!( + expected_event, event, + "Mismatch at index {index}: expected {expected_event:?}, got {event:?}", + ); + } + } + + /// A variant of `verify_api_events` which also compares user context + /// + /// Other parts of event context that would be hard and not that useful to test (timestamp, device) are omitted. + pub fn verify_api_events_with_user(&mut self, expected_events: &[(ApiEventType, Id, &str)]) { + // take all the events from the queue + let events = self.drain_all_events(); + + // verify number of events + assert_eq!( + events.len(), + expected_events.len(), + "Event number different than expected" + ); + + // compare events in order + for ( + index, + ((expected_event, expected_user_id, expected_username), (event, user_id, username)), + ) in expected_events.iter().zip(events.iter()).enumerate() + { + assert_eq!( + expected_event, event, + "Event type mismatch at index {index}: expected {expected_event:?}, got {event:?}", + ); + assert_eq!( + expected_user_id, user_id, + "User ID mismatch at index {index}: expected {expected_user_id:?}, got {user_id:?}", + ); + assert_eq!( + expected_username, username, + "Username mismatch at index {index}: expected {expected_username:?}, got {username:?}", + ); + } + } + + /// Receive all messages currently present in API event queue + /// + /// Can also be used to clear the queue. + pub fn drain_all_events(&mut self) -> Vec<(ApiEventType, Id, String)> { + let mut all_events = Vec::new(); + + loop { + match self.api_event_rx.try_recv() { + Ok(msg) => all_events.push((*msg.event, msg.context.user_id, msg.context.username)), + Err(tokio::sync::mpsc::error::TryRecvError::Empty) => { + // No more messages available right now + break; + } + Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => { + // Channel is closed + break; + } + } + } + all_events + } + + /// Assert there are no events queued + pub fn assert_event_queue_is_empty(&mut self) { + match self.api_event_rx.try_recv() { + Err(TryRecvError::Empty) => { + // Queue is empty, test passes + } + Ok(msg) => panic!("Expected empty queue, but got event: {msg:?}"), + Err(TryRecvError::Disconnected) => panic!("Channel is disconnected"), + } + } } impl Drop for TestClient { diff --git a/crates/defguard_core/tests/integration/api/common/mod.rs b/crates/defguard_core/tests/integration/api/common/mod.rs index 669ecc9649..233f51f056 100644 --- a/crates/defguard_core/tests/integration/api/common/mod.rs +++ b/crates/defguard_core/tests/integration/api/common/mod.rs @@ -14,7 +14,7 @@ use defguard_common::{ use defguard_core::{ auth::failed_login::FailedLoginMap, build_webapp, - db::{AppEvent, GatewayEvent, User, UserDetails}, + db::{AppEvent, Device, GatewayEvent, User, UserDetails, WireguardNetwork}, enterprise::license::{License, set_cached_license}, events::ApiEvent, grpc::{WorkerState, gateway::map::GatewayMap}, @@ -241,8 +241,27 @@ pub(crate) async fn make_client_with_db(pool: PgPool) -> (TestClient, PgPool) { (client, client_state.pool) } -pub(crate) async fn authenticate_admin(client: &TestClient) { - let auth = Auth::new("admin", "pass123"); - let response = client.post("/api/v1/auth").json(&auth).send().await; - assert_eq!(response.status(), StatusCode::OK); +pub(crate) async fn authenticate_admin(client: &mut TestClient) { + client.login_user("admin", "pass123").await; +} + +// Helper to fetch current user state from DB by username +pub(crate) async fn get_db_user(pool: &PgPool, username: &str) -> User { + User::find_by_username(pool, username) + .await + .unwrap() + .unwrap() +} + +// Helper to fetch current location state from DB by ID +pub(crate) async fn get_db_location(pool: &PgPool, location_id: Id) -> WireguardNetwork { + WireguardNetwork::find_by_id(pool, location_id) + .await + .unwrap() + .unwrap() +} + +// Helper to fetch current user device state from DB by device ID +pub(crate) async fn get_db_device(pool: &PgPool, device_id: Id) -> Device { + Device::find_by_id(pool, device_id).await.unwrap().unwrap() } diff --git a/crates/defguard_core/tests/integration/api/snat.rs b/crates/defguard_core/tests/integration/api/snat.rs index 01655d9b56..3e0ad6428a 100644 --- a/crates/defguard_core/tests/integration/api/snat.rs +++ b/crates/defguard_core/tests/integration/api/snat.rs @@ -20,10 +20,10 @@ use super::common::{ async fn test_snat_crud(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let (client, _) = make_test_client(pool).await; + let (mut client, _) = make_test_client(pool).await; // admin login - authenticate_admin(&client).await; + authenticate_admin(&mut client).await; // create location let response = client @@ -110,10 +110,10 @@ async fn test_snat_crud(_: PgPoolOptions, options: PgConnectOptions) { async fn test_snat_enterprise_required(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let (client, _) = make_test_client(pool).await; + let (mut client, _) = make_test_client(pool).await; // admin login - authenticate_admin(&client).await; + authenticate_admin(&mut client).await; exceed_enterprise_limits(&client).await; diff --git a/crates/defguard_core/tests/integration/api/user.rs b/crates/defguard_core/tests/integration/api/user.rs index d7605a5348..936bcb69a5 100644 --- a/crates/defguard_core/tests/integration/api/user.rs +++ b/crates/defguard_core/tests/integration/api/user.rs @@ -4,12 +4,15 @@ use defguard_core::{ AddDevice, UserInfo, models::{NewOpenIDClient, oauth2client::OAuth2Client}, }, + events::ApiEventType, handlers::{AddUserData, Auth, PasswordChange, PasswordChangeSelf, Username}, }; use reqwest::{StatusCode, header::USER_AGENT}; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use tokio_stream::{self as stream, StreamExt}; +use crate::api::common::{get_db_device, get_db_location, get_db_user, make_client_with_db}; + use super::{ TEST_SERVER_URL, common::{fetch_user_details, make_client, make_network, make_test_client, setup_pool}, @@ -19,7 +22,7 @@ use super::{ async fn test_authenticate(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let client = make_client(pool).await; + let mut client = make_client(pool).await; let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; @@ -32,35 +35,44 @@ async fn test_authenticate(_: PgPoolOptions, options: PgConnectOptions) { let auth = Auth::new("adumbledore", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + + // second user does not exist so we are unable to emit audit log event + client.verify_api_events_with_user(&[ + (ApiEventType::UserLogin, 2, "hpotter"), + ( + ApiEventType::UserLoginFailed { + message: "Authentication for hpotter failed: invalid password".into(), + }, + 2, + "hpotter", + ), + ]); } #[sqlx::test] async fn test_me(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let client = make_client(pool).await; + let mut client = make_client(pool).await; - let auth = Auth::new("hpotter", "pass123"); - let response = client.post("/api/v1/auth").json(&auth).send().await; - assert_eq!(response.status(), StatusCode::OK); + client.login_user("hpotter", "pass123").await; let response = client.get("/api/v1/me").send().await; assert_eq!(response.status(), StatusCode::OK); let user_info: UserInfo = response.json().await; assert_eq!(user_info.first_name, "Harry"); assert_eq!(user_info.last_name, "Potter"); + + client.assert_event_queue_is_empty(); } #[sqlx::test] async fn test_change_self_password(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let client = make_client(pool).await; - - let auth = Auth::new("hpotter", "pass123"); + let mut client = make_client(pool).await; - let response = client.post("/api/v1/auth").json(&auth).send().await; - assert_eq!(response.status(), StatusCode::OK); + client.login_user("hpotter", "pass123").await; let bad_old = "notCurrentPassword123!$"; @@ -103,6 +115,7 @@ async fn test_change_self_password(_: PgPoolOptions, options: PgConnectOptions) assert_eq!(response.status(), StatusCode::OK); // old pass login + let auth = Auth::new("hpotter", "pass123"); let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::UNAUTHORIZED); @@ -110,18 +123,27 @@ async fn test_change_self_password(_: PgPoolOptions, options: PgConnectOptions) let response = client.post("/api/v1/auth").json(&new_auth).send().await; assert_eq!(response.status(), StatusCode::OK); + + client.verify_api_events_with_user(&[ + (ApiEventType::PasswordChanged, 2, "hpotter"), + ( + ApiEventType::UserLoginFailed { + message: "Authentication for hpotter failed: invalid password".into(), + }, + 2, + "hpotter", + ), + (ApiEventType::UserLogin, 2, "hpotter"), + ]); } #[sqlx::test] async fn test_change_password(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let client = make_client(pool).await; - - let auth = Auth::new("admin", "pass123"); - let response = client.post("/api/v1/auth").json(&auth).send().await; + let (mut client, pool) = make_client_with_db(pool).await; - assert_eq!(response.status(), StatusCode::OK); + client.login_user("admin", "pass123").await; let new_password = "newPassword43$!"; @@ -159,62 +181,69 @@ async fn test_change_password(_: PgPoolOptions, options: PgConnectOptions) { .send() .await; assert_eq!(response.status(), StatusCode::FORBIDDEN); + + let test_user = get_db_user(&pool, "hpotter").await; + + client.verify_api_events_with_user(&[ + ( + ApiEventType::PasswordChangedByAdmin { user: test_user }, + 1, + "admin", + ), + (ApiEventType::UserLogin, 2, "hpotter"), + ]); } #[sqlx::test] async fn test_list_users(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let client = make_client(pool).await; + let mut client = make_client(pool).await; let response = client.get("/api/v1/user").send().await; assert_eq!(response.status(), StatusCode::UNAUTHORIZED); // normal user cannot list users - let auth = Auth::new("hpotter", "pass123"); - let response = client.post("/api/v1/auth").json(&auth).send().await; - assert_eq!(response.status(), StatusCode::OK); + client.login_user("hpotter", "pass123").await; let response = client.get("/api/v1/user").send().await; assert_eq!(response.status(), StatusCode::FORBIDDEN); // admin can list users - let auth = Auth::new("admin", "pass123"); - let response = client.post("/api/v1/auth").json(&auth).send().await; - assert_eq!(response.status(), StatusCode::OK); + client.login_user("admin", "pass123").await; let response = client.get("/api/v1/user").send().await; assert_eq!(response.status(), StatusCode::OK); + + client.assert_event_queue_is_empty(); } #[sqlx::test] async fn test_get_user(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let client = make_client(pool).await; + let mut client = make_client(pool).await; let response = client.get("/api/v1/user/hpotter").send().await; assert_eq!(response.status(), StatusCode::UNAUTHORIZED); - let auth = Auth::new("hpotter", "pass123"); - let response = client.post("/api/v1/auth").json(&auth).send().await; - assert_eq!(response.status(), StatusCode::OK); + client.login_user("hpotter", "pass123").await; let user_info = fetch_user_details(&client, "hpotter").await; assert_eq!(user_info.user.first_name, "Harry"); assert_eq!(user_info.user.last_name, "Potter"); + + client.assert_event_queue_is_empty(); } #[sqlx::test] async fn test_username_available(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let client = make_client(pool).await; + let mut client = make_client(pool).await; // standard user cannot check username availability - let auth = Auth::new("hpotter", "pass123"); - let response = client.post("/api/v1/auth").json(&auth).send().await; - assert_eq!(response.status(), StatusCode::OK); + client.login_user("hpotter", "pass123").await; let avail = Username { username: "hpotter".into(), @@ -227,9 +256,7 @@ async fn test_username_available(_: PgPoolOptions, options: PgConnectOptions) { assert_eq!(response.status(), StatusCode::FORBIDDEN); // log in as admin - let auth = Auth::new("admin", "pass123"); - let response = client.post("/api/v1/auth").json(&auth).send().await; - assert_eq!(response.status(), StatusCode::OK); + client.login_user("admin", "pass123").await; let avail = Username { username: "_CrashTestDummy".into(), @@ -260,17 +287,17 @@ async fn test_username_available(_: PgPoolOptions, options: PgConnectOptions) { .send() .await; assert_eq!(response.status(), StatusCode::BAD_REQUEST); + + client.assert_event_queue_is_empty(); } #[sqlx::test] async fn test_crud_user(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let client = make_client(pool).await; + let (mut client, pool) = make_client_with_db(pool).await; - let auth = Auth::new("admin", "pass123"); - let response = client.post("/api/v1/auth").json(&auth).send().await; - assert_eq!(response.status(), StatusCode::OK); + client.login_user("admin", "pass123").await; // create user let new_user = AddUserData { @@ -288,6 +315,8 @@ async fn test_crud_user(_: PgPoolOptions, options: PgConnectOptions) { let mut user_details = fetch_user_details(&client, "adumbledore").await; assert_eq!(user_details.user.first_name, "Albus"); + let old_test_user = get_db_user(&pool, "adumbledore").await; + // edit user user_details.user.phone = Some("5678".into()); let response = client @@ -297,20 +326,33 @@ async fn test_crud_user(_: PgPoolOptions, options: PgConnectOptions) { .await; assert_eq!(response.status(), StatusCode::OK); + let new_test_user = get_db_user(&pool, "adumbledore").await; + // delete user let response = client.delete("/api/v1/user/adumbledore").send().await; assert_eq!(response.status(), StatusCode::OK); + + client.verify_api_events(&[ + ApiEventType::UserAdded { + user: old_test_user.clone(), + }, + ApiEventType::UserModified { + before: old_test_user, + after: new_test_user.clone(), + }, + ApiEventType::UserRemoved { + user: new_test_user, + }, + ]); } #[sqlx::test] async fn test_check_username(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let client = make_client(pool).await; + let (mut client, pool) = make_client_with_db(pool).await; - let auth = Auth::new("admin", "pass123"); - let response = client.post("/api/v1/auth").json(&auth).send().await; - assert_eq!(response.status(), StatusCode::OK); + client.login_user("admin", "pass123").await; let invalid_usernames = ["ADumble dore", ".1user"]; let valid_usernames = ["user1", "use2r3", "not_wrong"]; @@ -328,6 +370,7 @@ async fn test_check_username(_: PgPoolOptions, options: PgConnectOptions) { assert_eq!(response.status(), StatusCode::BAD_REQUEST); } + let mut expected_events = Vec::new(); for (i, username) in valid_usernames.into_iter().enumerate() { let new_user = AddUserData { username: username.into(), @@ -339,19 +382,22 @@ async fn test_check_username(_: PgPoolOptions, options: PgConnectOptions) { }; let response = client.post("/api/v1/user").json(&new_user).send().await; assert_eq!(response.status(), StatusCode::CREATED); + + let test_user = get_db_user(&pool, username).await; + expected_events.push(ApiEventType::UserAdded { user: test_user }) } + + client.verify_api_events(&expected_events); } #[sqlx::test] async fn test_check_password_strength(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let client = make_client(pool).await; + let (mut client, pool) = make_client_with_db(pool).await; // auth session with admin - let auth = Auth::new("admin", "pass123"); - let response = client.post("/api/v1/auth").json(&auth).send().await; - assert_eq!(response.status(), StatusCode::OK); + client.login_user("admin", "pass123").await; // test let strong_password = "strongPass1234$!"; @@ -391,16 +437,20 @@ async fn test_check_password_strength(_: PgPoolOptions, options: PgConnectOption .send() .await; assert_eq!(response.status(), StatusCode::CREATED); + + let test_user = get_db_user(&pool, "strongpass").await; + + client.verify_api_events(&[ApiEventType::UserAdded { user: test_user }]); } #[sqlx::test] async fn test_user_unregister_authorized_app(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let client = make_client(pool).await; - let auth = Auth::new("admin", "pass123"); - let response = client.post("/api/v1/auth").json(&auth).send().await; - assert_eq!(response.status(), StatusCode::OK); + let (mut client, pool) = make_client_with_db(pool).await; + client.login_user("admin", "pass123").await; + + // add OpenID app let openid_client = NewOpenIDClient { name: "Test".into(), redirect_uri: vec![TEST_SERVER_URL.into()], @@ -415,6 +465,13 @@ async fn test_user_unregister_authorized_app(_: PgPoolOptions, options: PgConnec assert_eq!(response.status(), StatusCode::CREATED); let openid_client: OAuth2Client = response.json().await; assert_eq!(openid_client.name, "Test"); + + // verify app is not authorized yet + let response = client.get("/api/v1/me").send().await; + let user_info: UserInfo = response.json().await; + assert_eq!(user_info.authorized_apps.len(), 0); + + // authorize app let response = client .post(format!( "/api/v1/oauth/authorize?\ @@ -433,6 +490,10 @@ async fn test_user_unregister_authorized_app(_: PgPoolOptions, options: PgConnec let response = client.get("/api/v1/me").send().await; let mut user_info: UserInfo = response.json().await; assert_eq!(user_info.authorized_apps.len(), 1); + + let old_test_user = get_db_user(&pool, "admin").await; + + // unregister app user_info.authorized_apps = [].into(); let response = client .put("/api/v1/user/admin") @@ -443,16 +504,28 @@ async fn test_user_unregister_authorized_app(_: PgPoolOptions, options: PgConnec let response = client.get("/api/v1/me").send().await; let user_info: UserInfo = response.json().await; assert_eq!(user_info.authorized_apps.len(), 0); + + let new_test_user = get_db_user(&pool, "admin").await; + + client.verify_api_events(&[ + ApiEventType::OpenIdAppAdded { app: openid_client }, + ApiEventType::UserModified { + before: old_test_user, + after: new_test_user.clone(), + }, + ]); } #[sqlx::test] async fn test_user_add_device(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let (client, state) = make_test_client(pool).await; + let (mut client, state) = make_test_client(pool).await; let mut mail_rx = state.mail_rx; let user_agent_header = "Mozilla/5.0 (iPhone; CPU iPhone OS 17_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.1 Mobile/15E148 Safari/604.1"; + let mut expected_events = Vec::new(); + // log in as admin let auth = Auth::new("admin", "pass123"); let response = client @@ -462,6 +535,7 @@ async fn test_user_add_device(_: PgPoolOptions, options: PgConnectOptions) { .send() .await; assert_eq!(response.status(), StatusCode::OK); + expected_events.push(ApiEventType::UserLogin); // first email received is regarding admin login let mail = mail_rx.try_recv().unwrap(); @@ -478,6 +552,9 @@ async fn test_user_add_device(_: PgPoolOptions, options: PgConnectOptions) { .send() .await; assert_eq!(response.status(), StatusCode::CREATED); + expected_events.push(ApiEventType::VpnLocationAdded { + location: get_db_location(&state.pool, 1).await, + }); // add device for user let device_data = AddDevice { @@ -491,6 +568,10 @@ async fn test_user_add_device(_: PgPoolOptions, options: PgConnectOptions) { .send() .await; assert_eq!(response.status(), StatusCode::CREATED); + expected_events.push(ApiEventType::UserDeviceAdded { + owner: get_db_user(&state.pool, "hpotter").await, + device: get_db_device(&state.pool, 1).await, + }); // send email regarding new device being added // it does not contain session info @@ -512,6 +593,10 @@ async fn test_user_add_device(_: PgPoolOptions, options: PgConnectOptions) { .send() .await; assert_eq!(response.status(), StatusCode::CREATED); + expected_events.push(ApiEventType::UserDeviceAdded { + owner: get_db_user(&state.pool, "admin").await, + device: get_db_device(&state.pool, 2).await, + }); // send email regarding new device being added // it should contain session info @@ -533,6 +618,7 @@ async fn test_user_add_device(_: PgPoolOptions, options: PgConnectOptions) { .send() .await; assert_eq!(response.status(), StatusCode::OK); + expected_events.push(ApiEventType::UserLogin); let response = client.get("/api/v1/me").send().await; assert_eq!(response.status(), StatusCode::OK); @@ -580,6 +666,10 @@ async fn test_user_add_device(_: PgPoolOptions, options: PgConnectOptions) { .send() .await; assert_eq!(response.status(), StatusCode::CREATED); + expected_events.push(ApiEventType::UserDeviceAdded { + owner: get_db_user(&state.pool, "hpotter").await, + device: get_db_device(&state.pool, 3).await, + }); // send email regarding new device being added let mail = mail_rx.try_recv().unwrap(); @@ -590,17 +680,17 @@ async fn test_user_add_device(_: PgPoolOptions, options: PgConnectOptions) { mail.content .contains("Device type: iPhone, OS: iOS 17.1, Mobile Safari") ); + + client.verify_api_events(&expected_events); } #[sqlx::test] async fn test_disable(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let client = make_client(pool).await; + let (mut client, pool) = make_client_with_db(pool).await; - let auth = Auth::new("admin", "pass123"); - let response = client.post("/api/v1/auth").json(&auth).send().await; - assert_eq!(response.status(), StatusCode::OK); + client.login_user("admin", "pass123").await; // get yourself let mut user_details = fetch_user_details(&client, "admin").await; @@ -632,6 +722,8 @@ async fn test_disable(_: PgPoolOptions, options: PgConnectOptions) { assert_eq!(user_details.user.first_name, "Albus"); assert!(user_details.user.is_active); + let old_test_user = get_db_user(&pool, "adumbledore").await; + // disable user user_details.user.is_active = false; let response = client @@ -644,17 +736,27 @@ async fn test_disable(_: PgPoolOptions, options: PgConnectOptions) { let user_details = fetch_user_details(&client, "adumbledore").await; assert_eq!(user_details.user.first_name, "Albus"); assert!(!user_details.user.is_active); + + let new_test_user = get_db_user(&pool, "adumbledore").await; + + client.verify_api_events(&[ + ApiEventType::UserAdded { + user: old_test_user.clone(), + }, + ApiEventType::UserModified { + before: old_test_user, + after: new_test_user.clone(), + }, + ]); } #[sqlx::test] async fn test_unique_email(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let client = make_client(pool).await; + let (mut client, pool) = make_client_with_db(pool).await; - let auth = Auth::new("admin", "pass123"); - let response = client.post("/api/v1/auth").json(&auth).send().await; - assert_eq!(response.status(), StatusCode::OK); + client.login_user("admin", "pass123").await; // create user let new_user = AddUserData { @@ -679,4 +781,8 @@ async fn test_unique_email(_: PgPoolOptions, options: PgConnectOptions) { }; let response = client.post("/api/v1/user").json(&new_user).send().await; assert_eq!(response.status(), StatusCode::BAD_REQUEST); + + let test_user = get_db_user(&pool, "adumbledore").await; + + client.verify_api_events(&[ApiEventType::UserAdded { user: test_user }]); } diff --git a/crates/defguard_core/tests/integration/api/wireguard.rs b/crates/defguard_core/tests/integration/api/wireguard.rs index 79625944c0..468e2a8cbd 100644 --- a/crates/defguard_core/tests/integration/api/wireguard.rs +++ b/crates/defguard_core/tests/integration/api/wireguard.rs @@ -128,8 +128,8 @@ async fn test_network(_: PgPoolOptions, options: PgConnectOptions) { async fn test_location_mfa_mode_validation_create(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let (client, _client_state) = make_test_client(pool).await; - authenticate_admin(&client).await; + let (mut client, _client_state) = make_test_client(pool).await; + authenticate_admin(&mut client).await; exceed_enterprise_limits(&client).await; @@ -213,8 +213,8 @@ async fn test_location_mfa_mode_validation_create(_: PgPoolOptions, options: PgC async fn test_location_mfa_mode_validation_modify(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; - let (client, _client_state) = make_test_client(pool).await; - authenticate_admin(&client).await; + let (mut client, _client_state) = make_test_client(pool).await; + authenticate_admin(&mut client).await; let mut location_data = WireguardNetworkData { name: "test_location".into(),