diff --git a/crates/defguard_common/src/db/models/vpn_client_session.rs b/crates/defguard_common/src/db/models/vpn_client_session.rs index 8f2167e2ef..3ca957d93b 100644 --- a/crates/defguard_common/src/db/models/vpn_client_session.rs +++ b/crates/defguard_common/src/db/models/vpn_client_session.rs @@ -7,7 +7,7 @@ use crate::db::{ models::{WireguardNetwork, vpn_session_stats::VpnSessionStats, wireguard::LocationMfaMode}, }; -#[derive(Default, Type)] +#[derive(Debug, Default, Type)] #[sqlx(type_name = "vpn_client_session_state", rename_all = "lowercase")] pub enum VpnClientSessionState { #[default] @@ -17,7 +17,7 @@ pub enum VpnClientSessionState { } /// Represents a single VPN client session from creation to eventual disconnection -#[derive(Model)] +#[derive(Debug, Model)] #[table(vpn_client_session)] pub struct VpnClientSession { pub id: I, diff --git a/tools/defguard_generator/README.md b/tools/defguard_generator/README.md index c044980add..562f45d238 100644 --- a/tools/defguard_generator/README.md +++ b/tools/defguard_generator/README.md @@ -2,6 +2,18 @@ This crate contains a simple generator for creating users, devices, stats etc during development. +### Database connection + +The generator uses the same environment variables (or CLI options) for DB connection setup as the core binary: + +- DEFGUARD_DB_HOST +- DEFGUARD_DB_PORT +- DEFGUARD_DB_NAME +- DEFGUARD_DB_USER +- DEFGUARD_DB_PASSWORD + +This means that if you have a development environment set up already it should just work. + ### Usage ```bash @@ -12,3 +24,13 @@ cargo run -p defguard_generator -- vpn-session-stats \ --sessions-per-device 5 ``` +### Session generation logic + +For each device the generator always starts with creating an active (not disconnected) session. +If there are more sessions per device to be generated it goes backwards in time and creates +additional disconnected sessions. +Session duration and gaps between sessions are randomized but there is no logic to verify if +sessions are overlapping so by default the generator runs a `TRUNCATE` query at the start. +To disable this behavior (for example when running it multiple times for separate locations) +use the `--no-truncate` CLI flag. + diff --git a/tools/defguard_generator/src/main.rs b/tools/defguard_generator/src/main.rs index e9b5b814c3..3c6feea0e9 100644 --- a/tools/defguard_generator/src/main.rs +++ b/tools/defguard_generator/src/main.rs @@ -4,7 +4,7 @@ use defguard_common::db::{Id, init_db}; use defguard_generator::vpn_session_stats::{ VpnSessionGeneratorConfig, generate_vpn_session_stats, }; -use tracing::Level; +use tracing_subscriber::EnvFilter; #[derive(Parser)] #[command(about, long_about = None)] @@ -30,7 +30,7 @@ struct Cli { #[derive(Subcommand)] enum Commands { - /// generates VPN session stats + /// Generates fake VPN session statistics. VpnSessionStats { #[arg(long)] location_id: Id, @@ -40,13 +40,23 @@ enum Commands { devices_per_user: u8, #[arg(long)] sessions_per_device: u8, + /// don't truncate sessions & stats tables before generating stats + #[arg(long)] + no_truncate: bool, + /// insert stats records in batches of specified size + #[arg(long, default_value_t = 1000)] + stats_batch_size: u16, }, } #[tokio::main] async fn main() -> Result<()> { // Initialize logging - tracing_subscriber::fmt().with_max_level(Level::INFO).init(); + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")), + ) + .init(); // parse CLI options let cli = Cli::parse(); @@ -68,12 +78,16 @@ async fn main() -> Result<()> { num_users, devices_per_user, sessions_per_device, + no_truncate, + stats_batch_size, } => { let config = VpnSessionGeneratorConfig { location_id, num_users, devices_per_user, sessions_per_device, + no_truncate, + stats_batch_size, }; generate_vpn_session_stats(pool, config).await?; diff --git a/tools/defguard_generator/src/vpn_session_stats.rs b/tools/defguard_generator/src/vpn_session_stats.rs index f2835fa2f1..7e62eb98c9 100644 --- a/tools/defguard_generator/src/vpn_session_stats.rs +++ b/tools/defguard_generator/src/vpn_session_stats.rs @@ -13,30 +13,36 @@ use defguard_common::db::{ }, }; use rand::{Rng, rngs::ThreadRng}; -use sqlx::{PgConnection, PgPool}; -use tracing::info; +use sqlx::{PgConnection, PgPool, QueryBuilder}; +use tracing::{debug, info}; use crate::{user_devices::prepare_user_devices, users::prepare_users}; const STATS_COLLECTION_INTERVAL: Duration = Duration::seconds(30); const HANDSHAKE_INTERVAL: Duration = Duration::minutes(2); +#[derive(Debug)] pub struct VpnSessionGeneratorConfig { pub location_id: Id, pub num_users: u16, pub devices_per_user: u8, pub sessions_per_device: u8, + pub no_truncate: bool, + pub stats_batch_size: u16, } pub async fn generate_vpn_session_stats( pool: PgPool, config: VpnSessionGeneratorConfig, ) -> Result<()> { + info!("Running VPN stats generator with config: {config:#?}"); let mut rng = rand::thread_rng(); - // clear sessions & stats tables - info!("Clearing existing sessions & stats"); - truncate_with_restart(&pool).await?; + // clear sessions & stats tables unless disabled + if !config.no_truncate { + info!("Clearing existing sessions & stats"); + truncate_with_restart(&pool).await?; + } // fetch specified location let location = WireguardNetwork::find_by_id(&pool, config.location_id) @@ -52,7 +58,10 @@ pub async fn generate_vpn_session_stats( // generate sessions for each user for (i, user) in users.into_iter().enumerate() { - info!("[{i}/{user_count}] Generating VPN sessions for user {user}"); + info!( + "[{}/{user_count}] Generating VPN sessions for user {user}", + i + 1 + ); // begin DB transaction let mut transaction = pool.begin().await?; @@ -62,6 +71,7 @@ pub async fn generate_vpn_session_stats( prepare_user_devices(&pool, &mut rng, &user, config.devices_per_user as usize).await?; for device in devices { + info!("Generating sessions for device {device}"); // generate requested number of sessions for a device // we always start with a session that's currently active // and generate past ones as needed @@ -90,6 +100,8 @@ pub async fn generate_vpn_session_stats( session.save(&mut *transaction).await?; } + debug!("Created session {session:?}"); + generate_mock_session_stats( &mut transaction, &mut rng, @@ -97,9 +109,12 @@ pub async fn generate_vpn_session_stats( gateway.id, session_start, session_end, + config.stats_batch_size, ) .await?; + debug!("Finished generating mock stats for session {session:?}"); + // update end timestamp for next session session_end -= Duration::minutes(rng.gen_range(30..120)); } @@ -141,6 +156,7 @@ async fn generate_mock_session_stats( gateway_id: Id, session_start: NaiveDateTime, session_end: NaiveDateTime, + batch_size: u16, ) -> Result<()> { let mut latest_handshake = session_start; let mut next_handshake = latest_handshake + HANDSHAKE_INTERVAL; @@ -151,6 +167,9 @@ async fn generate_mock_session_stats( // assume the IP remains static within a single session let endpoint = random_socket_addr(rng).to_string(); + // Vector to accumulate stats before batch insertion + let mut stats_batch: Vec = Vec::new(); + while collected_at <= session_end { // generate traffic let upload_diff = rng.gen_range(100..100_000); @@ -158,7 +177,7 @@ async fn generate_mock_session_stats( let download_diff = rng.gen_range(100..100_000); total_download += download_diff; - VpnSessionStats::new( + let stats = VpnSessionStats::new( session_id, gateway_id, collected_at, @@ -168,9 +187,15 @@ async fn generate_mock_session_stats( total_download, download_diff, download_diff, - ) - .save(&mut *transaction) - .await?; + ); + + stats_batch.push(stats); + + // If batch is full, insert all at once + if stats_batch.len() >= batch_size.into() { + insert_stats_batch(&mut *transaction, &stats_batch).await?; + stats_batch.clear(); + } // update variables for next sample collected_at += STATS_COLLECTION_INTERVAL; @@ -182,6 +207,42 @@ async fn generate_mock_session_stats( } } + // Insert any remaining stats in the batch + if !stats_batch.is_empty() { + insert_stats_batch(&mut *transaction, &stats_batch).await?; + } + + Ok(()) +} + +/// Insert multiple VpnSessionStats records in a single query +async fn insert_stats_batch( + transaction: &mut PgConnection, + stats_batch: &[VpnSessionStats], +) -> Result<()> { + if stats_batch.is_empty() { + return Ok(()); + } + + let mut query_builder = QueryBuilder::new( + "INSERT INTO vpn_session_stats (session_id, gateway_id, collected_at, latest_handshake, endpoint, total_upload, total_download, upload_diff, download_diff) ", + ); + + query_builder.push_values(stats_batch, |mut b, stats| { + b.push_bind(stats.session_id) + .push_bind(stats.gateway_id) + .push_bind(stats.collected_at) + .push_bind(stats.latest_handshake) + .push_bind(&stats.endpoint) + .push_bind(stats.total_upload) + .push_bind(stats.total_download) + .push_bind(stats.upload_diff) + .push_bind(stats.download_diff); + }); + + let query = query_builder.build(); + query.execute(&mut *transaction).await?; + Ok(()) }