Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 4 additions & 5 deletions tools/defguard_generator/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@ repository = "https://github.com/DefGuard/defguard"
rust-version = "1.87.0"

[dependencies]
defguard_common = { workspace = true }
anyhow = { workspace = true }
chrono = { workspace = true }
rand = { workspace = true }
clap = { workspace = true, features = ["derive"] }
defguard_common = { workspace = true }
defguard_core = { workspace = true }
rand = { workspace = true }
sqlx = { workspace = true, features = ["postgres", "runtime-tokio-native-tls", "chrono"] }
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
anyhow = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true, features = ["env-filter"] }

[dev-dependencies]
30 changes: 30 additions & 0 deletions tools/defguard_generator/src/acl_rules.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use anyhow::Result;
use defguard_core::enterprise::db::models::acl::{AclRule, RuleState};
use sqlx::{PgPool, query};

pub async fn generate_acl_rules(pool: PgPool, num_rules: u32) -> Result<()> {
truncate_with_restart(&pool).await?;

for index in 0..num_rules {
let mut acl_rule = AclRule::default();
acl_rule.name = format!("Generated {index}");
acl_rule.state = RuleState::Applied;
acl_rule.all_locations = true;
acl_rule.allow_all_users = true;
acl_rule.allow_all_groups = true;
acl_rule.allow_all_network_devices = true;
acl_rule.save(&pool).await?;
}

Ok(())
}

/// Remove all records from sessions and stats tables.
/// This also resets the auto-incrementing sequences.
async fn truncate_with_restart(pool: &PgPool) -> Result<()> {
query("TRUNCATE aclrule RESTART IDENTITY CASCADE")
.execute(pool)
.await?;

Ok(())
}
5 changes: 3 additions & 2 deletions tools/defguard_generator/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod user_devices;
pub mod users;
pub mod acl_rules;
mod user_devices;
mod users;
pub mod vpn_session_stats;
13 changes: 10 additions & 3 deletions tools/defguard_generator/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use anyhow::Result;
use clap::{Parser, Subcommand};
use defguard_common::db::{Id, init_db};
use defguard_generator::vpn_session_stats::{
VpnSessionGeneratorConfig, generate_vpn_session_stats,
use defguard_generator::{
acl_rules::generate_acl_rules,
vpn_session_stats::{VpnSessionGeneratorConfig, generate_vpn_session_stats},
};
use tracing_subscriber::EnvFilter;

Expand Down Expand Up @@ -35,7 +36,7 @@ enum Commands {
#[arg(long)]
location_id: Id,
#[arg(long)]
num_users: u16,
num_users: usize,
#[arg(long)]
devices_per_user: u8,
#[arg(long)]
Expand All @@ -47,6 +48,11 @@ enum Commands {
#[arg(long, default_value_t = 1000)]
stats_batch_size: u16,
},
/// Generates ACL rules
AclRules {
#[arg(long)]
num_rules: u32,
},
}

#[tokio::main]
Expand Down Expand Up @@ -92,6 +98,7 @@ async fn main() -> Result<()> {

generate_vpn_session_stats(pool, config).await?;
}
Commands::AclRules { num_rules } => generate_acl_rules(pool, num_rules).await?,
};

Ok(())
Expand Down
15 changes: 8 additions & 7 deletions tools/defguard_generator/src/vpn_session_stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use defguard_common::db::{
},
};
use rand::{Rng, rngs::ThreadRng};
use sqlx::{PgConnection, PgPool, QueryBuilder};
use sqlx::{PgConnection, PgPool, QueryBuilder, query};
use tracing::{debug, info};

use crate::{user_devices::prepare_user_devices, users::prepare_users};
Expand All @@ -24,7 +24,7 @@ const HANDSHAKE_INTERVAL: Duration = Duration::minutes(2);
#[derive(Debug)]
pub struct VpnSessionGeneratorConfig {
pub location_id: Id,
pub num_users: u16,
pub num_users: usize,
pub devices_per_user: u8,
pub sessions_per_device: u8,
pub no_truncate: bool,
Expand Down Expand Up @@ -53,7 +53,7 @@ pub async fn generate_vpn_session_stats(
let gateway = prepare_gateway(&pool, location.id).await?;

// prepare requested number of users
let user_count = config.num_users as usize;
let user_count = config.num_users;
let users = prepare_users(&pool, &mut rng, user_count).await?;

// generate sessions for each user
Expand Down Expand Up @@ -148,10 +148,10 @@ pub async fn generate_vpn_session_stats(
Ok(())
}

/// Remove all records from sessions & stats tables.
/// This also resets the auto-incrementing sequences
/// Remove all records from sessions and stats tables.
/// This also resets the auto-incrementing sequences.
async fn truncate_with_restart(pool: &PgPool) -> Result<()> {
sqlx::query("TRUNCATE TABLE vpn_client_session RESTART IDENTITY CASCADE")
query("TRUNCATE vpn_client_session RESTART IDENTITY CASCADE")
.execute(pool)
.await?;

Expand Down Expand Up @@ -248,7 +248,8 @@ async fn insert_stats_batch(
}

let mut query_builder = QueryBuilder::new(
"INSERT INTO vpn_session_stats (session_id, gateway_id, collected_at, latest_handshake, endpoint, total_upload, total_download, upload_diff, download_diff) ",
"INSERT INTO vpn_session_stats (session_id, gateway_id, collected_at, latest_handshake, \
endpoint, total_upload, total_download, upload_diff, download_diff) ",
);

query_builder.push_values(stats_batch, |mut b, stats| {
Expand Down
Loading