Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/defguard_common/src/db/models/vpn_client_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::db::{
models::{WireguardNetwork, vpn_session_stats::VpnSessionStats, wireguard::LocationMfaMode},
};

#[derive(Default, Type)]
#[derive(Debug, Default, Type)]
#[sqlx(type_name = "vpn_client_session_state", rename_all = "lowercase")]
pub enum VpnClientSessionState {
#[default]
Expand All @@ -17,7 +17,7 @@ pub enum VpnClientSessionState {
}

/// Represents a single VPN client session from creation to eventual disconnection
#[derive(Model)]
#[derive(Debug, Model)]
#[table(vpn_client_session)]
pub struct VpnClientSession<I = NoId> {
pub id: I,
Expand Down
22 changes: 22 additions & 0 deletions tools/defguard_generator/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,18 @@

This crate contains a simple generator for creating users, devices, stats etc during development.

### Database connection

The generator uses the same environment variables (or CLI options) for DB connection setup as the core binary:

- DEFGUARD_DB_HOST
- DEFGUARD_DB_PORT
- DEFGUARD_DB_NAME
- DEFGUARD_DB_USER
- DEFGUARD_DB_PASSWORD

This means that if you have a development environment set up already it should just work.

### Usage

```bash
Expand All @@ -12,3 +24,13 @@ cargo run -p defguard_generator -- vpn-session-stats \
--sessions-per-device 5
```

### Session generation logic

For each device the generator always starts with creating an active (not disconnected) session.
If there are more sessions per device to be generated it goes backwards in time and creates
additional disconnected sessions.
Session duration and gaps between sessions are randomized but there is no logic to verify if
sessions are overlapping so by default the generator runs a `TRUNCATE` query at the start.
To disable this behavior (for example when running it multiple times for separate locations)
use the `--no-truncate` CLI flag.

20 changes: 17 additions & 3 deletions tools/defguard_generator/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use defguard_common::db::{Id, init_db};
use defguard_generator::vpn_session_stats::{
VpnSessionGeneratorConfig, generate_vpn_session_stats,
};
use tracing::Level;
use tracing_subscriber::EnvFilter;

#[derive(Parser)]
#[command(about, long_about = None)]
Expand All @@ -30,7 +30,7 @@ struct Cli {

#[derive(Subcommand)]
enum Commands {
/// generates VPN session stats
/// Generates fake VPN session statistics.
VpnSessionStats {
#[arg(long)]
location_id: Id,
Expand All @@ -40,13 +40,23 @@ enum Commands {
devices_per_user: u8,
#[arg(long)]
sessions_per_device: u8,
/// don't truncate sessions & stats tables before generating stats
#[arg(long)]
no_truncate: bool,
/// insert stats records in batches of specified size
#[arg(long, default_value_t = 1000)]
stats_batch_size: u16,
},
}

#[tokio::main]
async fn main() -> Result<()> {
// Initialize logging
tracing_subscriber::fmt().with_max_level(Level::INFO).init();
tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")),
)
.init();

// parse CLI options
let cli = Cli::parse();
Expand All @@ -68,12 +78,16 @@ async fn main() -> Result<()> {
num_users,
devices_per_user,
sessions_per_device,
no_truncate,
stats_batch_size,
} => {
let config = VpnSessionGeneratorConfig {
location_id,
num_users,
devices_per_user,
sessions_per_device,
no_truncate,
stats_batch_size,
};

generate_vpn_session_stats(pool, config).await?;
Expand Down
81 changes: 71 additions & 10 deletions tools/defguard_generator/src/vpn_session_stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,36 @@ use defguard_common::db::{
},
};
use rand::{Rng, rngs::ThreadRng};
use sqlx::{PgConnection, PgPool};
use tracing::info;
use sqlx::{PgConnection, PgPool, QueryBuilder};
use tracing::{debug, info};

use crate::{user_devices::prepare_user_devices, users::prepare_users};

const STATS_COLLECTION_INTERVAL: Duration = Duration::seconds(30);
const HANDSHAKE_INTERVAL: Duration = Duration::minutes(2);

#[derive(Debug)]
pub struct VpnSessionGeneratorConfig {
pub location_id: Id,
pub num_users: u16,
pub devices_per_user: u8,
pub sessions_per_device: u8,
pub no_truncate: bool,
pub stats_batch_size: u16,
}

pub async fn generate_vpn_session_stats(
pool: PgPool,
config: VpnSessionGeneratorConfig,
) -> Result<()> {
info!("Running VPN stats generator with config: {config:#?}");
let mut rng = rand::thread_rng();

// clear sessions & stats tables
info!("Clearing existing sessions & stats");
truncate_with_restart(&pool).await?;
// clear sessions & stats tables unless disabled
if !config.no_truncate {
info!("Clearing existing sessions & stats");
truncate_with_restart(&pool).await?;
}

// fetch specified location
let location = WireguardNetwork::find_by_id(&pool, config.location_id)
Expand All @@ -52,7 +58,10 @@ pub async fn generate_vpn_session_stats(

// generate sessions for each user
for (i, user) in users.into_iter().enumerate() {
info!("[{i}/{user_count}] Generating VPN sessions for user {user}");
info!(
"[{}/{user_count}] Generating VPN sessions for user {user}",
i + 1
);

// begin DB transaction
let mut transaction = pool.begin().await?;
Expand All @@ -62,6 +71,7 @@ pub async fn generate_vpn_session_stats(
prepare_user_devices(&pool, &mut rng, &user, config.devices_per_user as usize).await?;

for device in devices {
info!("Generating sessions for device {device}");
// generate requested number of sessions for a device
// we always start with a session that's currently active
// and generate past ones as needed
Expand Down Expand Up @@ -90,16 +100,21 @@ pub async fn generate_vpn_session_stats(
session.save(&mut *transaction).await?;
}

debug!("Created session {session:?}");

generate_mock_session_stats(
&mut transaction,
&mut rng,
session.id,
gateway.id,
session_start,
session_end,
config.stats_batch_size,
)
.await?;

debug!("Finished generating mock stats for session {session:?}");

// update end timestamp for next session
session_end -= Duration::minutes(rng.gen_range(30..120));
}
Expand Down Expand Up @@ -141,6 +156,7 @@ async fn generate_mock_session_stats(
gateway_id: Id,
session_start: NaiveDateTime,
session_end: NaiveDateTime,
batch_size: u16,
) -> Result<()> {
let mut latest_handshake = session_start;
let mut next_handshake = latest_handshake + HANDSHAKE_INTERVAL;
Expand All @@ -151,14 +167,17 @@ async fn generate_mock_session_stats(
// assume the IP remains static within a single session
let endpoint = random_socket_addr(rng).to_string();

// Vector to accumulate stats before batch insertion
let mut stats_batch: Vec<VpnSessionStats> = Vec::new();

while collected_at <= session_end {
// generate traffic
let upload_diff = rng.gen_range(100..100_000);
total_upload += upload_diff;
let download_diff = rng.gen_range(100..100_000);
total_download += download_diff;

VpnSessionStats::new(
let stats = VpnSessionStats::new(
session_id,
gateway_id,
collected_at,
Expand All @@ -168,9 +187,15 @@ async fn generate_mock_session_stats(
total_download,
download_diff,
download_diff,
)
.save(&mut *transaction)
.await?;
);

stats_batch.push(stats);

// If batch is full, insert all at once
if stats_batch.len() >= batch_size.into() {
insert_stats_batch(&mut *transaction, &stats_batch).await?;
stats_batch.clear();
}

// update variables for next sample
collected_at += STATS_COLLECTION_INTERVAL;
Expand All @@ -182,6 +207,42 @@ async fn generate_mock_session_stats(
}
}

// Insert any remaining stats in the batch
if !stats_batch.is_empty() {
insert_stats_batch(&mut *transaction, &stats_batch).await?;
}

Ok(())
}

/// Insert multiple VpnSessionStats records in a single query
async fn insert_stats_batch(
transaction: &mut PgConnection,
stats_batch: &[VpnSessionStats],
) -> Result<()> {
if stats_batch.is_empty() {
return Ok(());
}

let mut query_builder = QueryBuilder::new(
"INSERT INTO vpn_session_stats (session_id, gateway_id, collected_at, latest_handshake, endpoint, total_upload, total_download, upload_diff, download_diff) ",
);

query_builder.push_values(stats_batch, |mut b, stats| {
b.push_bind(stats.session_id)
.push_bind(stats.gateway_id)
.push_bind(stats.collected_at)
.push_bind(stats.latest_handshake)
.push_bind(&stats.endpoint)
.push_bind(stats.total_upload)
.push_bind(stats.total_download)
.push_bind(stats.upload_diff)
.push_bind(stats.download_diff);
});

let query = query_builder.build();
query.execute(&mut *transaction).await?;

Ok(())
}

Expand Down
Loading