Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub async fn create_activity_log_stream(
debug!("User {session_username} creates activity log stream");
// validate config
let _ = ActivityLogStreamConfig::from_serde_value(&data.stream_type, &data.stream_config)?;
let stream_model: ActivityLogStream<NoId> = ActivityLogStream {
let stream_model = ActivityLogStream {
id: NoId,
name: data.name,
stream_type: data.stream_type,
Expand Down
168 changes: 152 additions & 16 deletions crates/defguard_core/src/handlers/pagination.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,118 @@ use axum::{
response::{IntoResponse, Response},
};
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize, de};

use crate::error::WebError;

const DEFAULT_PER_PAGE: u32 = 50;
const MIN_PAGE: u32 = 1;
const MIN_PER_PAGE: u32 = 1;
const MAX_PER_PAGE: u32 = 100;

/// Query params for paginated endpoints
#[derive(Deserialize)]
#[serde(default)]
pub(crate) struct PaginationParams {
page: u32,
per_page: u32,
}

/// Implement custom deserializer to control default values and limits.
impl<'de> Deserialize<'de> for PaginationParams {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
enum Field {
Page,
PerPage,
Other, // ignore other fields
}

impl<'de> Deserialize<'de> for Field {
fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
where
D: Deserializer<'de>,
{
struct FieldVisitor;

impl de::Visitor<'_> for FieldVisitor {
type Value = Field;

fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("`page` or `per_page`")
}

fn visit_str<E>(self, value: &str) -> Result<Field, E>
where
E: de::Error,
{
match value {
"page" => Ok(Field::Page),
"per_page" => Ok(Field::PerPage),
_ => Ok(Field::Other),
}
}
}

deserializer.deserialize_identifier(FieldVisitor)
}
}

struct PaginationParamsVisitor;

impl<'de> de::Visitor<'de> for PaginationParamsVisitor {
type Value = PaginationParams;

fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("struct PaginationParams")
}

fn visit_map<V>(self, mut map: V) -> Result<PaginationParams, V::Error>
where
V: de::MapAccess<'de>,
{
let mut page = None;
let mut per_page = None;
while let Some(key) = map.next_key()? {
match key {
Field::Page => {
if page.is_some() {
return Err(de::Error::duplicate_field("page"));
}
page = Some(map.next_value()?);
}
Field::PerPage => {
if per_page.is_some() {
return Err(de::Error::duplicate_field("per_page"));
}
per_page = Some(map.next_value()?);
}
Field::Other => {
let _ = map.next_value::<de::IgnoredAny>()?;
}
}
}
let page = page.unwrap_or(MIN_PAGE);
let per_page = per_page.unwrap_or(DEFAULT_PER_PAGE);
Ok(PaginationParams::new(page, per_page))
}
}

const FIELDS: &[&str] = &["page", "per_page"];
deserializer.deserialize_struct("PaginationParams", FIELDS, PaginationParamsVisitor)
}
}

impl PaginationParams {
/// Constructor.
#[must_use]
pub fn new(page: u32, per_page: u32) -> Self {
Self {
page: page.max(MIN_PAGE),
per_page: per_page.clamp(MIN_PER_PAGE, MAX_PER_PAGE),
}
}

/// Page getter.
#[must_use]
pub fn page(&self) -> u32 {
Expand All @@ -33,19 +132,15 @@ impl PaginationParams {
/// Calculate offset.
#[must_use]
pub fn offset(&self) -> u32 {
if self.page == 0 {
self.per_page
} else {
(self.page - 1) * self.per_page
}
(self.page - 1) * self.per_page
}
}

impl Default for PaginationParams {
fn default() -> Self {
Self {
page: 1,
per_page: 50,
page: MIN_PAGE,
per_page: DEFAULT_PER_PAGE,
}
}
}
Expand All @@ -71,12 +166,7 @@ impl PaginationMeta {
#[must_use]
fn from_pagination(pagination: PaginationParams, total_items: u32) -> Self {
let PaginationParams { page, per_page } = pagination;
let total_pages = if per_page <= 1 {
// For 0 and 1, assume per_page is 1.
total_items
} else {
total_items.div_ceil(per_page)
};
let total_pages = total_items.div_ceil(per_page);
let next_page = if page < total_pages {
Some(page + 1)
} else {
Expand Down Expand Up @@ -126,3 +216,49 @@ where
}
}
}

#[cfg(test)]
mod tests {
use super::PaginationParams;

#[test]
fn deserialize_pagination_params_defaults() {
let params = serde_urlencoded::from_str::<PaginationParams>("").unwrap();
assert_eq!(params.page(), 1);
assert_eq!(params.per_page(), 50);
assert_eq!(params.offset(), 0);
}

#[test]
fn deserialize_pagination_foreign_params() {
let params = serde_urlencoded::from_str::<PaginationParams>("search=term").unwrap();
assert_eq!(params.page(), 1);
assert_eq!(params.per_page(), 50);
assert_eq!(params.offset(), 0);
}

#[test]
fn deserialize_pagination_params_zero_values() {
let params = serde_urlencoded::from_str::<PaginationParams>("page=0&per_page=0").unwrap();
assert_eq!(params.page(), 1);
assert_eq!(params.per_page(), 1);
assert_eq!(params.offset(), 0);
}

#[test]
fn deserialize_pagination_params_large_values() {
let params =
serde_urlencoded::from_str::<PaginationParams>("page=1000&per_page=1000").unwrap();
assert_eq!(params.page(), 1000);
assert_eq!(params.per_page(), 100);
assert_eq!(params.offset(), 99900);
}

#[test]
fn deserialize_pagination_params_valid_values() {
let params = serde_urlencoded::from_str::<PaginationParams>("page=3&per_page=25").unwrap();
assert_eq!(params.page(), 3);
assert_eq!(params.per_page(), 25);
assert_eq!(params.offset(), 50);
}
}
43 changes: 28 additions & 15 deletions crates/defguard_core/tests/integration/api/activity_log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use defguard_common::db::{Id, NoId, models::User, setup_pool};
use defguard_core::db::models::activity_log::{ActivityLogEvent, ActivityLogModule, EventType};
use reqwest::StatusCode;
use serde::Deserialize;
use sqlx::postgres::{PgConnectOptions, PgPoolOptions};
use sqlx::postgres::{PgConnectOptions, PgPool, PgPoolOptions};

use super::common::{client::TestClient, get_db_user, make_client_with_db};

Expand All @@ -23,7 +23,7 @@ struct PaginationMeta {

#[derive(Clone, Deserialize)]
struct ApiActivityLogEvent {
id: i64,
id: Id,
timestamp: NaiveDateTime,
username: String,
ip: Option<String>,
Expand Down Expand Up @@ -54,12 +54,17 @@ async fn fetch_activity_log(
.get(activity_log_url(marker, extra_query))
.send()
.await;
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.status(),
StatusCode::OK,
"{}",
response.text().await
);
response.json().await
}

async fn save_activity_log_event(
db: &sqlx::PgPool,
db: &PgPool,
user: &User<Id>,
marker: &str,
description_suffix: &str,
Expand Down Expand Up @@ -154,7 +159,11 @@ async fn test_activity_log_timestamp_desc_uses_id_desc_for_equal_timestamps(
save_activity_log_event(&db, &admin, &marker, "third", shared_timestamp).await;

let payload = fetch_activity_log(&client, &marker, "sort_by=timestamp&sort_order=desc").await;
let ids: Vec<i64> = payload.data.into_iter().map(|event| event.id).collect();
let ids = payload
.data
.into_iter()
.map(|event| event.id)
.collect::<Vec<_>>();

assert_eq!(
ids,
Expand Down Expand Up @@ -196,11 +205,11 @@ async fn test_activity_log_timestamp_desc_orders_by_timestamp_then_id(
.await;

let payload = fetch_activity_log(&client, &marker, "sort_by=timestamp&sort_order=desc").await;
let ordered_events: Vec<(i64, NaiveDateTime)> = payload
let ordered_events = payload
.data
.into_iter()
.map(|event| (event.id, event.timestamp))
.collect();
.collect::<Vec<_>>();

assert_eq!(
ordered_events,
Expand Down Expand Up @@ -239,11 +248,15 @@ async fn test_activity_log_timestamp_asc_uses_id_asc_for_equal_timestamps(
.await;

let payload = fetch_activity_log(&client, &marker, "sort_by=timestamp&sort_order=asc").await;
let ids: Vec<i64> = payload.data.into_iter().map(|event| event.id).collect();
let ids = payload
.data
.into_iter()
.map(|event| event.id)
.collect::<Vec<_>>();

assert_eq!(
ids,
vec![first_event.id, second_event.id, later_event.id],
[first_event.id, second_event.id, later_event.id],
"ascending timestamp sort should use ascending ids for equal timestamps",
);
}
Expand All @@ -270,11 +283,11 @@ async fn test_activity_log_non_timestamp_sort_uses_id_as_stable_tiebreaker(
save_activity_log_event(&db, &admin, &marker, "admin-second", shared_timestamp).await;

let payload = fetch_activity_log(&client, &marker, "sort_by=username&sort_order=asc").await;
let ordered_events: Vec<(String, i64)> = payload
let ordered_events = payload
.data
.into_iter()
.map(|event| (event.username, event.id))
.collect();
.collect::<Vec<_>>();

assert_eq!(
ordered_events,
Expand Down Expand Up @@ -333,14 +346,14 @@ async fn test_activity_log_pagination_is_stable_across_pages_for_equal_timestamp
assert_eq!(page_one.pagination.next_page, Some(2));
assert_eq!(page_two.pagination.next_page, None);

let combined_ids: Vec<i64> = page_one
let combined_ids = page_one
.data
.iter()
.chain(page_two.data.iter())
.map(|event| event.id)
.collect();
let unique_ids: HashSet<i64> = combined_ids.iter().copied().collect();
let expected_ids: Vec<i64> = inserted_ids.into_iter().rev().collect();
.collect::<Vec<_>>();
let unique_ids = combined_ids.iter().copied().collect::<HashSet<_>>();
let expected_ids = inserted_ids.into_iter().rev().collect::<Vec<_>>();

assert_eq!(
combined_ids, expected_ids,
Expand Down
6 changes: 3 additions & 3 deletions crates/defguard_event_logger/src/description.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,9 +311,9 @@ pub fn get_enrollment_event_description(event: &EnrollmentEvent) -> Option<Strin
EnrollmentEvent::EnrollmentCompleted => {
Some("User completed enrollment process".to_string())
}
EnrollmentEvent::PasswordResetRequested => None,
EnrollmentEvent::PasswordResetStarted => None,
EnrollmentEvent::PasswordResetCompleted => None,
EnrollmentEvent::PasswordResetRequested
| EnrollmentEvent::PasswordResetStarted
| EnrollmentEvent::PasswordResetCompleted => None,
EnrollmentEvent::TokenAdded { user } => {
Some(format!("Added enrollment token for user {user}"))
}
Expand Down
Loading