diff --git a/.sqlx/query-e2bf2ef722c117869f16695d5899aedfc6e35d5afa8c38e462186596348267d8.json b/.sqlx/query-e2bf2ef722c117869f16695d5899aedfc6e35d5afa8c38e462186596348267d8.json new file mode 100644 index 0000000000..8a9093e246 --- /dev/null +++ b/.sqlx/query-e2bf2ef722c117869f16695d5899aedfc6e35d5afa8c38e462186596348267d8.json @@ -0,0 +1,46 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT wn.id AS location_id, wn.name AS location_name, d.id AS device_id, d.name AS device_name, wnd.wireguard_ips AS \"wireguard_ips: Vec\" FROM wireguard_network wn JOIN wireguard_network_device wnd ON wnd.wireguard_network_id = wn.id JOIN device d ON d.id = wnd.device_id JOIN \"user\" u ON d.user_id = u.id WHERE u.username = $1 ORDER BY wn.name, d.name", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "location_id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "location_name", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "device_id", + "type_info": "Int8" + }, + { + "ordinal": 3, + "name": "device_name", + "type_info": "Text" + }, + { + "ordinal": 4, + "name": "wireguard_ips: Vec", + "type_info": "InetArray" + } + ], + "parameters": { + "Left": [ + "Text" + ] + }, + "nullable": [ + false, + false, + false, + false, + false + ] + }, + "hash": "e2bf2ef722c117869f16695d5899aedfc6e35d5afa8c38e462186596348267d8" +} diff --git a/Cargo.lock b/Cargo.lock index d4ffc9266d..d022b44641 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1380,6 +1380,7 @@ dependencies = [ "defguard_common", "defguard_mail", "defguard_proto", + "defguard_static_ip", "defguard_version", "defguard_web_ui", "futures", @@ -1618,6 +1619,21 @@ dependencies = [ "tracing", ] +[[package]] +name = "defguard_static_ip" +version = "0.0.0" +dependencies = [ + "anyhow", + "axum", + "defguard_common", + "serde", + "serde_json", + "sqlx", + "thiserror 2.0.18", + "tokio", + "tracing", +] + [[package]] name = "defguard_version" version = "0.0.0" diff --git a/Cargo.toml b/Cargo.toml index 7910a0505e..2315ef7223 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,10 @@ resolver = "2" [workspace.dependencies] # internal crates + +defguard_setup = { path = "./crates/defguard_setup", version = "0.0.0" } defguard_common = { path = "./crates/defguard_common", version = "2.0.0" } +defguard_static_ip = { path = "./crates/defguard_static_ip", version = "0.0.0" } defguard_core = { path = "./crates/defguard_core", version = "0.0.0" } defguard_event_logger = { path = "./crates/defguard_event_logger", version = "0.0.0" } defguard_event_router = { path = "./crates/defguard_event_router", version = "0.0.0" } @@ -26,7 +29,6 @@ defguard_vpn_stats_purge = { path = "./crates/defguard_vpn_stats_purge", version defguard_web_ui = { path = "./crates/defguard_web_ui", version = "0.0.0" } defguard_certs = { path = "./crates/defguard_certs", version = "0.0.0" } defguard_grpc_tls = { path = "./crates/defguard_grpc_tls", version = "0.0.0" } -defguard_setup = { path = "./crates/defguard_setup", version = "0.0.0" } model_derive = { path = "./crates/model_derive", version = "0.0.0" } # external dependencies @@ -58,7 +60,11 @@ ipnetwork = "0.20" jsonwebkey = { version = "0.4", features = ["pkcs-convert"] } jsonwebtoken = { version = "10.3", features = ["rust_crypto"] } ldap3 = { version = "0.12", default-features = false, features = ["tls"] } -lettre = { version = "0.11", default-features = false, features = ["builder", "smtp-transport", "tokio1-rustls-tls"] } +lettre = { version = "0.11", default-features = false, features = [ + "builder", + "smtp-transport", + "tokio1-rustls-tls", +] } matches = "0.1" md4 = "0.10" openidconnect = { version = "4.0", default-features = false, features = [ diff --git a/crates/defguard_common/src/db/models/wireguard.rs b/crates/defguard_common/src/db/models/wireguard.rs index 0d9208657a..9e517d0384 100644 --- a/crates/defguard_common/src/db/models/wireguard.rs +++ b/crates/defguard_common/src/db/models/wireguard.rs @@ -201,7 +201,7 @@ pub enum NetworkAddressError { "Location {0} has no network that could contain IP address {1}, available networks: {2:?}" )] NoContainingNetwork(String, IpAddr, Vec), - #[error("IP address {1} is reserved for gateway in location {0}")] + #[error("IP address {1} is reserved for Gateway in location {0}")] ReservedForGateway(String, IpAddr), #[error("IP address {1} is network broadcast address in location {0}")] IsBroadcastAddress(String, IpAddr), diff --git a/crates/defguard_common/src/utils.rs b/crates/defguard_common/src/utils.rs index a80460de0b..ae5965921a 100644 --- a/crates/defguard_common/src/utils.rs +++ b/crates/defguard_common/src/utils.rs @@ -1,4 +1,7 @@ +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + use ipnetwork::IpNetwork; +use serde::Serialize; /// Parse a string with comma-separated IP addresses. /// Invalid addresses will be silently ignored. @@ -23,3 +26,133 @@ pub fn parse_network_address_list(ips: &str) -> Vec { }) .collect() } + +#[derive(Debug, Serialize, PartialEq)] +pub struct SplitIp { + network_part: String, + modifiable_part: String, + network_prefix: String, + ip: String, +} + +/// Splits the IP address (IPv4 or IPv6) into three parts: network part, modifiable part and prefix +/// The network part is the part that can't be changed by the user. +/// This is to display an IP address in the UI like this: 192.168.(1.1)/16, where the part in the parenthesis can be changed by the user. +/// The algorithm works as follows: +/// 1. Get the network address, last address and IP address segments, e.g. 192.1.1.1 would be [192, 1, 1, 1] +/// 2. Iterate over the segments and compare the last address and network segments, as long as the current segments are equal, append the segment to the network part. +/// If they are not equal, we found the first modifiable segment (one of the segments of an address that may change between hosts in the same network), +/// append the rest of the segments to the modifiable part. +/// 3. Join the segments with the delimiter and return the network part, modifiable part and the network prefix +pub fn split_ip(ip: &IpAddr, network: &IpNetwork) -> SplitIp { + let network_addr = network.network(); + let network_prefix = network.prefix(); + + let ip_segments = match ip { + IpAddr::V4(ip) => ip.octets().iter().map(|x| u16::from(*x)).collect(), + IpAddr::V6(ip) => ip.segments().to_vec(), + }; + + let last_addr_segments = match network { + IpNetwork::V4(net) => { + let last_ip = u32::from(net.ip()) | (!u32::from(net.mask())); + let last_ip: Ipv4Addr = last_ip.into(); + last_ip.octets().iter().map(|x| u16::from(*x)).collect() + } + IpNetwork::V6(net) => { + let last_ip = u128::from(net.ip()) | (!u128::from(net.mask())); + let last_ip: Ipv6Addr = last_ip.into(); + last_ip.segments().to_vec() + } + }; + + let network_segments = match network_addr { + IpAddr::V4(ip) => ip.octets().iter().map(|x| u16::from(*x)).collect(), + IpAddr::V6(ip) => ip.segments().to_vec(), + }; + + let mut network_part = String::new(); + let mut modifiable_part = String::new(); + let delimiter = if ip.is_ipv4() { "." } else { ":" }; + let formatter = |x: &u16| { + if ip.is_ipv4() { + x.to_string() + } else { + format!("{x:04x}") + } + }; + + for (i, ((last_addr_segment, network_segment), ip_segment)) in last_addr_segments + .iter() + .zip(network_segments.iter()) + .zip(ip_segments.iter()) + .enumerate() + { + if last_addr_segment != network_segment { + let parts = ip_segments.split_at(i).1; + let joined = parts + .iter() + .map(formatter) + .collect::>() + .join(delimiter); + modifiable_part.push_str(&joined); + break; + } + let formatted = formatter(ip_segment); + network_part.push_str(&formatted); + network_part.push_str(delimiter); + } + + SplitIp { + ip: ip.to_string(), + network_part, + modifiable_part, + network_prefix: network_prefix.to_string(), + } +} + +#[cfg(test)] +mod test { + use std::str::FromStr; + + use super::*; + + #[test] + fn test_ip_splitter() { + let net = split_ip( + &IpAddr::from_str("192.168.3.1").unwrap(), + &IpNetwork::from_str("192.168.3.1/30").unwrap(), + ); + + assert_eq!(net.network_part, "192.168.3."); + assert_eq!(net.modifiable_part, "1"); + assert_eq!(net.network_prefix, "30"); + + let net = split_ip( + &IpAddr::from_str("192.168.5.7").unwrap(), + &IpNetwork::from_str("192.168.3.1/24").unwrap(), + ); + + assert_eq!(net.network_part, "192.168.5."); + assert_eq!(net.modifiable_part, "7"); + assert_eq!(net.network_prefix, "24"); + + let net = split_ip( + &IpAddr::from_str("2001:0db8:85a3::8a2e:0370:7334").unwrap(), + &IpNetwork::from_str("2001:0db8:85a3::8a2e:0370:7334/64").unwrap(), + ); + + assert_eq!(net.network_part, "2001:0db8:85a3:0000:"); + assert_eq!(net.modifiable_part, "0000:8a2e:0370:7334"); + assert_eq!(net.network_prefix, "64"); + + let net = split_ip( + &IpAddr::from_str("2001:0db8::0010:8a2e:0370:aaaa").unwrap(), + &IpNetwork::from_str("2001:db8::10:8a2e:370:aaa8/125").unwrap(), + ); + + assert_eq!(net.network_part, "2001:0db8:0000:0000:0010:8a2e:0370:"); + assert_eq!(net.modifiable_part, "aaaa"); + assert_eq!(net.network_prefix, "125"); + } +} diff --git a/crates/defguard_core/Cargo.toml b/crates/defguard_core/Cargo.toml index 15ec63bf76..0927fce948 100644 --- a/crates/defguard_core/Cargo.toml +++ b/crates/defguard_core/Cargo.toml @@ -16,6 +16,7 @@ defguard_web_ui = { workspace = true } defguard_version = { workspace = true } model_derive = { workspace = true } defguard_certs = { workspace = true } +defguard_static_ip = { workspace = true } # external dependencies anyhow = { workspace = true } diff --git a/crates/defguard_core/src/error.rs b/crates/defguard_core/src/error.rs index 09c2d450cb..10d824d490 100644 --- a/crates/defguard_core/src/error.rs +++ b/crates/defguard_core/src/error.rs @@ -7,6 +7,7 @@ use defguard_common::{ types::UrlParseError, }; use defguard_mail::templates::TemplateError; +use defguard_static_ip::error::StaticIpError; use thiserror::Error; use tokio::sync::mpsc::error::SendError; use utoipa::ToSchema; @@ -85,6 +86,9 @@ pub enum WebError { #[error(transparent)] #[schema(value_type=Object)] UrlParseError(#[from] UrlParseError), + #[error(transparent)] + #[schema(value_type=Object)] + StaticIpError(#[from] StaticIpError), } impl From for WebError { diff --git a/crates/defguard_core/src/handlers/mod.rs b/crates/defguard_core/src/handlers/mod.rs index bb5ff434ba..f455bb47ad 100644 --- a/crates/defguard_core/src/handlers/mod.rs +++ b/crates/defguard_core/src/handlers/mod.rs @@ -13,6 +13,7 @@ use defguard_common::{ }, types::user_info::UserInfo, }; +use defguard_static_ip::error::StaticIpError; use serde_json::{Value, json}; use sqlx::PgPool; use utoipa::ToSchema; @@ -41,6 +42,7 @@ pub(crate) mod pagination; pub mod proxy; pub mod settings; pub(crate) mod ssh_authorized_keys; +pub(crate) mod static_ips; pub(crate) mod support; pub(crate) mod updates; pub mod user; @@ -121,6 +123,22 @@ impl From for ApiResponse { StatusCode::INTERNAL_SERVER_ERROR, ) } + WebError::StaticIpError(err) => match err { + StaticIpError::InvalidIpAssignment(err) => { + ApiResponse::new(json!({"msg": err.to_string()}), StatusCode::BAD_REQUEST) + } + StaticIpError::NetworkNotFound(_) | StaticIpError::DeviceNotInNetwork(_, _) => { + error!("{err}"); + ApiResponse::new(json!({"msg": err.to_string()}), StatusCode::BAD_REQUEST) + } + StaticIpError::SqlxError(_) => { + error!("{err}"); + ApiResponse::new( + json!({"msg": "Internal server error"}), + StatusCode::INTERNAL_SERVER_ERROR, + ) + } + }, WebError::AclError(err) => match err { AclError::ParseIntError(_) | AclError::IpNetworkError(_) diff --git a/crates/defguard_core/src/handlers/network_devices.rs b/crates/defguard_core/src/handlers/network_devices.rs index 1120c730fc..10a3fa94e8 100644 --- a/crates/defguard_core/src/handlers/network_devices.rs +++ b/crates/defguard_core/src/handlers/network_devices.rs @@ -1,5 +1,5 @@ use std::{ - net::{AddrParseError, IpAddr, Ipv4Addr, Ipv6Addr}, + net::{AddrParseError, IpAddr}, str::FromStr, }; @@ -18,9 +18,9 @@ use defguard_common::{ wireguard::NetworkAddressError, }, }, + utils::{SplitIp, split_ip}, }; use defguard_mail::templates::{TemplateLocation, new_device_added_mail}; -use ipnetwork::IpNetwork; use serde_json::json; use sqlx::PgConnection; @@ -755,131 +755,3 @@ pub async fn modify_network_device( })?; Ok(ApiResponse::json(network_device_info, StatusCode::OK)) } - -#[derive(Debug, Serialize)] -struct SplitIp { - network_part: String, - modifiable_part: String, - network_prefix: String, - ip: String, -} - -/// Splits the IP address (IPv4 or IPv6) into three parts: network part, modifiable part and prefix -/// The network part is the part that can't be changed by the user. -/// This is to display an IP address in the UI like this: 192.168.(1.1)/16, where the part in the parenthesis can be changed by the user. -/// The algorithm works as follows: -/// 1. Get the network address, last address and IP address segments, e.g. 192.1.1.1 would be [192, 1, 1, 1] -/// 2. Iterate over the segments and compare the last address and network segments, as long as the current segments are equal, append the segment to the network part. -/// If they are not equal, we found the first modifiable segment (one of the segments of an address that may change between hosts in the same network), -/// append the rest of the segments to the modifiable part. -/// 3. Join the segments with the delimiter and return the network part, modifiable part and the network prefix -fn split_ip(ip: &IpAddr, network: &IpNetwork) -> SplitIp { - let network_addr = network.network(); - let network_prefix = network.prefix(); - - let ip_segments = match ip { - IpAddr::V4(ip) => ip.octets().iter().map(|x| u16::from(*x)).collect(), - IpAddr::V6(ip) => ip.segments().to_vec(), - }; - - let last_addr_segments = match network { - IpNetwork::V4(net) => { - let last_ip = u32::from(net.ip()) | (!u32::from(net.mask())); - let last_ip: Ipv4Addr = last_ip.into(); - last_ip.octets().iter().map(|x| u16::from(*x)).collect() - } - IpNetwork::V6(net) => { - let last_ip = u128::from(net.ip()) | (!u128::from(net.mask())); - let last_ip: Ipv6Addr = last_ip.into(); - last_ip.segments().to_vec() - } - }; - - let network_segments = match network_addr { - IpAddr::V4(ip) => ip.octets().iter().map(|x| u16::from(*x)).collect(), - IpAddr::V6(ip) => ip.segments().to_vec(), - }; - - let mut network_part = String::new(); - let mut modifiable_part = String::new(); - let delimiter = if ip.is_ipv4() { "." } else { ":" }; - let formatter = |x: &u16| { - if ip.is_ipv4() { - x.to_string() - } else { - format!("{x:04x}") - } - }; - - for (i, ((last_addr_segment, network_segment), ip_segment)) in last_addr_segments - .iter() - .zip(network_segments.iter()) - .zip(ip_segments.iter()) - .enumerate() - { - if last_addr_segment != network_segment { - let parts = ip_segments.split_at(i).1; - let joined = parts - .iter() - .map(formatter) - .collect::>() - .join(delimiter); - modifiable_part.push_str(&joined); - break; - } - let formatted = formatter(ip_segment); - network_part.push_str(&formatted); - network_part.push_str(delimiter); - } - - SplitIp { - ip: ip.to_string(), - network_part, - modifiable_part, - network_prefix: network_prefix.to_string(), - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_ip_splitter() { - let net = split_ip( - &IpAddr::from_str("192.168.3.1").unwrap(), - &IpNetwork::from_str("192.168.3.1/30").unwrap(), - ); - - assert_eq!(net.network_part, "192.168.3."); - assert_eq!(net.modifiable_part, "1"); - assert_eq!(net.network_prefix, "30"); - - let net = split_ip( - &IpAddr::from_str("192.168.5.7").unwrap(), - &IpNetwork::from_str("192.168.3.1/24").unwrap(), - ); - - assert_eq!(net.network_part, "192.168.5."); - assert_eq!(net.modifiable_part, "7"); - assert_eq!(net.network_prefix, "24"); - - let net = split_ip( - &IpAddr::from_str("2001:0db8:85a3::8a2e:0370:7334").unwrap(), - &IpNetwork::from_str("2001:0db8:85a3::8a2e:0370:7334/64").unwrap(), - ); - - assert_eq!(net.network_part, "2001:0db8:85a3:0000:"); - assert_eq!(net.modifiable_part, "0000:8a2e:0370:7334"); - assert_eq!(net.network_prefix, "64"); - - let net = split_ip( - &IpAddr::from_str("2001:0db8::0010:8a2e:0370:aaaa").unwrap(), - &IpNetwork::from_str("2001:db8::10:8a2e:370:aaa8/125").unwrap(), - ); - - assert_eq!(net.network_part, "2001:0db8:0000:0000:0010:8a2e:0370:"); - assert_eq!(net.modifiable_part, "aaaa"); - assert_eq!(net.network_prefix, "125"); - } -} diff --git a/crates/defguard_core/src/handlers/static_ips.rs b/crates/defguard_core/src/handlers/static_ips.rs new file mode 100644 index 0000000000..fe2fac0afd --- /dev/null +++ b/crates/defguard_core/src/handlers/static_ips.rs @@ -0,0 +1,92 @@ +use std::net::IpAddr; + +use axum::{ + Json, + extract::{Path, State}, + http::StatusCode, +}; +use defguard_common::db::Id; +use defguard_static_ip::{LocationDevices, get_ips_for_user}; +use serde::Serialize; + +use crate::{ + appstate::AppState, + auth::{AdminRole, SessionInfo}, + handlers::{ApiResponse, ApiResult}, +}; + +#[derive(Serialize)] +pub struct LocationDevicesResponse { + pub locations: Vec, +} + +pub async fn get_all_user_device_ips( + _admin_role: AdminRole, + _session: SessionInfo, + Path(username): Path, + State(state): State, +) -> ApiResult { + let locations = get_ips_for_user(&username, &state.pool).await?; + Ok(ApiResponse::json( + LocationDevicesResponse { locations }, + StatusCode::OK, + )) +} + +#[derive(Deserialize)] +pub struct StaticIpAssignment { + pub device_id: i64, + pub location_id: Id, + pub ips: Vec, +} + +pub async fn assign_static_ips( + _admin_role: AdminRole, + _session: SessionInfo, + State(state): State, + Json(payload): Json>, +) -> ApiResult { + let mut transaction = state.pool.begin().await?; + for assignment in payload { + defguard_static_ip::assign_static_ips( + assignment.device_id, + assignment.ips, + assignment.location_id, + &mut transaction, + ) + .await?; + } + transaction.commit().await?; + Ok(ApiResponse { + json: serde_json::json!({"message": "Static IPs assigned successfully"}), + status: StatusCode::OK, + }) +} + +#[derive(Deserialize)] +pub struct ValidateIpAssignmentRequest { + pub device_id: i64, + pub ip: IpAddr, + pub location: Id, +} + +pub async fn validate_ip_assignment( + _admin_role: AdminRole, + _session: SessionInfo, + State(state): State, + Json(payload): Json, +) -> ApiResult { + let mut transaction = state.pool.begin().await?; + defguard_static_ip::validate_ip( + payload.device_id, + payload.ip, + payload.location, + &mut transaction, + ) + .await?; + transaction.commit().await?; + Ok(ApiResponse { + json: serde_json::json!({"message": "IP assignment is valid"}), + status: StatusCode::OK, + }) +} diff --git a/crates/defguard_core/src/lib.rs b/crates/defguard_core/src/lib.rs index bb237404d6..00b8e07431 100644 --- a/crates/defguard_core/src/lib.rs +++ b/crates/defguard_core/src/lib.rs @@ -136,6 +136,7 @@ use crate::{ test_ldap_settings, update_settings, }, ssh_authorized_keys::get_authorized_keys, + static_ips::{assign_static_ips, get_all_user_device_ips, validate_ip_assignment}, support::{configuration, logs}, updates::outdated_components, user::{ @@ -467,6 +468,14 @@ pub fn build_webapp( ) .route("/device", get(list_devices)) .route("/device/user/{username}", get(list_user_devices)) + .route( + "/device/user/{username}/ip", + get(get_all_user_device_ips).post(assign_static_ips), + ) + .route( + "/device/user/{username}/ip/validate", + post(validate_ip_assignment), + ) // Network devices, as opposed to user devices .route( "/device/network", diff --git a/crates/defguard_mail/src/mail.rs b/crates/defguard_mail/src/mail.rs index de92c002c6..6977e3fb50 100644 --- a/crates/defguard_mail/src/mail.rs +++ b/crates/defguard_mail/src/mail.rs @@ -12,14 +12,13 @@ use tera::{Context, Tera, Value}; use thiserror::Error; use tracing::{debug, error, info, warn}; +use super::SmtpSettings; use crate::{ mail_context::MailContext, qr::qr_png, templates::{DEFAULT_LANG, TemplateError}, }; -use super::SmtpSettings; - #[derive(Debug)] pub struct Attachment { filename: String, diff --git a/crates/defguard_proxy_manager/src/handler.rs b/crates/defguard_proxy_manager/src/handler.rs index 5ee8fea4ad..06b24aba06 100644 --- a/crates/defguard_proxy_manager/src/handler.rs +++ b/crates/defguard_proxy_manager/src/handler.rs @@ -32,6 +32,7 @@ use defguard_core::{ }, version::{IncompatibleComponents, IncompatibleProxyData, is_proxy_version_supported}, }; +use defguard_grpc_tls::{certs as tls_certs, connector::HttpsSchemeConnector}; use defguard_proto::proxy::{ AuthCallbackResponse, AuthInfoResponse, CoreError, CoreRequest, CoreResponse, InitialInfo, core_request, core_response, proxy_client::ProxyClient, @@ -66,7 +67,6 @@ use crate::{ ProxyError, ProxyTxSet, TEN_SECS, servers::{EnrollmentServer, PasswordResetServer}, }; -use defguard_grpc_tls::{certs as tls_certs, connector::HttpsSchemeConnector}; static VERSION_ZERO: Version = Version::new(0, 0, 0); diff --git a/crates/defguard_proxy_manager/src/lib.rs b/crates/defguard_proxy_manager/src/lib.rs index 351f279001..b4bd75255b 100644 --- a/crates/defguard_proxy_manager/src/lib.rs +++ b/crates/defguard_proxy_manager/src/lib.rs @@ -6,7 +6,6 @@ use std::{ use defguard_common::{db::models::proxy::Proxy, types::proxy::ProxyControlMessage}; use defguard_core::{events::BidiStreamEvent, grpc::GatewayEvent, version::IncompatibleComponents}; - use sqlx::PgPool; use tokio::{ select, diff --git a/crates/defguard_static_ip/Cargo.toml b/crates/defguard_static_ip/Cargo.toml new file mode 100644 index 0000000000..d6497aad64 --- /dev/null +++ b/crates/defguard_static_ip/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "defguard_static_ip" +version = "0.0.0" +edition.workspace = true +license-file.workspace = true +homepage.workspace = true +repository.workspace = true +rust-version.workspace = true + +[dependencies] +defguard_common.workspace = true +anyhow.workspace = true +axum.workspace = true +sqlx.workspace = true +tokio.workspace = true +serde_json.workspace = true +serde.workspace = true +thiserror.workspace = true +tracing.workspace = true diff --git a/crates/defguard_static_ip/src/error.rs b/crates/defguard_static_ip/src/error.rs new file mode 100644 index 0000000000..45dcbeb023 --- /dev/null +++ b/crates/defguard_static_ip/src/error.rs @@ -0,0 +1,14 @@ +use defguard_common::db::models::wireguard::NetworkAddressError; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum StaticIpError { + #[error("Network (location) with ID {0} not found")] + NetworkNotFound(i64), + #[error("Device {0} is not assigned to network {1}")] + DeviceNotInNetwork(i64, i64), + #[error(transparent)] + InvalidIpAssignment(#[from] NetworkAddressError), + #[error(transparent)] + SqlxError(#[from] sqlx::Error), +} diff --git a/crates/defguard_static_ip/src/lib.rs b/crates/defguard_static_ip/src/lib.rs new file mode 100644 index 0000000000..d187eb0e05 --- /dev/null +++ b/crates/defguard_static_ip/src/lib.rs @@ -0,0 +1,614 @@ +use std::net::IpAddr; + +use defguard_common::{ + db::{ + Id, + models::{WireguardNetwork, device::WireguardNetworkDevice}, + }, + utils::{SplitIp, split_ip}, +}; +use serde::Serialize; +use sqlx::{PgConnection, PgPool, prelude::FromRow}; +use tracing::debug; + +use crate::error::StaticIpError; + +pub mod error; + +#[derive(Serialize)] +pub struct LocationDevices { + pub location_id: i64, + pub location_name: String, + pub devices: Vec, +} + +#[derive(Serialize)] +pub struct DeviceIps { + pub device_id: i64, + pub device_name: String, + pub wireguard_ips: Vec, +} + +#[derive(FromRow)] +struct DeviceIpRow { + location_id: i64, + location_name: String, + device_id: i64, + device_name: String, + wireguard_ips: Vec, +} + +pub async fn get_ips_for_user( + username: &str, + pool: &PgPool, +) -> Result, StaticIpError> { + debug!("Fetching static IPs for user {username}"); + let rows = sqlx::query_as!( + DeviceIpRow, + "SELECT \ + wn.id AS location_id, \ + wn.name AS location_name, \ + d.id AS device_id, \ + d.name AS device_name, \ + wnd.wireguard_ips AS \"wireguard_ips: Vec\" \ + FROM wireguard_network wn \ + JOIN wireguard_network_device wnd ON wnd.wireguard_network_id = wn.id \ + JOIN device d ON d.id = wnd.device_id \ + JOIN \"user\" u ON d.user_id = u.id \ + WHERE u.username = $1 \ + ORDER BY wn.name, d.name", + username + ) + .fetch_all(pool) + .await?; + + debug!( + "Found {} device-location assignments for user {username}", + rows.len() + ); + let mut locations: Vec = Vec::new(); + + for row in rows { + let network = WireguardNetwork::find_by_id(pool, row.location_id) + .await? + .ok_or(StaticIpError::NetworkNotFound(row.location_id))?; + + let wireguard_ips: Vec = row + .wireguard_ips + .iter() + .filter_map(|ip| { + network + .get_containing_network(*ip) + .map(|net| split_ip(ip, &net)) + }) + .collect(); + + let device = DeviceIps { + device_id: row.device_id, + device_name: row.device_name, + wireguard_ips, + }; + + match locations.last_mut() { + Some(loc) if loc.location_id == row.location_id => { + loc.devices.push(device); + } + _ => { + locations.push(LocationDevices { + location_id: row.location_id, + location_name: row.location_name, + devices: vec![device], + }); + } + } + } + + debug!( + "Returning IP data for {} location(s) for user {username}", + locations.len() + ); + Ok(locations) +} + +pub async fn assign_static_ips( + device_id: Id, + ips: Vec, + location: Id, + transaction: &mut PgConnection, +) -> Result<(), StaticIpError> { + debug!("Assigning static IPs {ips:?} to device {device_id} in location {location}"); + let network = WireguardNetwork::find_by_id(&mut *transaction, location) + .await? + .ok_or(StaticIpError::NetworkNotFound(location))?; + + let mut network_device = WireguardNetworkDevice::find(&mut *transaction, device_id, location) + .await? + .ok_or(StaticIpError::DeviceNotInNetwork(device_id, location))?; + + network + .can_assign_ips(transaction, &ips, Some(device_id)) + .await?; + + network_device.wireguard_ips = ips; + network_device.update(&mut *transaction).await?; + + debug!("Static IPs successfully assigned to device {device_id} in location {location}"); + Ok(()) +} + +pub async fn validate_ip( + device_id: Id, + ip: IpAddr, + location: Id, + transaction: &mut PgConnection, +) -> Result<(), StaticIpError> { + debug!("Validating IP {ip} for device {device_id} in location {location}"); + let network = WireguardNetwork::find_by_id(&mut *transaction, location) + .await? + .ok_or(StaticIpError::NetworkNotFound(location))?; + + let result = network + .can_assign_ips(transaction, &[ip], Some(device_id)) + .await + .map_err(StaticIpError::InvalidIpAssignment); + + if result.is_ok() { + debug!("IP {ip} is valid for device {device_id} in location {location}"); + } else { + debug!( + "IP {ip} is NOT valid for device {device_id} in location {location}: {:?}", + result + ); + } + + result +} + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use defguard_common::db::{ + models::{Device, DeviceType, User, WireguardNetwork, device::WireguardNetworkDevice}, + setup_pool, + }; + use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; + + use super::*; + + #[sqlx::test] + async fn test_get_ips_for_user_groups_by_location(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + + // Create a test user + let user = User::new( + "testuser", + Some("Test123!"), + "User", + "Test", + "test@example.com", + None, + ) + .save(&pool) + .await + .expect("Failed to create user"); + + // Create test locations + let mut location_a = WireguardNetwork { + name: "Location A".into(), + ..Default::default() + }; + location_a.try_set_address("10.0.1.1/24").unwrap(); + let location_a = location_a + .save(&pool) + .await + .expect("Failed to create Location A"); + + let mut location_b = WireguardNetwork { + name: "Location B".into(), + ..Default::default() + }; + location_b.try_set_address("10.0.2.1/24").unwrap(); + let location_b = location_b + .save(&pool) + .await + .expect("Failed to create Location B"); + + // Create test devices for the user + let device1 = Device::new( + "Device 1".into(), + "pubkey1".into(), + user.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .expect("Failed to create device 1"); + + let device2 = Device::new( + "Device 2".into(), + "pubkey2".into(), + user.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .expect("Failed to create device 2"); + + let device3 = Device::new( + "Device 3".into(), + "pubkey3".into(), + user.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .expect("Failed to create device 3"); + + let device4 = Device::new( + "Device 4".into(), + "pubkey4".into(), + user.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .expect("Failed to create device 4"); + + // Create network-device mappings with IPs + WireguardNetworkDevice::new( + location_a.id, + device1.id, + vec![ + IpAddr::from_str("10.0.1.2").unwrap(), + IpAddr::from_str("10.0.1.3").unwrap(), + ], + ) + .insert(&pool) + .await + .expect("Failed to assign device 1 to location A"); + + WireguardNetworkDevice::new( + location_a.id, + device2.id, + vec![IpAddr::from_str("10.0.1.4").unwrap()], + ) + .insert(&pool) + .await + .expect("Failed to assign device 2 to location A"); + + WireguardNetworkDevice::new( + location_b.id, + device3.id, + vec![IpAddr::from_str("10.0.2.2").unwrap()], + ) + .insert(&pool) + .await + .expect("Failed to assign device 3 to location B"); + + WireguardNetworkDevice::new( + location_b.id, + device4.id, + vec![ + IpAddr::from_str("10.0.2.3").unwrap(), + IpAddr::from_str("10.0.2.4").unwrap(), + ], + ) + .insert(&pool) + .await + .expect("Failed to assign device 4 to location B"); + + // Call the function + let result = get_ips_for_user("testuser", &pool).await; + assert!(result.is_ok()); + + let locations = result.unwrap(); + assert_eq!(locations.len(), 2); + + let net_a = location_a.address[0]; + let net_b = location_b.address[0]; + + // Verify Location A + assert_eq!(locations[0].location_name, "Location A"); + assert_eq!(locations[0].devices.len(), 2); + assert_eq!(locations[0].devices[0].device_name, "Device 1"); + assert_eq!( + locations[0].devices[0].wireguard_ips, + vec![ + split_ip(&IpAddr::from_str("10.0.1.2").unwrap(), &net_a), + split_ip(&IpAddr::from_str("10.0.1.3").unwrap(), &net_a), + ] + ); + assert_eq!(locations[0].devices[1].device_name, "Device 2"); + assert_eq!( + locations[0].devices[1].wireguard_ips, + vec![split_ip(&IpAddr::from_str("10.0.1.4").unwrap(), &net_a)] + ); + + // Verify Location B + assert_eq!(locations[1].location_name, "Location B"); + assert_eq!(locations[1].devices.len(), 2); + assert_eq!(locations[1].devices[0].device_name, "Device 3"); + assert_eq!(locations[1].devices[1].device_name, "Device 4"); + assert_eq!( + locations[1].devices[0].wireguard_ips, + vec![split_ip(&IpAddr::from_str("10.0.2.2").unwrap(), &net_b)] + ); + assert_eq!( + locations[1].devices[1].wireguard_ips, + vec![ + split_ip(&IpAddr::from_str("10.0.2.3").unwrap(), &net_b), + split_ip(&IpAddr::from_str("10.0.2.4").unwrap(), &net_b), + ] + ); + } + + #[sqlx::test] + async fn test_assign_static_ips_success(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + + let user = User::new( + "assignuser", + Some("Test123!"), + "User", + "Test", + "assign@example.com", + None, + ) + .save(&pool) + .await + .expect("Failed to create user"); + + let mut network = WireguardNetwork { + name: "Assign Network".into(), + ..Default::default() + }; + network.try_set_address("10.0.0.1/24").unwrap(); + let network = network.save(&pool).await.expect("Failed to create network"); + + let device = Device::new( + "Assign Device".into(), + "assignpubkey1".into(), + user.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .expect("Failed to create device"); + + WireguardNetworkDevice::new( + network.id, + device.id, + vec![IpAddr::from_str("10.0.0.2").unwrap()], + ) + .insert(&pool) + .await + .expect("Failed to assign device to network"); + + let new_ips = vec![IpAddr::from_str("10.0.0.10").unwrap()]; + let mut conn = pool.acquire().await.expect("Failed to acquire connection"); + assign_static_ips(device.id, new_ips.clone(), network.id, &mut conn) + .await + .expect("assign_static_ips should succeed"); + + let updated = WireguardNetworkDevice::find(&pool, device.id, network.id) + .await + .unwrap() + .expect("Network device entry should exist"); + assert_eq!(updated.wireguard_ips, new_ips); + } + + #[sqlx::test] + async fn test_assign_static_ips_network_not_found(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + + let mut conn = pool.acquire().await.expect("Failed to acquire connection"); + let result = assign_static_ips( + 1, + vec![IpAddr::from_str("10.0.0.2").unwrap()], + 9999, + &mut conn, + ) + .await; + + assert!(matches!(result, Err(StaticIpError::NetworkNotFound(9999)))); + } + + #[sqlx::test] + async fn test_assign_static_ips_device_not_in_network( + _: PgPoolOptions, + options: PgConnectOptions, + ) { + let pool = setup_pool(options).await; + + let user = User::new( + "nonetworkuser", + Some("Test123!"), + "User", + "Test", + "nonet@example.com", + None, + ) + .save(&pool) + .await + .expect("Failed to create user"); + + let mut network = WireguardNetwork { + name: "NoDevice Network".into(), + ..Default::default() + }; + network.try_set_address("10.0.0.1/24").unwrap(); + let network = network.save(&pool).await.expect("Failed to create network"); + + let device = Device::new( + "Unassigned Device".into(), + "unassignedpubkey".into(), + user.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .expect("Failed to create device"); + + let mut conn = pool.acquire().await.expect("Failed to acquire connection"); + let result = assign_static_ips( + device.id, + vec![IpAddr::from_str("10.0.0.2").unwrap()], + network.id, + &mut conn, + ) + .await; + + assert!(matches!( + result, + Err(StaticIpError::DeviceNotInNetwork(_, _)) + )); + } + + #[sqlx::test] + async fn test_assign_static_ips_ip_out_of_range(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + + let user = User::new( + "rangeuser", + Some("Test123!"), + "User", + "Test", + "range@example.com", + None, + ) + .save(&pool) + .await + .expect("Failed to create user"); + + let mut network = WireguardNetwork { + name: "Range Network".into(), + ..Default::default() + }; + network.try_set_address("10.0.0.1/24").unwrap(); + let network = network.save(&pool).await.expect("Failed to create network"); + + let device = Device::new( + "Range Device".into(), + "rangepubkey".into(), + user.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .expect("Failed to create device"); + + WireguardNetworkDevice::new( + network.id, + device.id, + vec![IpAddr::from_str("10.0.0.2").unwrap()], + ) + .insert(&pool) + .await + .expect("Failed to assign device to network"); + + // IP is outside the 10.0.0.0/24 range + let mut conn = pool.acquire().await.expect("Failed to acquire connection"); + let result = assign_static_ips( + device.id, + vec![IpAddr::from_str("192.168.1.5").unwrap()], + network.id, + &mut conn, + ) + .await; + + assert!(matches!(result, Err(StaticIpError::InvalidIpAssignment(_)))); + } + + #[sqlx::test] + async fn test_assign_static_ips_ip_already_used(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + + let user = User::new( + "conflictuser", + Some("Test123!"), + "User", + "Test", + "conflict@example.com", + None, + ) + .save(&pool) + .await + .expect("Failed to create user"); + + let mut network = WireguardNetwork { + name: "Conflict Network".into(), + ..Default::default() + }; + network.try_set_address("10.0.0.1/24").unwrap(); + let network = network.save(&pool).await.expect("Failed to create network"); + + let device1 = Device::new( + "Conflict Device 1".into(), + "conflictpubkey1".into(), + user.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .expect("Failed to create device 1"); + + let device2 = Device::new( + "Conflict Device 2".into(), + "conflictpubkey2".into(), + user.id, + DeviceType::User, + None, + true, + ) + .save(&pool) + .await + .expect("Failed to create device 2"); + + WireguardNetworkDevice::new( + network.id, + device1.id, + vec![IpAddr::from_str("10.0.0.3").unwrap()], + ) + .insert(&pool) + .await + .expect("Failed to assign device 1 to network"); + + WireguardNetworkDevice::new( + network.id, + device2.id, + vec![IpAddr::from_str("10.0.0.4").unwrap()], + ) + .insert(&pool) + .await + .expect("Failed to assign device 2 to network"); + + // Try to steal 10.0.0.3 which belongs to device1 + let mut conn = pool.acquire().await.expect("Failed to acquire connection"); + let result = assign_static_ips( + device2.id, + vec![IpAddr::from_str("10.0.0.3").unwrap()], + network.id, + &mut conn, + ) + .await; + + assert!(matches!(result, Err(StaticIpError::InvalidIpAssignment(_)))); + } +} diff --git a/deny.toml b/deny.toml index 558e412364..2877aa45ed 100644 --- a/deny.toml +++ b/deny.toml @@ -183,6 +183,10 @@ exceptions = [ "AGPL-3.0-only", "AGPL-3.0-or-later", ], crate = "defguard_setup" }, + { allow = [ + "AGPL-3.0-only", + "AGPL-3.0-or-later", + ], crate = "defguard_static_ip" }, ] # Some crates don't have (easily) machine readable licensing information, diff --git a/tools/defguard_generator/src/vpn_session_stats.rs b/tools/defguard_generator/src/vpn_session_stats.rs index b306abb647..7e22211822 100644 --- a/tools/defguard_generator/src/vpn_session_stats.rs +++ b/tools/defguard_generator/src/vpn_session_stats.rs @@ -6,6 +6,7 @@ use defguard_common::db::{ Id, models::{ WireguardNetwork, + device::WireguardNetworkDevice, gateway::Gateway, vpn_client_session::{VpnClientSession, VpnClientSessionState}, vpn_session_stats::VpnSessionStats, @@ -69,6 +70,27 @@ pub async fn generate_vpn_session_stats( let devices = prepare_user_devices(&pool, &mut rng, &user, config.devices_per_user as usize).await?; + // assign devices to the network if not already assigned + for device in &devices { + if WireguardNetworkDevice::find(&mut *transaction, device.id, location.id) + .await? + .is_none() + { + info!( + "Assigning device {} to network {} with auto-generated IP", + device.name, location.name + ); + device + .assign_next_network_ip(&mut transaction, &location, None, None) + .await?; + } else { + info!( + "Device {} already assigned to network {}", + device.name, location.name + ); + } + } + for device in devices { info!("Generating sessions for device {device}"); // generate requested number of sessions for a device diff --git a/web/messages/en/modal.json b/web/messages/en/modal.json index 26ad3a3b86..dce76d4065 100644 --- a/web/messages/en/modal.json +++ b/web/messages/en/modal.json @@ -122,5 +122,13 @@ "modal_ce_webhook_edit_title": "Edit Webhook", "modal_ce_webhook_events_title": "Trigger events", "modal_ce_webhook_events_text": "", - "modal_assign_users_groups_title": "Assign groups to selected users" + "modal_assign_users_groups_title": "Assign groups to selected users", + "modal_assign_user_ip_title": "{firstName} {lastName} IP settings", + "modal_assign_user_ip_title_fallback": "IP settings", + "modal_assign_user_ip_assignment_mode_title": "Single IP Assignment", + "modal_assign_user_ip_assignment_mode_description": "Assign and update IP addresses sequentially, managing each user's device individually.", + "modal_assign_user_ip_success": "{firstName} {lastName}'s IP addresses were successfully updated.", + "modal_assign_user_ip_error": "Failed to update IP addresses", + "modal_assign_user_ip_validation_error": "Invalid or already taken", + "modal_assign_user_ip_no_locations": "No locations available. Add a location first." } diff --git a/web/messages/en/users.json b/web/messages/en/users.json index 75741282ff..3e78b582b8 100644 --- a/web/messages/en/users.json +++ b/web/messages/en/users.json @@ -25,5 +25,6 @@ "users_row_menu_change_password": "Change password", "users_row_menu_edit_groups": "Edit groups", "users_row_menu_initiate_self_enrollment": "Initiate self-enrollment", + "users_row_menu_ip_settings": "User devices IP settings", "modal_edit_user_groups_title": "Edit user groups" } diff --git a/web/src/pages/UsersOverviewPage/UsersOverviewPage.tsx b/web/src/pages/UsersOverviewPage/UsersOverviewPage.tsx index 38335b2dec..55f99403bb 100644 --- a/web/src/pages/UsersOverviewPage/UsersOverviewPage.tsx +++ b/web/src/pages/UsersOverviewPage/UsersOverviewPage.tsx @@ -7,6 +7,7 @@ import { ChangePasswordModal } from '../../shared/components/modals/ChangePasswo import { TableSkeleton } from '../../shared/components/skeleton/TableSkeleton/TableSkeleton'; import { TablePageLayout } from '../../shared/layout/TablePageLayout/TablePageLayout'; import { AddUserModal } from './modals/AddUserModal/AddUserModal'; +import { AssignUserIPModal } from './modals/AssignUserIPModal/AssignUserIPModal'; import { AssignUsersToGroupsModal } from './modals/AssignUsersToGroupsModal/AssignUsersToGroupsModal'; import { EditUserModal } from './modals/EditUserModal/EditUserModal'; import { EnrollmentTokenModal } from './modals/EnrollmentTokenModal/EnrollmentTokenModal'; @@ -28,6 +29,7 @@ export const UsersOverviewPage = () => { + ); }; diff --git a/web/src/pages/UsersOverviewPage/UsersTable.tsx b/web/src/pages/UsersOverviewPage/UsersTable.tsx index 19a04992dd..26545dcb48 100644 --- a/web/src/pages/UsersOverviewPage/UsersTable.tsx +++ b/web/src/pages/UsersOverviewPage/UsersTable.tsx @@ -302,6 +302,18 @@ export const UsersTable = () => { }); }, }, + { + text: m.users_row_menu_ip_settings(), + icon: IconKind.Gateway, + testId: 'assign-ip', + onClick: async () => { + const response = await api.device.getUserDeviceIps(rowData.username); + openModal(ModalName.AssignUserIP, { + user: rowData, + locationData: response.data, + }); + }, + }, ], }, { diff --git a/web/src/pages/UsersOverviewPage/modals/AssignUserIPModal/AssignUserIPModal.tsx b/web/src/pages/UsersOverviewPage/modals/AssignUserIPModal/AssignUserIPModal.tsx new file mode 100644 index 0000000000..20b7fc0d1a --- /dev/null +++ b/web/src/pages/UsersOverviewPage/modals/AssignUserIPModal/AssignUserIPModal.tsx @@ -0,0 +1,280 @@ +import { useStore } from '@tanstack/react-form'; +import { useMutation } from '@tanstack/react-query'; +import axios from 'axios'; +import { useCallback, useEffect, useMemo, useState } from 'react'; +import z from 'zod'; +import { m } from '../../../../paraglide/messages'; +import api from '../../../../shared/api/api'; +import type { + LocationDevices, + LocationDevicesResponse, +} from '../../../../shared/api/types'; +import { IpAssignmentCard } from '../../../../shared/components/IpAssignmentCard/IpAssignmentCard'; +import { IpAssignmentDeviceSection } from '../../../../shared/components/IpAssignmentDeviceSection/IpAssignmentDeviceSection'; +import { Modal } from '../../../../shared/defguard-ui/components/Modal/Modal'; +import { ModalControls } from '../../../../shared/defguard-ui/components/ModalControls/ModalControls'; +import { SuggestedIpInput } from '../../../../shared/defguard-ui/components/SuggestedIPInput/SuggestedIPInput'; +import { Snackbar } from '../../../../shared/defguard-ui/providers/snackbar/snackbar'; +import { isPresent } from '../../../../shared/defguard-ui/utils/isPresent'; +import { useAppForm } from '../../../../shared/form'; +import { + closeModal, + subscribeCloseModal, + subscribeOpenModal, +} from '../../../../shared/hooks/modalControls/modalsSubjects'; +import { ModalName } from '../../../../shared/hooks/modalControls/modalTypes'; +import type { OpenAssignUserIPModal } from '../../../../shared/hooks/modalControls/types'; +import './style.scss'; +import { SizedBox } from '../../../../shared/defguard-ui/components/SizedBox/SizedBox'; +import { ThemeSpacing } from '../../../../shared/defguard-ui/types'; + +const modalNameValue = ModalName.AssignUserIP; + +type ModalData = OpenAssignUserIPModal; + +const formSchema = z.object({ + locations: z.array( + z.object({ + location_id: z.number(), + devices: z.array( + z.object({ + device_id: z.number(), + ips: z.array( + z.object({ + modifiable_part: z.string().trim(), + network_part: z.string(), + }), + ), + }), + ), + }), + ), +}); + +type FormFields = z.infer; + +export const AssignUserIPModal = () => { + const [isOpen, setOpen] = useState(false); + const [modalData, setModalData] = useState(null); + + useEffect(() => { + const openSub = subscribeOpenModal(modalNameValue, (data) => { + setModalData(data); + setOpen(true); + }); + const closeSub = subscribeCloseModal(modalNameValue, () => setOpen(false)); + return () => { + openSub.unsubscribe(); + closeSub.unsubscribe(); + }; + }, []); + + return ( + setOpen(false)} + afterClose={() => { + setModalData(null); + }} + > + {isPresent(modalData) && } + + ); +}; + +const ModalContent = ({ user, locationData }: ModalData) => { + return ( + + ); +}; + +type AssignmentFormProps = { + username: string; + firstName: string; + lastName: string; + locationData: LocationDevicesResponse; +}; + +const AssignmentForm = ({ + username, + firstName, + lastName, + locationData, +}: AssignmentFormProps) => { + const [openLocations, setOpenLocations] = useState>(() => new Set()); + + const defaultValues: FormFields = useMemo( + () => ({ + locations: locationData.locations.map((loc) => ({ + location_id: loc.location_id, + devices: loc.devices.map((dev) => ({ + device_id: dev.device_id, + ips: dev.wireguard_ips.map((ip) => ({ + modifiable_part: ip.modifiable_part, + network_part: ip.network_part, + })), + })), + })), + }), + [locationData], + ); + + const { mutateAsync: updateDevices } = useMutation({ + mutationFn: (formData: FormFields) => { + const assignments = formData.locations + .flatMap((loc) => + loc.devices.map((dev) => ({ + device_id: dev.device_id, + location_id: loc.location_id, + ips: dev.ips + .filter((e) => e.modifiable_part.length > 0) + .map((e) => `${e.network_part}${e.modifiable_part}`), + })), + ) + .filter((a) => a.ips.length > 0); + return api.device.assignUserDeviceIps(username, assignments); + }, + meta: { + invalidate: [['user-device-ips', username]], + }, + onSuccess: () => { + Snackbar.default(m.modal_assign_user_ip_success({ firstName, lastName })); + closeModal(modalNameValue); + }, + onError: (error) => { + console.error('Failed to update IP addresses:', error); + Snackbar.error(m.modal_assign_user_ip_error()); + }, + }); + + const form = useAppForm({ + defaultValues, + validators: { + onSubmit: formSchema, + }, + onSubmit: async ({ value }) => { + await updateDevices(value); + }, + }); + + const isSubmitting = useStore(form.store, (s) => s.isSubmitting); + + const toggleLocation = (locationId: number) => { + setOpenLocations((prev) => { + const next = new Set(prev); + if (next.has(locationId)) { + next.delete(locationId); + } else { + next.add(locationId); + } + return next; + }); + }; + + const validateIp = useCallback( + async (value: string, deviceId: number, locationId: number) => { + try { + await api.device.validateUserDeviceIp(username, { + device_id: deviceId, + ip: value, + location: locationId, + }); + return undefined; + } catch (e) { + return axios.isAxiosError(e) + ? (e.response?.data?.msg ?? m.modal_assign_user_ip_validation_error()) + : m.modal_assign_user_ip_validation_error(); + } + }, + [username], + ); + + return ( + +
+
+

{m.modal_assign_user_ip_assignment_mode_title()}

+ +

+ {m.modal_assign_user_ip_assignment_mode_description()} +

+
+ +
+ {locationData.locations.length === 0 && ( +

{m.modal_assign_user_ip_no_locations()}

+ )} + {locationData.locations.map((location: LocationDevices, locIdx) => ( + toggleLocation(location.location_id)} + > + {location.devices.map((deviceIps, devIdx) => ( + + {deviceIps.wireguard_ips.map((ipData, ipIdx) => ( + + validateIp( + `${ipData.network_part}${value}`, + deviceIps.device_id, + location.location_id, + ), + }} + > + {(field) => ( + field.handleChange(val ?? '')} + onBlur={field.handleBlur} + /> + )} + + ))} + + ))} + + ))} +
+ + form.handleSubmit(), + }} + cancelProps={{ + text: m.controls_cancel(), + disabled: isSubmitting, + onClick: () => closeModal(modalNameValue), + }} + /> +
+
+ ); +}; diff --git a/web/src/pages/UsersOverviewPage/modals/AssignUserIPModal/style.scss b/web/src/pages/UsersOverviewPage/modals/AssignUserIPModal/style.scss new file mode 100644 index 0000000000..4605ecedde --- /dev/null +++ b/web/src/pages/UsersOverviewPage/modals/AssignUserIPModal/style.scss @@ -0,0 +1,33 @@ +#assign-user-ip-modal { + .assign-user-ip-modal { + display: flex; + flex-direction: column; + gap: var(--spacing-xl); + min-height: 400px; + + .assignment-mode { + display: flex; + flex-direction: column; + + h3 { + font: var(--t-body-sm-400); + } + + .mode-description { + color: var(--fg-neutral); + font: var(--t-body-xs-400); + } + } + + .devices-list { + display: flex; + flex-direction: column; + gap: var(--spacing-md); + + .no-locations { + color: var(--fg-neutral); + font: var(--t-body-sm-400); + } + } + } +} diff --git a/web/src/shared/api/api.ts b/web/src/shared/api/api.ts index 7ad86695d3..e2e6542003 100644 --- a/web/src/shared/api/api.ts +++ b/web/src/shared/api/api.ts @@ -28,6 +28,7 @@ import type { AdminChangeUserPasswordRequest, ApiToken, ApplicationInfo, + AssignStaticIpsRequest, AuthKey, AvailableLocationIpResponse, ChangeAccountActiveRequest, @@ -57,6 +58,7 @@ import type { GroupsResponse, IpValidation, LicenseInfoResponse, + LocationDevicesResponse, LocationDevicesStats, LocationStats, LocationStatsRequest, @@ -87,6 +89,7 @@ import type { UserProfileResponse, UsersListItem, ValidateDeviceIpsRequest, + ValidateIpAssignmentRequest, WebauthnLoginStartResponse, WebauthnRegisterFinishRequest, WebauthnRegisterStartResponse, @@ -339,6 +342,12 @@ const api = { getDevices: () => client.get('/device'), getDeviceConfig: ({ deviceId, networkId }: { networkId: number; deviceId: number }) => client.get(`/network/${networkId}/device/${deviceId}/config`), + getUserDeviceIps: (username: string) => + client.get(`/device/user/${username}/ip`), + assignUserDeviceIps: (username: string, data: AssignStaticIpsRequest) => + client.post(`/device/user/${username}/ip`, data), + validateUserDeviceIp: (username: string, data: ValidateIpAssignmentRequest) => + client.post(`/device/user/${username}/ip/validate`, data), getDeviceConfigs: async (device: Device): Promise => { const networkConfigurations: AddDeviceResponseConfig[] = []; for (const network of device.networks) { diff --git a/web/src/shared/api/types.ts b/web/src/shared/api/types.ts index 0ec4a3cff5..c10ac33330 100644 --- a/web/src/shared/api/types.ts +++ b/web/src/shared/api/types.ts @@ -67,6 +67,36 @@ export interface AvailableLocationIP { export type AvailableLocationIpResponse = AvailableLocationIP[]; +export interface DeviceIps { + device_id: number; + device_name: string; + wireguard_ips: AvailableLocationIP[]; +} + +export interface LocationDevices { + location_id: number; + location_name: string; + devices: DeviceIps[]; +} + +export interface LocationDevicesResponse { + locations: LocationDevices[]; +} + +export interface StaticIpAssignment { + device_id: number; + location_id: number; + ips: string[]; +} + +export type AssignStaticIpsRequest = StaticIpAssignment[]; + +export interface ValidateIpAssignmentRequest { + device_id: number; + ip: string; + location: number; +} + export type AddUsersToGroupsRequest = { groups: string[]; users: number[]; diff --git a/web/src/shared/components/IpAssignmentCard/style.scss b/web/src/shared/components/IpAssignmentCard/style.scss index 9513fb4c40..bde2ce933a 100644 --- a/web/src/shared/components/IpAssignmentCard/style.scss +++ b/web/src/shared/components/IpAssignmentCard/style.scss @@ -37,3 +37,9 @@ padding-left: var(--spacing-3xl); } } + +.devices { + display: flex; + flex-direction: column; + gap: var(--spacing-xl); +} diff --git a/web/src/shared/defguard-ui b/web/src/shared/defguard-ui index 7db2a91bd1..96b3e1f592 160000 --- a/web/src/shared/defguard-ui +++ b/web/src/shared/defguard-ui @@ -1 +1 @@ -Subproject commit 7db2a91bd1744ab345a07e38d20c45724ccfa9ee +Subproject commit 96b3e1f5922f43b4d9de3c1b210c36c02d592f46 diff --git a/web/src/shared/hooks/modalControls/modalTypes.ts b/web/src/shared/hooks/modalControls/modalTypes.ts index 6ee2592add..3379b964da 100644 --- a/web/src/shared/hooks/modalControls/modalTypes.ts +++ b/web/src/shared/hooks/modalControls/modalTypes.ts @@ -4,6 +4,7 @@ import type { OpenAddApiTokenModal, OpenAddLocationModal, OpenAddNetworkDeviceModal, + OpenAssignUserIPModal, OpenAssignUsersToGroupsModal, OpenAuthKeyRenameModal, OpenCEGroupModal, @@ -55,6 +56,7 @@ export const ModalName = { EditLogStreaming: 'editLogStreaming', DeleteLogStreaming: 'deleteLogStreaming', SelfEnrollmentToken: 'selfEnrollmentToken', + AssignUserIP: 'assignUserIP', } as const; export type ModalNameValue = (typeof ModalName)[keyof typeof ModalName]; @@ -180,6 +182,10 @@ const modalOpenArgsSchema = z.discriminatedUnion('name', [ name: z.literal(ModalName.LicenseExpired), data: z.custom(), }), + z.object({ + name: z.literal(ModalName.AssignUserIP), + data: z.custom(), + }), ]); export type ModalOpenEvent = z.infer; diff --git a/web/src/shared/hooks/modalControls/types.ts b/web/src/shared/hooks/modalControls/types.ts index 6e5005d5bf..a67e8f1f76 100644 --- a/web/src/shared/hooks/modalControls/types.ts +++ b/web/src/shared/hooks/modalControls/types.ts @@ -4,6 +4,7 @@ import type { GroupInfo, LicenseInfo, LicenseTierValue, + LocationDevicesResponse, NetworkDevice, NetworkLocation, OpenIdClient, @@ -107,3 +108,8 @@ export interface OpenLicenseExpiredModal { export interface OpenAddLocationModal { license: LicenseInfo | null; } + +export interface OpenAssignUserIPModal { + user: User; + locationData: LocationDevicesResponse; +}