Skip to content
4 changes: 2 additions & 2 deletions crates/defguard_common/src/db/models/authentication_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use sqlx::{Error as SqlxError, PgExecutor, Type, query_as};

use crate::db::{Id, NoId};

#[derive(Clone, Debug, Deserialize, Serialize, Type)]
#[derive(Clone, Debug, Deserialize, Serialize, Type, PartialEq)]
#[sqlx(type_name = "authentication_key_type", rename_all = "lowercase")]
#[serde(rename_all = "lowercase")]
pub enum AuthenticationKeyType {
Expand All @@ -23,7 +23,7 @@ impl Display for AuthenticationKeyType {
}
}

#[derive(Clone, Debug, Deserialize, Model, Serialize)]
#[derive(Clone, Debug, Deserialize, Model, Serialize, PartialEq)]
#[table(authentication_key)]
pub struct AuthenticationKey<I = NoId> {
pub id: I,
Expand Down
12 changes: 10 additions & 2 deletions crates/defguard_core/src/db/models/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{fmt, net::IpAddr};
use base64::{Engine, prelude::BASE64_STANDARD};
#[cfg(test)]
use chrono::NaiveDate;
use chrono::{NaiveDateTime, Utc};
use chrono::{NaiveDateTime, Timelike, Utc};
use defguard_common::{
csv::AsCsv,
db::{Id, NoId, models::ModelError},
Expand Down Expand Up @@ -532,12 +532,20 @@ impl Device {
description: Option<String>,
configured: bool,
) -> Self {
// FIXME: this is a workaround for reducing timestamp precision.
// `chrono` has nanosecond precision by default, while Postgres only does microseconds.
// It avoids issues when comparing to objects fetched from DB.
let created = Utc::now().naive_utc();
let created = created
.with_nanosecond((created.nanosecond() / 1_000) * 1_000)
.expect("failed to truncate timestamp precision");

Self {
id: NoId,
name,
wireguard_pubkey,
user_id,
created: Utc::now().naive_utc(),
created,
device_type,
description,
configured,
Expand Down
2 changes: 1 addition & 1 deletion crates/defguard_core/src/db/models/oauth2client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use super::NewOpenIDClient;
use defguard_common::db::{Id, NoId};
use defguard_common::random::gen_alphanumeric;

#[derive(Clone, Debug, Deserialize, Model, Serialize)]
#[derive(Clone, Debug, Deserialize, Model, Serialize, PartialEq)]
pub struct OAuth2Client<I = NoId> {
pub id: I,
pub client_id: String, // unique
Expand Down
2 changes: 1 addition & 1 deletion crates/defguard_core/src/db/models/webauthn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use webauthn_rs::prelude::Passkey;

use defguard_common::db::{Id, NoId, models::ModelError};

#[derive(Model, Clone, Debug)]
#[derive(Model, Clone, Debug, PartialEq)]
pub struct WebAuthn<I = NoId> {
pub id: I,
pub user_id: Id,
Expand Down
2 changes: 1 addition & 1 deletion crates/defguard_core/src/db/models/webhook.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ impl AppEvent {
}
}

#[derive(Clone, Debug, Deserialize, FromRow, Model, Serialize)]
#[derive(Clone, Debug, Deserialize, FromRow, Model, Serialize, PartialEq)]
pub struct WebHook<I = NoId> {
pub id: I,
pub url: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use defguard_common::{
secret::SecretStringWrapper,
};

#[derive(Debug, Serialize, Deserialize, Type, EnumString, Display, Clone)]
#[derive(Debug, Serialize, Deserialize, Type, EnumString, Display, Clone, PartialEq)]
#[sqlx(type_name = "text", rename_all = "snake_case")]
#[serde(rename_all = "snake_case")]
pub enum ActivityLogStreamType {
Expand All @@ -19,7 +19,7 @@ pub enum ActivityLogStreamType {
LogstashHttp,
}

#[derive(Clone, Debug, Serialize, Model, FromRow)]
#[derive(Clone, Debug, Serialize, Model, FromRow, PartialEq)]
#[table(activity_log_stream)]
pub struct ActivityLogStream<I = NoId> {
pub id: I,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use sqlx::{Error as SqlxError, PgExecutor, query_as};

use defguard_common::db::{Id, NoId};

#[derive(Clone, Debug, Deserialize, Model, Serialize)]
#[derive(Clone, Debug, Deserialize, Model, Serialize, PartialEq)]
#[table(api_token)]
pub struct ApiToken<I = NoId> {
pub id: I,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ impl From<String> for DirectorySyncTarget {
}
}

#[derive(Clone, Debug, Deserialize, Model, Serialize)]
#[derive(Clone, Debug, Deserialize, Model, Serialize, PartialEq)]
pub struct OpenIdProvider<I = NoId> {
pub id: I,
pub name: String,
Expand Down
2 changes: 1 addition & 1 deletion crates/defguard_core/src/enterprise/db/models/snat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use utoipa::ToSchema;
use crate::enterprise::snat::error::UserSnatBindingError;
use defguard_common::db::{Id, NoId};

#[derive(Clone, Debug, Deserialize, Model, Serialize, ToSchema)]
#[derive(Clone, Debug, Deserialize, Model, Serialize, ToSchema, PartialEq)]
#[table(user_snat_binding)]
pub struct UserSnatBinding<I = NoId> {
pub id: I,
Expand Down
4 changes: 2 additions & 2 deletions crates/defguard_core/src/events.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use defguard_proto::proxy::MfaMethod;
/// Mainly meant to be stored in the activity log.
/// By design this is a duplicate of a similar struct in the `event_logger` module.
/// This is done in order to avoid circular imports once we split the project into multiple crates.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
pub struct ApiRequestContext {
pub timestamp: NaiveDateTime,
pub user_id: Id,
Expand Down Expand Up @@ -83,7 +83,7 @@ impl GrpcRequestContext {
}
}

#[derive(Debug)]
#[derive(Debug, PartialEq)]
pub enum ApiEventType {
UserLogin,
UserLoginFailed {
Expand Down
76 changes: 38 additions & 38 deletions crates/defguard_core/tests/integration/api/acl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ fn edit_alias_data_into_api_response(
async fn test_rule_crud(_: PgPoolOptions, options: PgConnectOptions) {
let pool = setup_pool(options).await;

let (client, _) = make_test_client(pool).await;
authenticate_admin(&client).await;
let (mut client, _) = make_test_client(pool).await;
authenticate_admin(&mut client).await;

let rule = make_rule();

Expand Down Expand Up @@ -189,8 +189,8 @@ async fn test_rule_crud(_: PgPoolOptions, options: PgConnectOptions) {
async fn test_rule_enterprise(_: PgPoolOptions, options: PgConnectOptions) {
let pool = setup_pool(options).await;

let (client, _) = make_test_client(pool).await;
authenticate_admin(&client).await;
let (mut client, _) = make_test_client(pool).await;
authenticate_admin(&mut client).await;

exceed_enterprise_limits(&client).await;

Expand Down Expand Up @@ -230,8 +230,8 @@ async fn test_rule_enterprise(_: PgPoolOptions, options: PgConnectOptions) {
async fn test_alias_crud(_: PgPoolOptions, options: PgConnectOptions) {
let pool = setup_pool(options).await;

let (client, _) = make_test_client(pool).await;
authenticate_admin(&client).await;
let (mut client, _) = make_test_client(pool).await;
authenticate_admin(&mut client).await;

let alias = make_alias();

Expand Down Expand Up @@ -284,8 +284,8 @@ async fn test_alias_crud(_: PgPoolOptions, options: PgConnectOptions) {
async fn test_alias_enterprise(_: PgPoolOptions, options: PgConnectOptions) {
let pool = setup_pool(options).await;

let (client, _) = make_test_client(pool).await;
authenticate_admin(&client).await;
let (mut client, _) = make_test_client(pool).await;
authenticate_admin(&mut client).await;

exceed_enterprise_limits(&client).await;

Expand Down Expand Up @@ -325,8 +325,8 @@ async fn test_alias_enterprise(_: PgPoolOptions, options: PgConnectOptions) {
async fn test_empty_strings(_: PgPoolOptions, options: PgConnectOptions) {
let pool = setup_pool(options).await;

let (client, _) = make_test_client(pool).await;
authenticate_admin(&client).await;
let (mut client, _) = make_test_client(pool).await;
authenticate_admin(&mut client).await;

// rule
let mut rule = make_rule();
Expand Down Expand Up @@ -409,8 +409,8 @@ async fn test_related_objects(_: PgPoolOptions, options: PgConnectOptions) {
let pool = setup_pool(options).await;

let config = init_config(None);
let client = make_client_v2(pool.clone(), config).await;
authenticate_admin(&client).await;
let mut client = make_client_v2(pool.clone(), config).await;
authenticate_admin(&mut client).await;

// create related objects
// networks
Expand Down Expand Up @@ -541,8 +541,8 @@ async fn test_related_objects(_: PgPoolOptions, options: PgConnectOptions) {
async fn test_invalid_related_objects(_: PgPoolOptions, options: PgConnectOptions) {
let pool = setup_pool(options).await;

let (client, state) = make_test_client(pool).await;
authenticate_admin(&client).await;
let (mut client, state) = make_test_client(pool).await;
authenticate_admin(&mut client).await;

let rule = make_rule();
let response = client.post("/api/v1/acl/rule").json(&rule).send().await;
Expand Down Expand Up @@ -644,8 +644,8 @@ async fn test_invalid_related_objects(_: PgPoolOptions, options: PgConnectOption
async fn test_invalid_data(_: PgPoolOptions, options: PgConnectOptions) {
let pool = setup_pool(options).await;

let (client, _) = make_test_client(pool).await;
authenticate_admin(&client).await;
let (mut client, _) = make_test_client(pool).await;
authenticate_admin(&mut client).await;

// invalid port
let mut rule = make_rule();
Expand Down Expand Up @@ -677,8 +677,8 @@ async fn test_rule_create_modify_state(_: PgPoolOptions, options: PgConnectOptio
let pool = setup_pool(options).await;

let config = init_config(None);
let client = make_client_v2(pool.clone(), config).await;
authenticate_admin(&client).await;
let mut client = make_client_v2(pool.clone(), config).await;
authenticate_admin(&mut client).await;

let rule = make_rule();

Expand Down Expand Up @@ -732,8 +732,8 @@ async fn test_rule_delete_state_new(_: PgPoolOptions, options: PgConnectOptions)
let pool = setup_pool(options).await;

let config = init_config(None);
let client = make_client_v2(pool.clone(), config).await;
authenticate_admin(&client).await;
let mut client = make_client_v2(pool.clone(), config).await;
authenticate_admin(&mut client).await;

// test NEW rule deletion
let rule = make_rule();
Expand All @@ -751,8 +751,8 @@ async fn test_rule_delete_state_applied(_: PgPoolOptions, options: PgConnectOpti
let pool = setup_pool(options).await;

let config = init_config(None);
let client = make_client_v2(pool.clone(), config).await;
authenticate_admin(&client).await;
let mut client = make_client_v2(pool.clone(), config).await;
authenticate_admin(&mut client).await;

// create a location
WireguardNetwork::new(
Expand Down Expand Up @@ -812,8 +812,8 @@ async fn test_rule_duplication(_: PgPoolOptions, options: PgConnectOptions) {

// each modification / deletion of parent rule should remove the child and create a new one
let config = init_config(None);
let client = make_client_v2(pool.clone(), config).await;
authenticate_admin(&client).await;
let mut client = make_client_v2(pool.clone(), config).await;
authenticate_admin(&mut client).await;

let rule = make_rule();
let response = client.post("/api/v1/acl/rule").json(&rule).send().await;
Expand Down Expand Up @@ -842,8 +842,8 @@ async fn test_rule_application(_: PgPoolOptions, options: PgConnectOptions) {
let pool = setup_pool(options).await;

let config = init_config(None);
let client = make_client_v2(pool.clone(), config).await;
authenticate_admin(&client).await;
let mut client = make_client_v2(pool.clone(), config).await;
authenticate_admin(&mut client).await;

let rule = make_rule();

Expand Down Expand Up @@ -934,8 +934,8 @@ async fn test_multiple_rules_application(_: PgPoolOptions, options: PgConnectOpt
let pool = setup_pool(options).await;

let config = init_config(None);
let client = make_client_v2(pool.clone(), config).await;
authenticate_admin(&client).await;
let mut client = make_client_v2(pool.clone(), config).await;
authenticate_admin(&mut client).await;

let rule_1 = make_rule();
let rule_2 = make_rule();
Expand Down Expand Up @@ -972,8 +972,8 @@ async fn test_alias_create_modify_state(_: PgPoolOptions, options: PgConnectOpti
let pool = setup_pool(options).await;

let config = init_config(None);
let client = make_client_v2(pool.clone(), config).await;
authenticate_admin(&client).await;
let mut client = make_client_v2(pool.clone(), config).await;
authenticate_admin(&mut client).await;

let alias = make_alias();

Expand Down Expand Up @@ -1012,8 +1012,8 @@ async fn test_alias_delete(_: PgPoolOptions, options: PgConnectOptions) {
let pool = setup_pool(options).await;

let config = init_config(None);
let client = make_client_v2(pool.clone(), config).await;
authenticate_admin(&client).await;
let mut client = make_client_v2(pool.clone(), config).await;
authenticate_admin(&mut client).await;

// create alias
let alias = make_alias();
Expand Down Expand Up @@ -1078,8 +1078,8 @@ async fn test_alias_duplication(_: PgPoolOptions, options: PgConnectOptions) {

// each modification of parent alias should remove the child and create a new one
let config = init_config(None);
let client = make_client_v2(pool.clone(), config).await;
authenticate_admin(&client).await;
let mut client = make_client_v2(pool.clone(), config).await;
authenticate_admin(&mut client).await;

let alias = make_alias();
let response = client.post("/api/v1/acl/alias").json(&alias).send().await;
Expand All @@ -1104,8 +1104,8 @@ async fn test_alias_application(_: PgPoolOptions, options: PgConnectOptions) {
let pool = setup_pool(options).await;

let config = init_config(None);
let client = make_client_v2(pool.clone(), config).await;
authenticate_admin(&client).await;
let mut client = make_client_v2(pool.clone(), config).await;
authenticate_admin(&mut client).await;

// create new alias
let alias = make_alias();
Expand Down Expand Up @@ -1165,8 +1165,8 @@ async fn test_multiple_aliases_application(_: PgPoolOptions, options: PgConnectO
let pool = setup_pool(options).await;

let config = init_config(None);
let client = make_client_v2(pool.clone(), config).await;
authenticate_admin(&client).await;
let mut client = make_client_v2(pool.clone(), config).await;
authenticate_admin(&mut client).await;

let alias_1 = make_alias();
let alias_2 = make_alias();
Expand Down
11 changes: 10 additions & 1 deletion crates/defguard_core/tests/integration/api/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use defguard_common::db::models::{MFAMethod, Settings, settings::update_current_
use defguard_core::{
auth::{TOTP_CODE_DIGITS, TOTP_CODE_VALIDITY_PERIOD},
db::{MFAInfo, User, UserDetails},
events::ApiEventType,
handlers::{Auth, AuthCode, AuthResponse, AuthTotp},
};
use reqwest::{StatusCode, header::USER_AGENT};
Expand Down Expand Up @@ -59,6 +60,8 @@ async fn test_logout(_: PgPoolOptions, options: PgConnectOptions) {
let response = client.post("/api/v1/auth").json(&auth).send().await;
assert_eq!(response.status(), StatusCode::OK);

client.verify_api_events_with_user(&[(ApiEventType::UserLogin, 2, "hpotter")]);

// store auth cookie for later use
let auth_cookie = response
.cookies()
Expand All @@ -74,6 +77,8 @@ async fn test_logout(_: PgPoolOptions, options: PgConnectOptions) {
let response = client.get("/api/v1/me").send().await;
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);

client.verify_api_events_with_user(&[(ApiEventType::UserLogout, 2, "hpotter")]);

// try reusing auth cookie
client.set_cookie(&auth_cookie);
let response = client.get("/api/v1/me").send().await;
Expand All @@ -84,7 +89,7 @@ async fn test_logout(_: PgPoolOptions, options: PgConnectOptions) {
async fn test_login_bruteforce(_: PgPoolOptions, options: PgConnectOptions) {
let pool = setup_pool(options).await;

let client = make_client(pool).await;
let mut client = make_client(pool).await;

let invalid_auth = Auth::new("hpotter", "invalid");

Expand All @@ -93,8 +98,12 @@ async fn test_login_bruteforce(_: PgPoolOptions, options: PgConnectOptions) {
let response = client.post("/api/v1/auth").json(&invalid_auth).send().await;
if i == 5 {
assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
client.assert_event_queue_is_empty();
} else {
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
client.verify_api_events(&[ApiEventType::UserLoginFailed {
message: "Authentication for hpotter failed: invalid password".into(),
}]);
}
}
}
Expand Down
Loading