diff --git a/.sqlx/query-0c18e9d0f192e36ed65569c0ef124d6ab73bee88929ad223c46bb9b2892150f3.json b/.sqlx/query-0c18e9d0f192e36ed65569c0ef124d6ab73bee88929ad223c46bb9b2892150f3.json new file mode 100644 index 0000000000..f0693f9a64 --- /dev/null +++ b/.sqlx/query-0c18e9d0f192e36ed65569c0ef124d6ab73bee88929ad223c46bb9b2892150f3.json @@ -0,0 +1,40 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT id, user_id, location_id, \"public_ip\" \"public_ip: IpAddr\" FROM user_snat_binding WHERE location_id = $1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "user_id", + "type_info": "Int8" + }, + { + "ordinal": 2, + "name": "location_id", + "type_info": "Int8" + }, + { + "ordinal": 3, + "name": "public_ip: IpAddr", + "type_info": "Inet" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + false, + false, + false, + false + ] + }, + "hash": "0c18e9d0f192e36ed65569c0ef124d6ab73bee88929ad223c46bb9b2892150f3" +} diff --git a/.sqlx/query-0effde2f87ca6a7d9ce34daedb6462deb43863778ea17a47696912f412714741.json b/.sqlx/query-0effde2f87ca6a7d9ce34daedb6462deb43863778ea17a47696912f412714741.json new file mode 100644 index 0000000000..f9349763ad --- /dev/null +++ b/.sqlx/query-0effde2f87ca6a7d9ce34daedb6462deb43863778ea17a47696912f412714741.json @@ -0,0 +1,40 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT id, \"user_id\",\"location_id\",\"public_ip\" \"public_ip: IpAddr\" FROM \"user_snat_binding\" WHERE id = $1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "user_id", + "type_info": "Int8" + }, + { + "ordinal": 2, + "name": "location_id", + "type_info": "Int8" + }, + { + "ordinal": 3, + "name": "public_ip: IpAddr", + "type_info": "Inet" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + false, + false, + false, + false + ] + }, + "hash": "0effde2f87ca6a7d9ce34daedb6462deb43863778ea17a47696912f412714741" +} diff --git a/.sqlx/query-6175d2d008ccf7860ebe9f1b2e12d7dd30dac5aa72e015c19931c84010ebfcf5.json b/.sqlx/query-6175d2d008ccf7860ebe9f1b2e12d7dd30dac5aa72e015c19931c84010ebfcf5.json new file mode 100644 index 0000000000..c95aa77bf7 --- /dev/null +++ b/.sqlx/query-6175d2d008ccf7860ebe9f1b2e12d7dd30dac5aa72e015c19931c84010ebfcf5.json @@ -0,0 +1,17 @@ +{ + "db_name": "PostgreSQL", + "query": "UPDATE \"user_snat_binding\" SET \"user_id\" = $2,\"location_id\" = $3,\"public_ip\" = $4 WHERE id = $1", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8", + "Int8", + "Int8", + "Inet" + ] + }, + "nullable": [] + }, + "hash": "6175d2d008ccf7860ebe9f1b2e12d7dd30dac5aa72e015c19931c84010ebfcf5" +} diff --git a/.sqlx/query-876e1659850a050155f3938231e801b381b112d54971c86beaad5fd679fbd5ac.json b/.sqlx/query-876e1659850a050155f3938231e801b381b112d54971c86beaad5fd679fbd5ac.json new file mode 100644 index 0000000000..51f0e042d1 --- /dev/null +++ b/.sqlx/query-876e1659850a050155f3938231e801b381b112d54971c86beaad5fd679fbd5ac.json @@ -0,0 +1,41 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT id, user_id, location_id, \"public_ip\" \"public_ip: IpAddr\" FROM user_snat_binding WHERE location_id = $1 AND user_id = $2", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "user_id", + "type_info": "Int8" + }, + { + "ordinal": 2, + "name": "location_id", + "type_info": "Int8" + }, + { + "ordinal": 3, + "name": "public_ip: IpAddr", + "type_info": "Inet" + } + ], + "parameters": { + "Left": [ + "Int8", + "Int8" + ] + }, + "nullable": [ + false, + false, + false, + false + ] + }, + "hash": "876e1659850a050155f3938231e801b381b112d54971c86beaad5fd679fbd5ac" +} diff --git a/.sqlx/query-87868c21e47dd3b55ba1aefb690ce4ea9b463d7e71533e9c6312939a3d77b49e.json b/.sqlx/query-87868c21e47dd3b55ba1aefb690ce4ea9b463d7e71533e9c6312939a3d77b49e.json new file mode 100644 index 0000000000..0db9a2aa7c --- /dev/null +++ b/.sqlx/query-87868c21e47dd3b55ba1aefb690ce4ea9b463d7e71533e9c6312939a3d77b49e.json @@ -0,0 +1,38 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT id, \"user_id\",\"location_id\",\"public_ip\" \"public_ip: IpAddr\" FROM \"user_snat_binding\"", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "user_id", + "type_info": "Int8" + }, + { + "ordinal": 2, + "name": "location_id", + "type_info": "Int8" + }, + { + "ordinal": 3, + "name": "public_ip: IpAddr", + "type_info": "Inet" + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + false, + false, + false, + false + ] + }, + "hash": "87868c21e47dd3b55ba1aefb690ce4ea9b463d7e71533e9c6312939a3d77b49e" +} diff --git a/.sqlx/query-8d8416a6cc1f0bae02e126ca398e87000f305499668455d7e4d949f7d0a7be9a.json b/.sqlx/query-8d8416a6cc1f0bae02e126ca398e87000f305499668455d7e4d949f7d0a7be9a.json new file mode 100644 index 0000000000..01f2bb1f99 --- /dev/null +++ b/.sqlx/query-8d8416a6cc1f0bae02e126ca398e87000f305499668455d7e4d949f7d0a7be9a.json @@ -0,0 +1,14 @@ +{ + "db_name": "PostgreSQL", + "query": "DELETE FROM \"user_snat_binding\" WHERE id = $1", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [] + }, + "hash": "8d8416a6cc1f0bae02e126ca398e87000f305499668455d7e4d949f7d0a7be9a" +} diff --git a/.sqlx/query-cda76abbdfed425c3b06c5b64115529ffe33c7290a63adfbaaa416efeffefac3.json b/.sqlx/query-cda76abbdfed425c3b06c5b64115529ffe33c7290a63adfbaaa416efeffefac3.json new file mode 100644 index 0000000000..b606660ad7 --- /dev/null +++ b/.sqlx/query-cda76abbdfed425c3b06c5b64115529ffe33c7290a63adfbaaa416efeffefac3.json @@ -0,0 +1,24 @@ +{ + "db_name": "PostgreSQL", + "query": "INSERT INTO \"user_snat_binding\" (\"user_id\",\"location_id\",\"public_ip\") VALUES ($1,$2,$3) RETURNING id", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + } + ], + "parameters": { + "Left": [ + "Int8", + "Int8", + "Inet" + ] + }, + "nullable": [ + false + ] + }, + "hash": "cda76abbdfed425c3b06c5b64115529ffe33c7290a63adfbaaa416efeffefac3" +} diff --git a/Cargo.lock b/Cargo.lock index b388c5bd1e..9dffd6c33f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3836,9 +3836,9 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "reqwest" -version = "0.12.21" +version = "0.12.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c8cea6b35bcceb099f30173754403d2eba0a5dc18cea3630fccd88251909288" +checksum = "cbc931937e6ca3a06e3b6c0aa7841849b160a90351d6ab467a8b9b9959767531" dependencies = [ "base64 0.22.1", "bytes", diff --git a/crates/defguard_core/migrations/20250616071627_add_user_snat.down.sql b/crates/defguard_core/migrations/20250616071627_add_user_snat.down.sql new file mode 100644 index 0000000000..c98ffc6336 --- /dev/null +++ b/crates/defguard_core/migrations/20250616071627_add_user_snat.down.sql @@ -0,0 +1 @@ +DROP TABLE user_snat_binding; diff --git a/crates/defguard_core/migrations/20250616071627_add_user_snat.up.sql b/crates/defguard_core/migrations/20250616071627_add_user_snat.up.sql new file mode 100644 index 0000000000..8e0974e7d7 --- /dev/null +++ b/crates/defguard_core/migrations/20250616071627_add_user_snat.up.sql @@ -0,0 +1,9 @@ +CREATE TABLE user_snat_binding ( + id bigserial PRIMARY KEY, + user_id bigint NOT NULL, + location_id bigint NOT NULL, + public_ip inet NOT NULL, + FOREIGN KEY(user_id) REFERENCES "user"(id) ON DELETE CASCADE, + FOREIGN KEY(location_id) REFERENCES "wireguard_network"(id) ON DELETE CASCADE, + CONSTRAINT user_location UNIQUE (user_id, location_id) +); diff --git a/crates/defguard_core/src/enterprise/db/models/mod.rs b/crates/defguard_core/src/enterprise/db/models/mod.rs index 36ae9de03e..18af2791d2 100644 --- a/crates/defguard_core/src/enterprise/db/models/mod.rs +++ b/crates/defguard_core/src/enterprise/db/models/mod.rs @@ -3,3 +3,4 @@ pub mod activity_log_stream; pub mod api_tokens; pub mod enterprise_settings; pub mod openid_provider; +pub mod snat; diff --git a/crates/defguard_core/src/enterprise/db/models/snat.rs b/crates/defguard_core/src/enterprise/db/models/snat.rs new file mode 100644 index 0000000000..08fcd0db92 --- /dev/null +++ b/crates/defguard_core/src/enterprise/db/models/snat.rs @@ -0,0 +1,69 @@ +use std::net::IpAddr; + +use crate::{ + db::{Id, NoId}, + enterprise::snat::error::UserSnatBindingError, +}; +use model_derive::Model; +use serde::Serialize; +use sqlx::{query_as, PgExecutor}; +use utoipa::ToSchema; + +#[derive(Debug, Model, Serialize, ToSchema)] +#[table(user_snat_binding)] +pub struct UserSnatBinding { + pub id: I, + pub user_id: Id, + pub location_id: Id, + #[model(ip)] + #[schema(value_type = String)] + pub public_ip: IpAddr, +} + +impl UserSnatBinding { + pub fn new(user_id: Id, location_id: Id, public_ip: IpAddr) -> Self { + Self { + id: NoId, + user_id, + location_id, + public_ip, + } + } +} + +impl UserSnatBinding { + pub async fn find_binding<'e, E>( + executor: E, + location_id: Id, + user_id: Id, + ) -> Result + where + E: PgExecutor<'e>, + { + let binding = query_as!(Self, + "SELECT id, user_id, location_id, \"public_ip\" \"public_ip: IpAddr\" FROM user_snat_binding WHERE location_id = $1 AND user_id = $2", + location_id, user_id + ).fetch_one(executor).await?; + + Ok(binding) + } + + pub async fn all_for_location<'e, E>( + executor: E, + location_id: Id, + ) -> Result, sqlx::Error> + where + E: PgExecutor<'e>, + { + let bindings = query_as!(Self, + "SELECT id, user_id, location_id, \"public_ip\" \"public_ip: IpAddr\" FROM user_snat_binding WHERE location_id = $1", + location_id + ).fetch_all(executor).await?; + + Ok(bindings) + } + + pub fn update_ip(&mut self, new_public_ip: IpAddr) { + self.public_ip = new_public_ip; + } +} diff --git a/crates/defguard_core/src/enterprise/firewall.rs b/crates/defguard_core/src/enterprise/firewall/mod.rs similarity index 90% rename from crates/defguard_core/src/enterprise/firewall.rs rename to crates/defguard_core/src/enterprise/firewall/mod.rs index 90f2f6bbdc..513aa064fe 100644 --- a/crates/defguard_core/src/enterprise/firewall.rs +++ b/crates/defguard_core/src/enterprise/firewall/mod.rs @@ -15,10 +15,14 @@ use super::{ }; use crate::{ db::{models::error::ModelError, Device, Id, User, WireguardNetwork}, - enterprise::{db::models::acl::AliasKind, is_enterprise_enabled}, + enterprise::{ + db::models::{acl::AliasKind, snat::UserSnatBinding}, + is_enterprise_enabled, + }, grpc::proto::enterprise::firewall::{ ip_address::Address, port::Port as PortInner, FirewallConfig, FirewallPolicy, FirewallRule, IpAddress, IpRange, IpVersion, Port, PortRange as PortRangeProto, + SnatBinding as SnatBindingProto, }, }; @@ -64,10 +68,11 @@ pub async fn generate_firewall_rules_from_acls( // get relevant users for determining source IPs let users = get_source_users(allowed_users, &denied_users); + // prepare a list of user IDs + let user_ids: Vec = users.iter().map(|user| user.id).collect(); // get network IPs for devices belonging to those users - let user_device_ips = get_user_device_ips(&users, location_id, &mut *conn).await?; - + let user_device_ips = get_user_device_ips(&user_ids, location_id, &mut *conn).await?; // separate IPv4 and IPv6 user-device addresses let user_device_ips = user_device_ips .iter() @@ -323,13 +328,10 @@ fn get_source_users(allowed_users: Vec>, denied_users: &[User]) -> /// Fetches all IPs of devices belonging to specified users within a given location's VPN subnet. /// We specifically only fetch user devices since network devices are handled separately. async fn get_user_device_ips<'e, E: sqlx::PgExecutor<'e>>( - users: &[User], + user_ids: &[Id], location_id: Id, executor: E, ) -> Result>, SqlxError> { - // prepare a list of user IDs - let user_ids: Vec = users.iter().map(|user| user.id).collect(); - // fetch network IPs query_scalar!( "SELECT wireguard_ips \"wireguard_ips: Vec\" \ @@ -779,6 +781,86 @@ fn merge_port_ranges(port_ranges: Vec) -> Vec { .collect() } +/// Converts user SNAT bindings into SNAT config to be sent to a gateway as part of `FirewallConfig`. +/// +/// To generate the final SNAT binding we need to find all user devices +/// and get their IPs to generate a list of source addresses for a firewall rule. +async fn generate_user_snat_bindings_for_location( + location_id: Id, + conn: &mut PgConnection, +) -> Result, SqlxError> { + debug!("Generating SNAT bindings for location {location_id}"); + + let user_snat_bindings = UserSnatBinding::all_for_location(&mut *conn, location_id).await?; + + // check if there are any bindings configured for this location + if user_snat_bindings.is_empty() { + debug!("No user SNAT bindings configured for location {location_id}"); + return Ok(Vec::new()); + } + + // initialize output list + let mut bindings = Vec::new(); + + // process each user SNAT binding + for user_binding in user_snat_bindings { + let user_id = user_binding.user_id; + + debug!( + "Processing SNAT binding for user {user_id} with public IP {}", + user_binding.public_ip + ); + + // determine IP protocol version based on public IP + let is_ipv4 = user_binding.public_ip.is_ipv4(); + + // fetch all device IPs for this specific user in the location + let user_device_ips = get_user_device_ips(&[user_id], location_id, &mut *conn).await?; + + // separate IPv4 and IPv6 user-device addresses + let (user_device_ips_v4, user_device_ips_v6) = user_device_ips + .iter() + .flatten() + .partition(|ip| ip.is_ipv4()); + + // convert device IPs into source addresses for a firewall rule + let source_addrs = if is_ipv4 { + get_source_addrs(user_device_ips_v4, Vec::new(), IpVersion::Ipv4) + } else { + get_source_addrs(user_device_ips_v6, Vec::new(), IpVersion::Ipv6) + }; + + if source_addrs.is_empty() { + debug!( + "No compatible device IPs found for user {user_id} in location {location_id} with public IP {}, skipping SNAT binding", user_binding.public_ip + ); + continue; + } + + // create the SNAT binding proto + let snat_binding = SnatBindingProto { + id: user_binding.id, + source_addrs, + public_ip: user_binding.public_ip.to_string(), + comment: Some(format!("User {user_id} SNAT binding {}", user_binding.id)), + }; + + debug!( + "Created SNAT binding for user {user_id} in location {location_id}: {snat_binding:?}", + ); + + // add to output list + bindings.push(snat_binding); + } + + debug!( + "Generated {} SNAT bindings for location {location_id}", + bindings.len(), + ); + + Ok(bindings) +} + impl WireguardNetwork { /// Fetches all active ACL rules for a given location. /// Filters out rules which are disabled, expired or have not been deployed yet. @@ -846,9 +928,11 @@ impl WireguardNetwork { }; let firewall_rules = generate_firewall_rules_from_acls(self.id, location_acls, &mut *conn).await?; + let snat_bindings = generate_user_snat_bindings_for_location(self.id, &mut *conn).await?; let firewall_config = FirewallConfig { default_policy: default_policy.into(), rules: firewall_rules, + snat_bindings, }; debug!("Firewall config generated for location {self}: {firewall_config:?}"); diff --git a/crates/defguard_core/src/enterprise/mod.rs b/crates/defguard_core/src/enterprise/mod.rs index 2b7c276ff1..679296908e 100644 --- a/crates/defguard_core/src/enterprise/mod.rs +++ b/crates/defguard_core/src/enterprise/mod.rs @@ -7,6 +7,7 @@ pub mod handlers; pub mod ldap; pub mod license; pub mod limits; +pub mod snat; mod utils; use license::{get_cached_license, validate_license}; diff --git a/crates/defguard_core/src/enterprise/snat/error.rs b/crates/defguard_core/src/enterprise/snat/error.rs new file mode 100644 index 0000000000..83ee044ab5 --- /dev/null +++ b/crates/defguard_core/src/enterprise/snat/error.rs @@ -0,0 +1,36 @@ +use reqwest::StatusCode; +use thiserror::Error; + +use crate::error::WebError; + +#[derive(Debug, Error)] +pub enum UserSnatBindingError { + #[error("Binding not found")] + BindingNotFound, + #[error("Binding already exists")] + BindingAlreadyExists, + #[error("Database error")] + DbError { source: sqlx::Error }, +} + +impl From for UserSnatBindingError { + fn from(value: sqlx::Error) -> Self { + match value { + sqlx::Error::RowNotFound => Self::BindingNotFound, + sqlx::Error::Database(err) if err.constraint() == Some("user_location") => { + Self::BindingAlreadyExists + } + _ => Self::DbError { source: value }, + } + } +} + +impl From for WebError { + fn from(value: UserSnatBindingError) -> Self { + match value { + UserSnatBindingError::BindingNotFound => WebError::ObjectNotFound(value.to_string()), + UserSnatBindingError::BindingAlreadyExists => WebError::Http(StatusCode::CONFLICT), + UserSnatBindingError::DbError { source } => WebError::DbError(source.to_string()), + } + } +} diff --git a/crates/defguard_core/src/enterprise/snat/handlers.rs b/crates/defguard_core/src/enterprise/snat/handlers.rs new file mode 100644 index 0000000000..64f6478b23 --- /dev/null +++ b/crates/defguard_core/src/enterprise/snat/handlers.rs @@ -0,0 +1,248 @@ +use axum::{ + extract::{Path, State}, + Json, +}; +use reqwest::StatusCode; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use sqlx::query_as; +use std::net::IpAddr; +use utoipa::ToSchema; + +use crate::{ + appstate::AppState, + auth::{AdminRole, SessionInfo}, + db::{GatewayEvent, Id, WireguardNetwork}, + enterprise::{db::models::snat::UserSnatBinding, handlers::LicenseInfo}, + handlers::{ApiResponse, ApiResult}, +}; + +/// List all SNAT bindings for a WireGuard location +#[utoipa::path( + get, + path = "/api/v1/network/{location_id}/snat", + tag = "SNAT", + params( + ("location_id" = Id, Path, description = "WireGuard location ID") + ), + responses( + (status = 200, description = "List of SNAT bindings", body = Vec), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Forbidden - Admin role required"), + (status = 500, description = "Internal server error") + ), + security( + ("cookie" = []), + ("api_token" = []) + ) +)] +pub async fn list_snat_bindings( + _license: LicenseInfo, + _admin_role: AdminRole, + session: SessionInfo, + Path(location_id): Path, + State(appstate): State, +) -> ApiResult { + let current_user = session.user.username; + + debug!("User {current_user} listing SNAT bindings for WireGuard location {location_id}"); + + let bindings = query_as!( + UserSnatBinding::, + "SELECT id, user_id, location_id, \"public_ip\" \"public_ip: IpAddr\" FROM user_snat_binding WHERE location_id = $1", + location_id + ) + .fetch_all(&appstate.pool) + .await?; + + Ok(ApiResponse { + json: json!(bindings), + status: StatusCode::OK, + }) +} + +#[derive(Debug, Deserialize, Serialize, ToSchema)] +pub struct NewUserSnatBinding { + /// User ID to bind to the public IP + user_id: Id, + /// Public IP address for SNAT + #[schema(value_type = String)] + public_ip: IpAddr, +} + +/// Create a new SNAT binding for a user in a WireGuard location +#[utoipa::path( + post, + path = "/api/v1/network/{location_id}/snat", + tag = "SNAT", + params( + ("location_id" = Id, Path, description = "WireGuard location ID") + ), + request_body = NewUserSnatBinding, + responses( + (status = 201, description = "SNAT binding created successfully", body = UserSnatBinding), + (status = 400, description = "Bad request - Invalid input data"), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Forbidden - Admin role required"), + (status = 409, description = "Conflict - Binding already exists"), + (status = 500, description = "Internal server error") + ), + security( + ("cookie" = []), + ("api_token" = []) + ) +)] +pub async fn create_snat_binding( + _license: LicenseInfo, + _admin_role: AdminRole, + session: SessionInfo, + Path(location_id): Path, + State(appstate): State, + Json(data): Json, +) -> ApiResult { + let current_user = session.user.username; + + debug!("User {current_user} creating new SNAT binding for WireGuard location {location_id} with {data:?}"); + + let snat_binding = UserSnatBinding::new(data.user_id, location_id, data.public_ip); + + let binding = snat_binding.save(&appstate.pool).await?; + + // trigger firewall config update on relevant gateways + let mut conn = appstate.pool.acquire().await?; + if let Some(location) = WireguardNetwork::find_by_id(&appstate.pool, location_id).await? { + if let Some(firewall_config) = location.try_get_firewall_config(&mut conn).await? { + debug!("Sending firewall config update for location {location} affected by adding new SNAT binding"); + appstate.send_wireguard_event(GatewayEvent::FirewallConfigChanged( + location_id, + firewall_config, + )); + } + } + + Ok(ApiResponse { + json: json!(binding), + status: StatusCode::CREATED, + }) +} + +#[derive(Debug, Deserialize, Serialize, ToSchema)] +pub struct EditUserSnatBinding { + /// New public IP address for SNAT + #[schema(value_type = String)] + public_ip: IpAddr, +} + +/// Modify an existing SNAT binding for a user in a WireGuard location +#[utoipa::path( + put, + path = "/api/v1/network/{location_id}/snat/{user_id}", + tag = "SNAT", + params( + ("location_id" = Id, Path, description = "WireGuard location ID"), + ("user_id" = Id, Path, description = "User ID") + ), + request_body = EditUserSnatBinding, + responses( + (status = 200, description = "SNAT binding updated successfully", body = UserSnatBinding), + (status = 400, description = "Bad request - Invalid input data"), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Forbidden - Admin role required"), + (status = 404, description = "Not found - SNAT binding does not exist"), + (status = 500, description = "Internal server error") + ), + security( + ("cookie" = []), + ("api_token" = []) + ) +)] +pub async fn modify_snat_binding( + _license: LicenseInfo, + _admin_role: AdminRole, + session: SessionInfo, + Path((location_id, user_id)): Path<(Id, Id)>, + State(appstate): State, + Json(data): Json, +) -> ApiResult { + let current_user = session.user.username; + + debug!("User {current_user} updating SNAT binding for user {user_id} and WireGuard location {location_id} with {data:?}"); + + // fetch existing binding + let mut snat_binding = + UserSnatBinding::find_binding(&appstate.pool, location_id, user_id).await?; + + // update public IP + snat_binding.update_ip(data.public_ip); + snat_binding.save(&appstate.pool).await?; + + // trigger firewall config update on relevant gateways + let mut conn = appstate.pool.acquire().await?; + if let Some(location) = WireguardNetwork::find_by_id(&appstate.pool, location_id).await? { + if let Some(firewall_config) = location.try_get_firewall_config(&mut conn).await? { + debug!("Sending firewall config update for location {location} affected by adding new SNAT binding"); + appstate.send_wireguard_event(GatewayEvent::FirewallConfigChanged( + location_id, + firewall_config, + )); + } + } + + Ok(ApiResponse { + json: json!(snat_binding), + status: StatusCode::OK, + }) +} + +/// Delete an existing SNAT binding for a user in a WireGuard location +#[utoipa::path( + delete, + path = "/api/v1/network/{location_id}/snat/{user_id}", + tag = "SNAT", + params( + ("location_id" = Id, Path, description = "WireGuard location ID"), + ("user_id" = Id, Path, description = "User ID") + ), + responses( + (status = 200, description = "SNAT binding deleted successfully"), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Forbidden - Admin role required"), + (status = 404, description = "Not found - SNAT binding does not exist"), + (status = 500, description = "Internal server error") + ), + security( + ("cookie" = []), + ("api_token" = []) + ) +)] +pub async fn delete_snat_binding( + _license: LicenseInfo, + _admin_role: AdminRole, + session: SessionInfo, + Path((location_id, user_id)): Path<(Id, Id)>, + State(appstate): State, +) -> ApiResult { + let current_user = session.user.username; + + debug!("User {current_user} deleting SNAT binding for user {user_id} and WireGuard location {location_id}"); + + // fetch existing binding + let snat_binding = UserSnatBinding::find_binding(&appstate.pool, location_id, user_id).await?; + + // delete binding + snat_binding.delete(&appstate.pool).await?; + + // trigger firewall config update on relevant gateways + let mut conn = appstate.pool.acquire().await?; + if let Some(location) = WireguardNetwork::find_by_id(&appstate.pool, location_id).await? { + if let Some(firewall_config) = location.try_get_firewall_config(&mut conn).await? { + debug!("Sending firewall config update for location {location} affected by adding new SNAT binding"); + appstate.send_wireguard_event(GatewayEvent::FirewallConfigChanged( + location_id, + firewall_config, + )); + } + } + + Ok(ApiResponse::default()) +} diff --git a/crates/defguard_core/src/enterprise/snat/mod.rs b/crates/defguard_core/src/enterprise/snat/mod.rs new file mode 100644 index 0000000000..00d363536e --- /dev/null +++ b/crates/defguard_core/src/enterprise/snat/mod.rs @@ -0,0 +1,2 @@ +pub mod error; +pub mod handlers; diff --git a/crates/defguard_core/src/lib.rs b/crates/defguard_core/src/lib.rs index bd5f5b72ee..99598c498a 100644 --- a/crates/defguard_core/src/lib.rs +++ b/crates/defguard_core/src/lib.rs @@ -14,23 +14,28 @@ use axum::{ }; use db::models::device::DeviceType; use defguard_web_ui::{index, svg, web_asset}; -use enterprise::handlers::{ - acl::{ - apply_acl_aliases, apply_acl_rules, create_acl_alias, create_acl_rule, delete_acl_alias, - delete_acl_rule, get_acl_alias, get_acl_rule, list_acl_aliases, list_acl_rules, - update_acl_alias, update_acl_rule, - }, - activity_log_stream::{ - create_activity_log_stream, delete_activity_log_stream, get_activity_log_stream, - modify_activity_log_stream, +use enterprise::{ + handlers::{ + acl::{ + apply_acl_aliases, apply_acl_rules, create_acl_alias, create_acl_rule, + delete_acl_alias, delete_acl_rule, get_acl_alias, get_acl_rule, list_acl_aliases, + list_acl_rules, update_acl_alias, update_acl_rule, + }, + activity_log_stream::{ + create_activity_log_stream, delete_activity_log_stream, get_activity_log_stream, + modify_activity_log_stream, + }, + api_tokens::{add_api_token, delete_api_token, fetch_api_tokens, rename_api_token}, + check_enterprise_info, + enterprise_settings::{get_enterprise_settings, patch_enterprise_settings}, + openid_login::{auth_callback, get_auth_info}, + openid_providers::{ + add_openid_provider, delete_openid_provider, get_current_openid_provider, + test_dirsync_connection, + }, }, - api_tokens::{add_api_token, delete_api_token, fetch_api_tokens, rename_api_token}, - check_enterprise_info, - enterprise_settings::{get_enterprise_settings, patch_enterprise_settings}, - openid_login::{auth_callback, get_auth_info}, - openid_providers::{ - add_openid_provider, delete_openid_provider, get_current_openid_provider, - test_dirsync_connection, + snat::handlers::{ + create_snat_binding, delete_snat_binding, list_snat_bindings, modify_snat_binding, }, }; use events::ApiEvent; @@ -180,6 +185,7 @@ pub(crate) fn server_config() -> &'static DefGuardConfig { pub(crate) const KEY_LENGTH: usize = 32; mod openapi { + use crate::enterprise::snat::handlers as snat; use db::{ models::device::{ModifyDevice, UserDevice}, AddDevice, UserDetails, UserInfo, @@ -240,6 +246,12 @@ mod openapi { network::delete_network, network::list_networks, network::network_details, + // /network/{location_id}/snat + snat::list_snat_bindings, + snat::create_snat_binding, + snat::modify_snat_binding, + snat::delete_snat_binding, + ), components( schemas( @@ -271,13 +283,16 @@ Available actions: - list all devices or user devices - CRUD mechanism for handling devices. "), - (name = "nework", description = " + (name = "network", description = " Endpoints that allow to control your networks. Available actions: - list all wireguard networks - CRUD mechanism for handling devices. "), + (name = "SNAT", description = " +Endpoints that allow you to control user SNAT bindings for your locations. + "), ) )] pub struct ApiDoc; @@ -577,6 +592,16 @@ pub fn build_webapp( .route("/network/{network_id}/token", get(create_network_token)) .route("/network/{network_id}/stats/users", get(devices_stats)) .route("/network/{network_id}/stats", get(network_stats)) + .route("/network/{location_id}/snat", get(list_snat_bindings)) + .route("/network/{location_id}/snat", post(create_snat_binding)) + .route( + "/network/{location_id}/snat/{user_id}", + put(modify_snat_binding), + ) + .route( + "/network/{location_id}/snat/{user_id}", + delete(delete_snat_binding), + ) .layer(Extension(gateway_state)), ); diff --git a/crates/model_derive/src/lib.rs b/crates/model_derive/src/lib.rs index 0602281931..299ba260a4 100644 --- a/crates/model_derive/src/lib.rs +++ b/crates/model_derive/src/lib.rs @@ -123,6 +123,8 @@ pub fn derive(input: TokenStream) -> TokenStream { if field_type == "secret" { // FIXME: don't hard-code struct name cs_aliased_fields.push_str("?: SecretString\""); + } else if field_type == "ip" { + cs_aliased_fields.push_str(": IpAddr\""); } else { cs_aliased_fields.push_str(": _\""); } @@ -153,6 +155,9 @@ pub fn derive(input: TokenStream) -> TokenStream { } else if tokens == "secret" { // FIXME: hard-coded struct name return Some(quote! { &self.#name as &Option }); + } else if tokens == "ip" { + // FIXME: hard-coded struct name + return Some(quote! { &self.#name as &IpAddr }); } else { return Some(quote! { &self.#name }); } diff --git a/flake.lock b/flake.lock index 69fe2db4c8..b227bc8e14 100644 --- a/flake.lock +++ b/flake.lock @@ -48,11 +48,11 @@ ] }, "locked": { - "lastModified": 1751338093, - "narHash": "sha256-/yd9nPcTfUZPFtwjRbdB5yGLdt3LTPqz6Ja63Joiahs=", + "lastModified": 1751423951, + "narHash": "sha256-AowKhJGplXRkAngSvb+32598DTiI6LOzhAnzgvbCtYM=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "6cfb7821732dac2d3e2dea857a5613d3b856c20c", + "rev": "1684ed5b15859b655caf41b467d046e29a994d04", "type": "github" }, "original": { diff --git a/proto b/proto index eb4ac0620f..c0aef68395 160000 --- a/proto +++ b/proto @@ -1 +1 @@ -Subproject commit eb4ac0620f54bfa58669f2ac61ea5fce5c55b521 +Subproject commit c0aef68395720f46a7f038b6766de3bb30e02930