From 52c309a40af7b1df637fe5133901217fb71c48d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Fri, 20 Mar 2026 13:07:37 +0100 Subject: [PATCH 1/3] Limited pagination --- .../defguard_core/src/handlers/pagination.rs | 128 +++++++++++++++--- 1 file changed, 112 insertions(+), 16 deletions(-) diff --git a/crates/defguard_core/src/handlers/pagination.rs b/crates/defguard_core/src/handlers/pagination.rs index cb4428aece..91bbe5e313 100644 --- a/crates/defguard_core/src/handlers/pagination.rs +++ b/crates/defguard_core/src/handlers/pagination.rs @@ -5,19 +5,86 @@ use axum::{ response::{IntoResponse, Response}, }; use reqwest::StatusCode; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize, de}; use crate::error::WebError; +const DEFAULT_PER_PAGE: u32 = 50; +const MIN_PAGE: u32 = 1; +const MIN_PER_PAGE: u32 = 1; +const MAX_PER_PAGE: u32 = 100; + /// Query params for paginated endpoints -#[derive(Deserialize)] -#[serde(default)] pub(crate) struct PaginationParams { page: u32, per_page: u32, } +/// Implement custom deserializer to control default values and limits. +impl<'de> Deserialize<'de> for PaginationParams { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + #[serde(field_identifier, rename_all = "snake_case")] + enum Field { + Page, + PerPage, + } + + struct PaginationParamsVisitor; + + impl<'de> de::Visitor<'de> for PaginationParamsVisitor { + type Value = PaginationParams; + + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("struct PaginationParams") + } + + fn visit_map(self, mut map: V) -> Result + where + V: de::MapAccess<'de>, + { + let mut page = None; + let mut per_page = None; + while let Some(key) = map.next_key()? { + match key { + Field::Page => { + if page.is_some() { + return Err(de::Error::duplicate_field("page")); + } + page = Some(map.next_value()?); + } + Field::PerPage => { + if per_page.is_some() { + return Err(de::Error::duplicate_field("per_page")); + } + per_page = Some(map.next_value()?); + } + } + } + let page = page.unwrap_or(MIN_PAGE); + let per_page = per_page.unwrap_or(DEFAULT_PER_PAGE); + Ok(PaginationParams::new(page, per_page)) + } + } + + const FIELDS: &[&str] = &["page", "per_page"]; + deserializer.deserialize_struct("PaginationParams", FIELDS, PaginationParamsVisitor) + } +} + impl PaginationParams { + /// Constructor. + #[must_use] + pub fn new(page: u32, per_page: u32) -> Self { + Self { + page: page.max(MIN_PAGE), + per_page: per_page.max(MIN_PER_PAGE).min(MAX_PER_PAGE), + } + } + /// Page getter. #[must_use] pub fn page(&self) -> u32 { @@ -33,19 +100,15 @@ impl PaginationParams { /// Calculate offset. #[must_use] pub fn offset(&self) -> u32 { - if self.page == 0 { - self.per_page - } else { - (self.page - 1) * self.per_page - } + (self.page - 1) * self.per_page } } impl Default for PaginationParams { fn default() -> Self { Self { - page: 1, - per_page: 50, + page: MIN_PAGE, + per_page: DEFAULT_PER_PAGE, } } } @@ -71,12 +134,7 @@ impl PaginationMeta { #[must_use] fn from_pagination(pagination: PaginationParams, total_items: u32) -> Self { let PaginationParams { page, per_page } = pagination; - let total_pages = if per_page <= 1 { - // For 0 and 1, assume per_page is 1. - total_items - } else { - total_items.div_ceil(per_page) - }; + let total_pages = total_items.div_ceil(per_page); let next_page = if page < total_pages { Some(page + 1) } else { @@ -126,3 +184,41 @@ where } } } + +#[cfg(test)] +mod tests { + use super::PaginationParams; + + #[test] + fn deserialize_pagination_params_defaults() { + let params = serde_urlencoded::from_str::("").unwrap(); + assert_eq!(params.page(), 1); + assert_eq!(params.per_page(), 50); + assert_eq!(params.offset(), 0); + } + + #[test] + fn deserialize_pagination_params_zero_values() { + let params = serde_urlencoded::from_str::("page=0&per_page=0").unwrap(); + assert_eq!(params.page(), 1); + assert_eq!(params.per_page(), 1); + assert_eq!(params.offset(), 0); + } + + #[test] + fn deserialize_pagination_params_large_values() { + let params = + serde_urlencoded::from_str::("page=1000&per_page=1000").unwrap(); + assert_eq!(params.page(), 1000); + assert_eq!(params.per_page(), 100); + assert_eq!(params.offset(), 99900); + } + + #[test] + fn deserialize_pagination_params_valid_values() { + let params = serde_urlencoded::from_str::("page=3&per_page=25").unwrap(); + assert_eq!(params.page(), 3); + assert_eq!(params.per_page(), 25); + assert_eq!(params.offset(), 50); + } +} From 468ff8f69ae1286e38ca6aa072dd9632eaf0667d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Fri, 20 Mar 2026 13:17:17 +0100 Subject: [PATCH 2/3] Make clippy happy --- crates/defguard_core/src/handlers/pagination.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/defguard_core/src/handlers/pagination.rs b/crates/defguard_core/src/handlers/pagination.rs index 91bbe5e313..d9e7a74332 100644 --- a/crates/defguard_core/src/handlers/pagination.rs +++ b/crates/defguard_core/src/handlers/pagination.rs @@ -81,7 +81,7 @@ impl PaginationParams { pub fn new(page: u32, per_page: u32) -> Self { Self { page: page.max(MIN_PAGE), - per_page: per_page.max(MIN_PER_PAGE).min(MAX_PER_PAGE), + per_page: per_page.clamp(MIN_PER_PAGE, MAX_PER_PAGE), } } From f2893c8607f9a82175b95c9d6475c54af5860d0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Fri, 20 Mar 2026 14:13:21 +0100 Subject: [PATCH 3/3] Ignore other fields when deserialising --- .../handlers/activity_log_stream.rs | 2 +- .../defguard_core/src/handlers/pagination.rs | 44 ++++++++++++++++++- .../tests/integration/api/activity_log.rs | 43 +++++++++++------- .../defguard_event_logger/src/description.rs | 6 +-- 4 files changed, 74 insertions(+), 21 deletions(-) diff --git a/crates/defguard_core/src/enterprise/handlers/activity_log_stream.rs b/crates/defguard_core/src/enterprise/handlers/activity_log_stream.rs index 96cf905970..3c4cb3f509 100644 --- a/crates/defguard_core/src/enterprise/handlers/activity_log_stream.rs +++ b/crates/defguard_core/src/enterprise/handlers/activity_log_stream.rs @@ -53,7 +53,7 @@ pub async fn create_activity_log_stream( debug!("User {session_username} creates activity log stream"); // validate config let _ = ActivityLogStreamConfig::from_serde_value(&data.stream_type, &data.stream_config)?; - let stream_model: ActivityLogStream = ActivityLogStream { + let stream_model = ActivityLogStream { id: NoId, name: data.name, stream_type: data.stream_type, diff --git a/crates/defguard_core/src/handlers/pagination.rs b/crates/defguard_core/src/handlers/pagination.rs index d9e7a74332..7cd4ce8fe2 100644 --- a/crates/defguard_core/src/handlers/pagination.rs +++ b/crates/defguard_core/src/handlers/pagination.rs @@ -26,11 +26,40 @@ impl<'de> Deserialize<'de> for PaginationParams { where D: Deserializer<'de>, { - #[derive(Deserialize)] - #[serde(field_identifier, rename_all = "snake_case")] enum Field { Page, PerPage, + Other, // ignore other fields + } + + impl<'de> Deserialize<'de> for Field { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct FieldVisitor; + + impl de::Visitor<'_> for FieldVisitor { + type Value = Field; + + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("`page` or `per_page`") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + match value { + "page" => Ok(Field::Page), + "per_page" => Ok(Field::PerPage), + _ => Ok(Field::Other), + } + } + } + + deserializer.deserialize_identifier(FieldVisitor) + } } struct PaginationParamsVisitor; @@ -62,6 +91,9 @@ impl<'de> Deserialize<'de> for PaginationParams { } per_page = Some(map.next_value()?); } + Field::Other => { + let _ = map.next_value::()?; + } } } let page = page.unwrap_or(MIN_PAGE); @@ -197,6 +229,14 @@ mod tests { assert_eq!(params.offset(), 0); } + #[test] + fn deserialize_pagination_foreign_params() { + let params = serde_urlencoded::from_str::("search=term").unwrap(); + assert_eq!(params.page(), 1); + assert_eq!(params.per_page(), 50); + assert_eq!(params.offset(), 0); + } + #[test] fn deserialize_pagination_params_zero_values() { let params = serde_urlencoded::from_str::("page=0&per_page=0").unwrap(); diff --git a/crates/defguard_core/tests/integration/api/activity_log.rs b/crates/defguard_core/tests/integration/api/activity_log.rs index 9e837bb021..707a5d3911 100644 --- a/crates/defguard_core/tests/integration/api/activity_log.rs +++ b/crates/defguard_core/tests/integration/api/activity_log.rs @@ -5,7 +5,7 @@ use defguard_common::db::{Id, NoId, models::User, setup_pool}; use defguard_core::db::models::activity_log::{ActivityLogEvent, ActivityLogModule, EventType}; use reqwest::StatusCode; use serde::Deserialize; -use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; +use sqlx::postgres::{PgConnectOptions, PgPool, PgPoolOptions}; use super::common::{client::TestClient, get_db_user, make_client_with_db}; @@ -23,7 +23,7 @@ struct PaginationMeta { #[derive(Clone, Deserialize)] struct ApiActivityLogEvent { - id: i64, + id: Id, timestamp: NaiveDateTime, username: String, ip: Option, @@ -54,12 +54,17 @@ async fn fetch_activity_log( .get(activity_log_url(marker, extra_query)) .send() .await; - assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response.status(), + StatusCode::OK, + "{}", + response.text().await + ); response.json().await } async fn save_activity_log_event( - db: &sqlx::PgPool, + db: &PgPool, user: &User, marker: &str, description_suffix: &str, @@ -154,7 +159,11 @@ async fn test_activity_log_timestamp_desc_uses_id_desc_for_equal_timestamps( save_activity_log_event(&db, &admin, &marker, "third", shared_timestamp).await; let payload = fetch_activity_log(&client, &marker, "sort_by=timestamp&sort_order=desc").await; - let ids: Vec = payload.data.into_iter().map(|event| event.id).collect(); + let ids = payload + .data + .into_iter() + .map(|event| event.id) + .collect::>(); assert_eq!( ids, @@ -196,11 +205,11 @@ async fn test_activity_log_timestamp_desc_orders_by_timestamp_then_id( .await; let payload = fetch_activity_log(&client, &marker, "sort_by=timestamp&sort_order=desc").await; - let ordered_events: Vec<(i64, NaiveDateTime)> = payload + let ordered_events = payload .data .into_iter() .map(|event| (event.id, event.timestamp)) - .collect(); + .collect::>(); assert_eq!( ordered_events, @@ -239,11 +248,15 @@ async fn test_activity_log_timestamp_asc_uses_id_asc_for_equal_timestamps( .await; let payload = fetch_activity_log(&client, &marker, "sort_by=timestamp&sort_order=asc").await; - let ids: Vec = payload.data.into_iter().map(|event| event.id).collect(); + let ids = payload + .data + .into_iter() + .map(|event| event.id) + .collect::>(); assert_eq!( ids, - vec![first_event.id, second_event.id, later_event.id], + [first_event.id, second_event.id, later_event.id], "ascending timestamp sort should use ascending ids for equal timestamps", ); } @@ -270,11 +283,11 @@ async fn test_activity_log_non_timestamp_sort_uses_id_as_stable_tiebreaker( save_activity_log_event(&db, &admin, &marker, "admin-second", shared_timestamp).await; let payload = fetch_activity_log(&client, &marker, "sort_by=username&sort_order=asc").await; - let ordered_events: Vec<(String, i64)> = payload + let ordered_events = payload .data .into_iter() .map(|event| (event.username, event.id)) - .collect(); + .collect::>(); assert_eq!( ordered_events, @@ -333,14 +346,14 @@ async fn test_activity_log_pagination_is_stable_across_pages_for_equal_timestamp assert_eq!(page_one.pagination.next_page, Some(2)); assert_eq!(page_two.pagination.next_page, None); - let combined_ids: Vec = page_one + let combined_ids = page_one .data .iter() .chain(page_two.data.iter()) .map(|event| event.id) - .collect(); - let unique_ids: HashSet = combined_ids.iter().copied().collect(); - let expected_ids: Vec = inserted_ids.into_iter().rev().collect(); + .collect::>(); + let unique_ids = combined_ids.iter().copied().collect::>(); + let expected_ids = inserted_ids.into_iter().rev().collect::>(); assert_eq!( combined_ids, expected_ids, diff --git a/crates/defguard_event_logger/src/description.rs b/crates/defguard_event_logger/src/description.rs index 69b08bebfc..fdbf215703 100644 --- a/crates/defguard_event_logger/src/description.rs +++ b/crates/defguard_event_logger/src/description.rs @@ -311,9 +311,9 @@ pub fn get_enrollment_event_description(event: &EnrollmentEvent) -> Option { Some("User completed enrollment process".to_string()) } - EnrollmentEvent::PasswordResetRequested => None, - EnrollmentEvent::PasswordResetStarted => None, - EnrollmentEvent::PasswordResetCompleted => None, + EnrollmentEvent::PasswordResetRequested + | EnrollmentEvent::PasswordResetStarted + | EnrollmentEvent::PasswordResetCompleted => None, EnrollmentEvent::TokenAdded { user } => { Some(format!("Added enrollment token for user {user}")) }