diff --git a/Cargo.lock b/Cargo.lock index e9797f0ff9..5a43cbdc1b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1297,6 +1297,21 @@ dependencies = [ "tracing", ] +[[package]] +name = "defguard_generator" +version = "0.0.0" +dependencies = [ + "anyhow", + "chrono", + "clap", + "defguard_common", + "rand 0.8.5", + "sqlx", + "tokio", + "tracing", + "tracing-subscriber", +] + [[package]] name = "defguard_mail" version = "0.0.0" diff --git a/Cargo.toml b/Cargo.toml index 5d2ad055ea..bfa9367073 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,8 @@ repository = "https://github.com/DefGuard/defguard" rust-version = "1.85.1" [workspace] -members = ["crates/*"] +members = ["crates/*", "tools/*"] +default-members = ["crates/*"] resolver = "2" [workspace.dependencies] diff --git a/crates/defguard_common/src/db/models/device.rs b/crates/defguard_common/src/db/models/device.rs index ed8cffb478..dfbf670997 100644 --- a/crates/defguard_common/src/db/models/device.rs +++ b/crates/defguard_common/src/db/models/device.rs @@ -106,10 +106,10 @@ impl fmt::Display for Device { } } -impl Distribution> for Standard { - fn sample(&self, rng: &mut R) -> Device { +impl Distribution for Standard { + fn sample(&self, rng: &mut R) -> Device { Device { - id: rng.r#gen(), + id: NoId, name: Alphanumeric.sample_string(rng, 8), wireguard_pubkey: Alphanumeric.sample_string(rng, 32), user_id: rng.r#gen(), diff --git a/crates/defguard_core/src/enterprise/firewall/tests.rs b/crates/defguard_core/src/enterprise/firewall/tests.rs index 316bb4df68..a39c26c4a4 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests.rs @@ -51,8 +51,8 @@ fn random_user_with_id(rng: &mut R, id: Id) -> User { } fn random_network_device_with_id(rng: &mut R, id: Id) -> Device { - let mut device: Device = rng.r#gen(); - device.id = id; + let device: Device = rng.r#gen(); + let mut device = device.with_id(id); device.device_type = DeviceType::Network; device } diff --git a/deny.toml b/deny.toml index 4097d67ed2..6cee259570 100644 --- a/deny.toml +++ b/deny.toml @@ -161,6 +161,10 @@ exceptions = [ "AGPL-3.0-only", "AGPL-3.0-or-later", ], crate = "defguard_certs" }, + { allow = [ + "AGPL-3.0-only", + "AGPL-3.0-or-later", + ], crate = "defguard_generator" }, ] # Some crates don't have (easily) machine readable licensing information, diff --git a/tools/defguard_generator/Cargo.toml b/tools/defguard_generator/Cargo.toml new file mode 100644 index 0000000000..6c37f70ba9 --- /dev/null +++ b/tools/defguard_generator/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "defguard_generator" +version = "0.0.0" +edition = "2024" +license-file = "../../LICENSE.md" +homepage = "https://defguard.net/" +repository = "https://github.com/DefGuard/defguard" +rust-version = "1.85.1" + +[dependencies] +defguard_common = { workspace = true } +chrono = { workspace = true } +rand = { workspace = true } +clap = { workspace = true, features = ["derive"] } +sqlx = { workspace = true, features = ["postgres", "runtime-tokio-native-tls", "chrono"] } +tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } +anyhow = { workspace = true } +tracing = { workspace = true } +tracing-subscriber = { workspace = true, features = ["env-filter"] } + +[dev-dependencies] diff --git a/tools/defguard_generator/README.md b/tools/defguard_generator/README.md new file mode 100644 index 0000000000..c044980add --- /dev/null +++ b/tools/defguard_generator/README.md @@ -0,0 +1,14 @@ +# Defguard Object Generator + +This crate contains a simple generator for creating users, devices, stats etc during development. + +### Usage + +```bash +cargo run -p defguard_generator -- vpn-session-stats \ + --location-id 1 \ + --num-users 10 \ + --devices-per-user 2 \ + --sessions-per-device 5 +``` + diff --git a/tools/defguard_generator/src/lib.rs b/tools/defguard_generator/src/lib.rs new file mode 100644 index 0000000000..38c1a426e9 --- /dev/null +++ b/tools/defguard_generator/src/lib.rs @@ -0,0 +1,3 @@ +pub mod user_devices; +pub mod users; +pub mod vpn_session_stats; diff --git a/tools/defguard_generator/src/main.rs b/tools/defguard_generator/src/main.rs new file mode 100644 index 0000000000..e9b5b814c3 --- /dev/null +++ b/tools/defguard_generator/src/main.rs @@ -0,0 +1,84 @@ +use anyhow::Result; +use clap::{Parser, Subcommand}; +use defguard_common::db::{Id, init_db}; +use defguard_generator::vpn_session_stats::{ + VpnSessionGeneratorConfig, generate_vpn_session_stats, +}; +use tracing::Level; + +#[derive(Parser)] +#[command(about, long_about = None)] +struct Cli { + #[arg(long, env = "DEFGUARD_DB_HOST", default_value = "localhost")] + pub database_host: String, + + #[arg(long, env = "DEFGUARD_DB_PORT", default_value_t = 5432)] + pub database_port: u16, + + #[arg(long, env = "DEFGUARD_DB_NAME", default_value = "defguard")] + pub database_name: String, + + #[arg(long, env = "DEFGUARD_DB_USER", default_value = "defguard")] + pub database_user: String, + + #[arg(long, env = "DEFGUARD_DB_PASSWORD", default_value = "")] + pub database_password: String, + + #[command(subcommand)] + command: Commands, +} + +#[derive(Subcommand)] +enum Commands { + /// generates VPN session stats + VpnSessionStats { + #[arg(long)] + location_id: Id, + #[arg(long)] + num_users: u16, + #[arg(long)] + devices_per_user: u8, + #[arg(long)] + sessions_per_device: u8, + }, +} + +#[tokio::main] +async fn main() -> Result<()> { + // Initialize logging + tracing_subscriber::fmt().with_max_level(Level::INFO).init(); + + // parse CLI options + let cli = Cli::parse(); + + // setup DB pool + let pool = init_db( + &cli.database_host, + cli.database_port, + &cli.database_name, + &cli.database_user, + &cli.database_password, + ) + .await; + + // execute based on the selected subcommand + match cli.command { + Commands::VpnSessionStats { + location_id, + num_users, + devices_per_user, + sessions_per_device, + } => { + let config = VpnSessionGeneratorConfig { + location_id, + num_users, + devices_per_user, + sessions_per_device, + }; + + generate_vpn_session_stats(pool, config).await?; + } + }; + + Ok(()) +} diff --git a/tools/defguard_generator/src/user_devices.rs b/tools/defguard_generator/src/user_devices.rs new file mode 100644 index 0000000000..19b726a2ad --- /dev/null +++ b/tools/defguard_generator/src/user_devices.rs @@ -0,0 +1,37 @@ +use anyhow::Result; +use defguard_common::db::{ + Id, + models::{Device, User}, +}; +use rand::{Rng, rngs::ThreadRng}; +use sqlx::PgPool; +use tracing::info; + +pub async fn prepare_user_devices( + pool: &PgPool, + rng: &mut ThreadRng, + user: &User, + devices_per_user: usize, +) -> Result>> { + // fetch all existing devices for a given user + let mut user_devices = Device::all_for_username(pool, &user.username).await?; + + // if there are enough users just return the required number + if user_devices.len() >= devices_per_user { + info!( + "Found {} existing devices for user {user} in the database. Using the required number.", + user_devices.len() + ); + return Ok(user_devices[..devices_per_user].to_vec()); + } + + // if there are not enough users create new ones + for _ in 0..(devices_per_user - user_devices.len()) { + let mut device: Device = rng.r#gen(); + device.user_id = user.id; + let device = device.save(pool).await?; + user_devices.push(device); + } + + Ok(user_devices) +} diff --git a/tools/defguard_generator/src/users.rs b/tools/defguard_generator/src/users.rs new file mode 100644 index 0000000000..b223b6d120 --- /dev/null +++ b/tools/defguard_generator/src/users.rs @@ -0,0 +1,34 @@ +use anyhow::Result; +use defguard_common::db::{Id, models::User}; +use rand::{Rng, rngs::ThreadRng}; +use sqlx::PgPool; +use tracing::info; + +pub async fn prepare_users( + pool: &PgPool, + rng: &mut ThreadRng, + num_users: usize, +) -> Result>> { + info!("Preparing {num_users} random users for generating VPN session stats"); + + // fetch all existing users + let mut all_users = User::all(pool).await?; + + // if there are enough users just return the required number + if all_users.len() >= num_users { + info!( + "Found {} existing users in the database. Using the required number.", + all_users.len() + ); + return Ok(all_users[..num_users].to_vec()); + } + + // if there are not enough users create new ones + for _ in 0..(num_users - all_users.len()) { + let user: User = rng.r#gen(); + let user = user.save(pool).await?; + all_users.push(user); + } + + Ok(all_users) +} diff --git a/tools/defguard_generator/src/vpn_session_stats.rs b/tools/defguard_generator/src/vpn_session_stats.rs new file mode 100644 index 0000000000..f2835fa2f1 --- /dev/null +++ b/tools/defguard_generator/src/vpn_session_stats.rs @@ -0,0 +1,192 @@ +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + +use anyhow::Result; +use chrono::{Duration, NaiveDateTime, Utc}; +use defguard_common::db::{ + Id, + models::{ + WireguardNetwork, + gateway::Gateway, + vpn_client_session::{VpnClientSession, VpnClientSessionState}, + vpn_session_stats::VpnSessionStats, + wireguard::LocationMfaMode, + }, +}; +use rand::{Rng, rngs::ThreadRng}; +use sqlx::{PgConnection, PgPool}; +use tracing::info; + +use crate::{user_devices::prepare_user_devices, users::prepare_users}; + +const STATS_COLLECTION_INTERVAL: Duration = Duration::seconds(30); +const HANDSHAKE_INTERVAL: Duration = Duration::minutes(2); + +pub struct VpnSessionGeneratorConfig { + pub location_id: Id, + pub num_users: u16, + pub devices_per_user: u8, + pub sessions_per_device: u8, +} + +pub async fn generate_vpn_session_stats( + pool: PgPool, + config: VpnSessionGeneratorConfig, +) -> Result<()> { + let mut rng = rand::thread_rng(); + + // clear sessions & stats tables + info!("Clearing existing sessions & stats"); + truncate_with_restart(&pool).await?; + + // fetch specified location + let location = WireguardNetwork::find_by_id(&pool, config.location_id) + .await? + .expect("Location not found"); + + // prepare a gateway + let gateway = prepare_gateway(&pool, location.id).await?; + + // prepare requested number of users + let user_count = config.num_users as usize; + let users = prepare_users(&pool, &mut rng, user_count).await?; + + // generate sessions for each user + for (i, user) in users.into_iter().enumerate() { + info!("[{i}/{user_count}] Generating VPN sessions for user {user}"); + + // begin DB transaction + let mut transaction = pool.begin().await?; + + // prepare requested number of devices + let devices = + prepare_user_devices(&pool, &mut rng, &user, config.devices_per_user as usize).await?; + + for device in devices { + // generate requested number of sessions for a device + // we always start with a session that's currently active + // and generate past ones as needed + + // start with the active session + let mut session_end = Utc::now().naive_utc(); + + for i in 0..config.sessions_per_device { + let session_duration = Duration::minutes(rng.gen_range(10..120)); + let session_start = session_end - session_duration; + + let mut session = VpnClientSession::new( + location.id, + device.user_id, + device.id, + Some(session_start), + LocationMfaMode::Disabled, + ) + .save(&mut *transaction) + .await?; + + // mark all but the first session as disconnected + if i > 0 { + session.state = VpnClientSessionState::Disconnected; + session.disconnected_at = Some(session_end); + session.save(&mut *transaction).await?; + } + + generate_mock_session_stats( + &mut transaction, + &mut rng, + session.id, + gateway.id, + session_start, + session_end, + ) + .await?; + + // update end timestamp for next session + session_end -= Duration::minutes(rng.gen_range(30..120)); + } + } + transaction.commit().await?; + } + + Ok(()) +} + +/// Remove all records from sessions & stats tables. +/// This also resets the auto-incrementing sequences +async fn truncate_with_restart(pool: &PgPool) -> Result<()> { + sqlx::query("TRUNCATE TABLE vpn_client_session RESTART IDENTITY CASCADE") + .execute(pool) + .await?; + + Ok(()) +} + +async fn prepare_gateway(pool: &PgPool, location_id: Id) -> Result> { + // check if a gateway exists already + let existing_gateways = Gateway::find_by_network_id(pool, location_id).await?; + match existing_gateways.into_iter().next() { + Some(gateway) => Ok(gateway), + None => { + let gateway = Gateway::new(location_id, "http://localhost:50055") + .save(pool) + .await?; + Ok(gateway) + } + } +} + +async fn generate_mock_session_stats( + transaction: &mut PgConnection, + rng: &mut ThreadRng, + session_id: Id, + gateway_id: Id, + session_start: NaiveDateTime, + session_end: NaiveDateTime, +) -> Result<()> { + let mut latest_handshake = session_start; + let mut next_handshake = latest_handshake + HANDSHAKE_INTERVAL; + let mut collected_at = session_start; + let mut total_upload = 0; + let mut total_download = 0; + + // assume the IP remains static within a single session + let endpoint = random_socket_addr(rng).to_string(); + + while collected_at <= session_end { + // generate traffic + let upload_diff = rng.gen_range(100..100_000); + total_upload += upload_diff; + let download_diff = rng.gen_range(100..100_000); + total_download += download_diff; + + VpnSessionStats::new( + session_id, + gateway_id, + collected_at, + latest_handshake, + endpoint.clone(), + total_upload, + total_download, + download_diff, + download_diff, + ) + .save(&mut *transaction) + .await?; + + // update variables for next sample + collected_at += STATS_COLLECTION_INTERVAL; + + // update handshake if necessary + if collected_at > next_handshake { + latest_handshake = next_handshake; + next_handshake = latest_handshake + HANDSHAKE_INTERVAL; + } + } + + Ok(()) +} + +fn random_socket_addr(rng: &mut ThreadRng) -> SocketAddr { + let ip = Ipv4Addr::new(rng.r#gen(), rng.r#gen(), rng.r#gen(), rng.r#gen()); + let port = rng.r#gen(); + SocketAddr::new(IpAddr::V4(ip), port) +}