diff --git a/crates/defguard_core/src/enterprise/firewall/tests/destination.rs b/crates/defguard_core/src/enterprise/firewall/tests/destination.rs index 213bbd1a18..ca198a64ed 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/destination.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/destination.rs @@ -5,7 +5,8 @@ use std::{ use defguard_common::db::{NoId, models::WireguardNetwork, setup_pool}; use defguard_proto::enterprise::firewall::{ - FirewallPolicy, IpAddress, IpRange, ip_address::Address, + FirewallPolicy, IpAddress, IpRange, Port, Protocol, ip_address::Address, + port::Port as PortInner, }; use rand::thread_rng; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; @@ -534,3 +535,478 @@ async fn test_manual_destination_merges_rule_and_component_alias_address_ranges( assert!(deny_rule.source_addrs.is_empty()); assert_eq!(deny_rule.destination_addrs, expected_destination_addrs); } + +#[sqlx::test] +async fn test_any_port_preserves_destination_addresses_and_protocols( + _: PgPoolOptions, + options: PgConnectOptions, +) { + set_test_license_business(); + let pool = setup_pool(options).await; + + let mut rng = thread_rng(); + + let mut location = WireguardNetwork::default() + .set_address(["10.0.0.1/16".parse().unwrap()]) + .unwrap(); + location.acl_enabled = true; + let location = location.save(&pool).await.unwrap(); + + create_test_users_and_devices(&mut rng, &pool, vec![&location]).await; + + let acl_rule = AclRule { + name: "any port manual destination rule".to_string(), + state: RuleState::Applied, + allow_all_users: true, + addresses: vec!["192.168.50.0/24".parse().unwrap()], + ports: vec![ + crate::enterprise::db::models::acl::PortRange::new(22, 22).into(), + crate::enterprise::db::models::acl::PortRange::new(443, 443).into(), + ], + protocols: vec![Protocol::Tcp.into(), Protocol::Udp.into()], + any_address: false, + any_port: true, + any_protocol: false, + use_manual_destination_settings: true, + ..Default::default() + }; + + create_acl_rule( + &pool, + acl_rule, + vec![location.id], + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + vec![( + IpAddr::V4(Ipv4Addr::new(192, 168, 60, 10)), + IpAddr::V4(Ipv4Addr::new(192, 168, 60, 20)), + )], + Vec::new(), + ) + .await; + + let mut conn = pool.acquire().await.unwrap(); + let generated_firewall_rules = try_get_location_firewall_config(&location, &mut conn) + .await + .unwrap() + .unwrap() + .rules; + + let expected_source_addrs = [ + IpAddress { + address: Some(Address::IpRange(IpRange { + start: "10.0.1.1".to_string(), + end: "10.0.1.2".to_string(), + })), + }, + IpAddress { + address: Some(Address::IpRange(IpRange { + start: "10.0.2.1".to_string(), + end: "10.0.2.2".to_string(), + })), + }, + ]; + let expected_destination_addrs = [ + IpAddress { + address: Some(Address::IpSubnet("192.168.50.0/24".to_string())), + }, + IpAddress { + address: Some(Address::IpSubnet("192.168.60.10/31".to_string())), + }, + IpAddress { + address: Some(Address::IpSubnet("192.168.60.12/30".to_string())), + }, + IpAddress { + address: Some(Address::IpSubnet("192.168.60.16/30".to_string())), + }, + IpAddress { + address: Some(Address::Ip("192.168.60.20".to_string())), + }, + ]; + + assert_eq!(generated_firewall_rules.len(), 2); + + let allow_rule = &generated_firewall_rules[0]; + assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.source_addrs, expected_source_addrs); + assert_eq!(allow_rule.destination_addrs, expected_destination_addrs); + assert!(allow_rule.destination_ports.is_empty()); + assert_eq!( + allow_rule.protocols, + [i32::from(Protocol::Tcp), i32::from(Protocol::Udp)] + ); + + let deny_rule = &generated_firewall_rules[1]; + assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert!(deny_rule.source_addrs.is_empty()); + assert_eq!(deny_rule.destination_addrs, expected_destination_addrs); + assert!(deny_rule.destination_ports.is_empty()); + assert!(deny_rule.protocols.is_empty()); +} + +#[sqlx::test] +async fn test_any_protocol_preserves_destination_addresses_and_ports( + _: PgPoolOptions, + options: PgConnectOptions, +) { + set_test_license_business(); + let pool = setup_pool(options).await; + + let mut rng = thread_rng(); + + let mut location = WireguardNetwork::default() + .set_address(["10.0.0.1/16".parse().unwrap()]) + .unwrap(); + location.acl_enabled = true; + let location = location.save(&pool).await.unwrap(); + + create_test_users_and_devices(&mut rng, &pool, vec![&location]).await; + + let acl_rule = AclRule { + name: "any protocol manual destination rule".to_string(), + state: RuleState::Applied, + allow_all_users: true, + addresses: vec![ + "192.168.70.0/24".parse().unwrap(), + "192.168.80.1/32".parse().unwrap(), + ], + ports: vec![ + crate::enterprise::db::models::acl::PortRange::new(80, 80).into(), + crate::enterprise::db::models::acl::PortRange::new(1000, 1005).into(), + ], + protocols: vec![Protocol::Tcp.into(), Protocol::Udp.into()], + any_address: false, + any_port: false, + any_protocol: true, + use_manual_destination_settings: true, + ..Default::default() + }; + + create_acl_rule( + &pool, + acl_rule, + vec![location.id], + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + ) + .await; + + let mut conn = pool.acquire().await.unwrap(); + let generated_firewall_rules = try_get_location_firewall_config(&location, &mut conn) + .await + .unwrap() + .unwrap() + .rules; + + let expected_source_addrs = [ + IpAddress { + address: Some(Address::IpRange(IpRange { + start: "10.0.1.1".to_string(), + end: "10.0.1.2".to_string(), + })), + }, + IpAddress { + address: Some(Address::IpRange(IpRange { + start: "10.0.2.1".to_string(), + end: "10.0.2.2".to_string(), + })), + }, + ]; + let expected_destination_addrs = [ + IpAddress { + address: Some(Address::IpSubnet("192.168.70.0/24".to_string())), + }, + IpAddress { + address: Some(Address::Ip("192.168.80.1".to_string())), + }, + ]; + let expected_ports = [ + Port { + port: Some(PortInner::SinglePort(80)), + }, + Port { + port: Some(PortInner::PortRange( + defguard_proto::enterprise::firewall::PortRange { + start: 1000, + end: 1005, + }, + )), + }, + ]; + + assert_eq!(generated_firewall_rules.len(), 2); + + let allow_rule = &generated_firewall_rules[0]; + assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.source_addrs, expected_source_addrs); + assert_eq!(allow_rule.destination_addrs, expected_destination_addrs); + assert_eq!(allow_rule.destination_ports, expected_ports); + assert!(allow_rule.protocols.is_empty()); + + let deny_rule = &generated_firewall_rules[1]; + assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert!(deny_rule.source_addrs.is_empty()); + assert_eq!(deny_rule.destination_addrs, expected_destination_addrs); + assert!(deny_rule.destination_ports.is_empty()); + assert!(deny_rule.protocols.is_empty()); +} + +#[sqlx::test] +async fn test_destination_alias_any_port_preserves_addresses_and_protocols( + _: PgPoolOptions, + options: PgConnectOptions, +) { + set_test_license_business(); + let pool = setup_pool(options).await; + + let mut rng = thread_rng(); + + let mut location = WireguardNetwork::default() + .set_address(["10.0.0.1/16".parse().unwrap()]) + .unwrap(); + location.acl_enabled = true; + let location = location.save(&pool).await.unwrap(); + + create_test_users_and_devices(&mut rng, &pool, vec![&location]).await; + + let destination_alias = AclAlias { + name: "any port destination alias".to_string(), + kind: AliasKind::Destination, + addresses: vec!["192.168.90.0/24".parse().unwrap()], + ports: vec![ + crate::enterprise::db::models::acl::PortRange::new(22, 22).into(), + crate::enterprise::db::models::acl::PortRange::new(443, 443).into(), + ], + protocols: vec![Protocol::Tcp.into(), Protocol::Udp.into()], + any_address: false, + any_port: true, + any_protocol: false, + ..Default::default() + } + .save(&pool) + .await + .unwrap(); + + AclAliasDestinationRange { + id: NoId, + alias_id: destination_alias.id, + start: IpAddr::V4(Ipv4Addr::new(192, 168, 91, 10)), + end: IpAddr::V4(Ipv4Addr::new(192, 168, 91, 20)), + } + .save(&pool) + .await + .unwrap(); + + let acl_rule = AclRule { + name: "any port destination alias rule".to_string(), + state: RuleState::Applied, + allow_all_users: true, + use_manual_destination_settings: false, + ..Default::default() + }; + + create_acl_rule( + &pool, + acl_rule, + vec![location.id], + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + vec![destination_alias.id], + ) + .await; + + let mut conn = pool.acquire().await.unwrap(); + let generated_firewall_rules = try_get_location_firewall_config(&location, &mut conn) + .await + .unwrap() + .unwrap() + .rules; + + let expected_source_addrs = [ + IpAddress { + address: Some(Address::IpRange(IpRange { + start: "10.0.1.1".to_string(), + end: "10.0.1.2".to_string(), + })), + }, + IpAddress { + address: Some(Address::IpRange(IpRange { + start: "10.0.2.1".to_string(), + end: "10.0.2.2".to_string(), + })), + }, + ]; + let expected_destination_addrs = [ + IpAddress { + address: Some(Address::IpSubnet("192.168.90.0/24".to_string())), + }, + IpAddress { + address: Some(Address::IpSubnet("192.168.91.10/31".to_string())), + }, + IpAddress { + address: Some(Address::IpSubnet("192.168.91.12/30".to_string())), + }, + IpAddress { + address: Some(Address::IpSubnet("192.168.91.16/30".to_string())), + }, + IpAddress { + address: Some(Address::Ip("192.168.91.20".to_string())), + }, + ]; + + assert_eq!(generated_firewall_rules.len(), 2); + + let allow_rule = &generated_firewall_rules[0]; + assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.source_addrs, expected_source_addrs); + assert_eq!(allow_rule.destination_addrs, expected_destination_addrs); + assert!(allow_rule.destination_ports.is_empty()); + assert_eq!( + allow_rule.protocols, + [i32::from(Protocol::Tcp), i32::from(Protocol::Udp)] + ); + + let deny_rule = &generated_firewall_rules[1]; + assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert!(deny_rule.source_addrs.is_empty()); + assert_eq!(deny_rule.destination_addrs, expected_destination_addrs); + assert!(deny_rule.destination_ports.is_empty()); + assert!(deny_rule.protocols.is_empty()); +} + +#[sqlx::test] +async fn test_destination_alias_any_protocol_preserves_addresses_and_ports( + _: PgPoolOptions, + options: PgConnectOptions, +) { + set_test_license_business(); + let pool = setup_pool(options).await; + + let mut rng = thread_rng(); + + let mut location = WireguardNetwork::default() + .set_address(["10.0.0.1/16".parse().unwrap()]) + .unwrap(); + location.acl_enabled = true; + let location = location.save(&pool).await.unwrap(); + + create_test_users_and_devices(&mut rng, &pool, vec![&location]).await; + + let destination_alias = AclAlias { + name: "any protocol destination alias".to_string(), + kind: AliasKind::Destination, + addresses: vec![ + "192.168.110.0/24".parse().unwrap(), + "192.168.120.1/32".parse().unwrap(), + ], + ports: vec![ + crate::enterprise::db::models::acl::PortRange::new(80, 80).into(), + crate::enterprise::db::models::acl::PortRange::new(1000, 1005).into(), + ], + protocols: vec![Protocol::Tcp.into(), Protocol::Udp.into()], + any_address: false, + any_port: false, + any_protocol: true, + ..Default::default() + } + .save(&pool) + .await + .unwrap(); + + let acl_rule = AclRule { + name: "any protocol destination alias rule".to_string(), + state: RuleState::Applied, + allow_all_users: true, + use_manual_destination_settings: false, + ..Default::default() + }; + + create_acl_rule( + &pool, + acl_rule, + vec![location.id], + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + vec![destination_alias.id], + ) + .await; + + let mut conn = pool.acquire().await.unwrap(); + let generated_firewall_rules = try_get_location_firewall_config(&location, &mut conn) + .await + .unwrap() + .unwrap() + .rules; + + let expected_source_addrs = [ + IpAddress { + address: Some(Address::IpRange(IpRange { + start: "10.0.1.1".to_string(), + end: "10.0.1.2".to_string(), + })), + }, + IpAddress { + address: Some(Address::IpRange(IpRange { + start: "10.0.2.1".to_string(), + end: "10.0.2.2".to_string(), + })), + }, + ]; + let expected_destination_addrs = [ + IpAddress { + address: Some(Address::IpSubnet("192.168.110.0/24".to_string())), + }, + IpAddress { + address: Some(Address::Ip("192.168.120.1".to_string())), + }, + ]; + let expected_ports = [ + Port { + port: Some(PortInner::SinglePort(80)), + }, + Port { + port: Some(PortInner::PortRange( + defguard_proto::enterprise::firewall::PortRange { + start: 1000, + end: 1005, + }, + )), + }, + ]; + + assert_eq!(generated_firewall_rules.len(), 2); + + let allow_rule = &generated_firewall_rules[0]; + assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.source_addrs, expected_source_addrs); + assert_eq!(allow_rule.destination_addrs, expected_destination_addrs); + assert_eq!(allow_rule.destination_ports, expected_ports); + assert!(allow_rule.protocols.is_empty()); + + let deny_rule = &generated_firewall_rules[1]; + assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert!(deny_rule.source_addrs.is_empty()); + assert_eq!(deny_rule.destination_addrs, expected_destination_addrs); + assert!(deny_rule.destination_ports.is_empty()); + assert!(deny_rule.protocols.is_empty()); +} diff --git a/crates/defguard_core/src/enterprise/firewall/tests/mod.rs b/crates/defguard_core/src/enterprise/firewall/tests/mod.rs index 6a78e4b09e..fccbd40a4a 100644 --- a/crates/defguard_core/src/enterprise/firewall/tests/mod.rs +++ b/crates/defguard_core/src/enterprise/firewall/tests/mod.rs @@ -75,6 +75,87 @@ fn random_network_device_with_id(rng: &mut R, id: Id) -> Device { device } +fn expected_ipv4_source_range_for_user(user_id: Id) -> IpAddress { + let user_octet = user_id as u8; + IpAddress { + address: Some(Address::IpRange(IpRange { + start: format!("10.0.{user_octet}.1"), + end: format!("10.0.{user_octet}.2"), + })), + } +} + +async fn create_test_user_with_devices( + rng: &mut R, + pool: &PgPool, + test_locations: &[&WireguardNetwork], +) -> User { + let user: User = rng.r#gen(); + let user = user.save(pool).await.unwrap(); + + for device_number in 1u8..=2 { + let device = Device { + id: NoId, + name: format!("device-{}-{device_number}", user.id), + user_id: user.id, + device_type: DeviceType::User, + description: None, + wireguard_pubkey: String::default(), + created: NaiveDateTime::default(), + configured: true, + }; + let device = device.save(pool).await.unwrap(); + + for location in test_locations { + let wireguard_ips = location + .address() + .iter() + .map(|subnet| match subnet { + IpNetwork::V4(ipv4_network) => { + let octets = ipv4_network.network().octets(); + IpAddr::V4(Ipv4Addr::new( + octets[0], + octets[1], + user.id as u8, + device_number, + )) + } + IpNetwork::V6(ipv6_network) => { + let mut octets = ipv6_network.network().octets(); + octets[14] = user.id as u8; + octets[15] = device_number; + IpAddr::V6(Ipv6Addr::from(octets)) + } + }) + .collect(); + + WireguardNetworkDevice { + device_id: device.id, + wireguard_network_id: location.id, + wireguard_ips, + } + .insert(pool) + .await + .unwrap(); + } + } + + user +} + +async fn add_users_to_group(pool: &PgPool, group_id: Id, users: &[&User]) { + for user in users { + query!( + "INSERT INTO group_user (user_id, group_id) VALUES ($1, $2)", + user.id, + group_id + ) + .execute(pool) + .await + .unwrap(); + } +} + async fn create_test_users_and_devices( rng: &mut ThreadRng, pool: &PgPool, @@ -2105,6 +2186,570 @@ async fn test_no_allowed_users_ipv4(_: PgPoolOptions, options: PgConnectOptions) } } +#[sqlx::test] +async fn test_allow_all_groups_expands_all_group_members_into_firewall_sources( + _: PgPoolOptions, + options: PgConnectOptions, +) { + set_test_license_business(); + let pool = setup_pool(options).await; + let mut rng = thread_rng(); + + let mut location = WireguardNetwork::default() + .set_address(["10.0.0.1/16".parse().unwrap()]) + .unwrap(); + location.acl_enabled = true; + let location = location.save(&pool).await.unwrap(); + + let grouped_allowed_user_a = create_test_user_with_devices(&mut rng, &pool, &[&location]).await; + let grouped_allowed_user_b = create_test_user_with_devices(&mut rng, &pool, &[&location]).await; + let grouped_denied_user = create_test_user_with_devices(&mut rng, &pool, &[&location]).await; + let explicitly_allowed_ungrouped_user = + create_test_user_with_devices(&mut rng, &pool, &[&location]).await; + let ungrouped_blocked_user = create_test_user_with_devices(&mut rng, &pool, &[&location]).await; + + let first_group = Group { + name: "allow-all-groups-first".into(), + ..Default::default() + } + .save(&pool) + .await + .unwrap(); + let second_group = Group { + name: "allow-all-groups-second".into(), + ..Default::default() + } + .save(&pool) + .await + .unwrap(); + + add_users_to_group( + &pool, + first_group.id, + &[&grouped_allowed_user_a, &grouped_denied_user], + ) + .await; + add_users_to_group(&pool, second_group.id, &[&grouped_allowed_user_b]).await; + + let acl_rule = AclRule { + name: "allow all groups source expansion".into(), + state: RuleState::Applied, + allow_all_groups: true, + addresses: vec!["192.168.10.0/24".parse().unwrap()], + ports: vec![PortRange::new(443, 443).into()], + protocols: vec![Protocol::Tcp.into()], + any_address: false, + any_port: false, + any_protocol: false, + use_manual_destination_settings: true, + ..Default::default() + }; + + create_acl_rule( + &pool, + acl_rule, + vec![location.id], + vec![explicitly_allowed_ungrouped_user.id], + vec![grouped_denied_user.id], + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + ) + .await; + + let mut conn = pool.acquire().await.unwrap(); + let generated_firewall_rules = try_get_location_firewall_config(&location, &mut conn) + .await + .unwrap() + .unwrap() + .rules; + + assert_eq!(generated_firewall_rules.len(), 2); + + let mut expected_allowed_user_ids = vec![ + grouped_allowed_user_a.id, + grouped_allowed_user_b.id, + explicitly_allowed_ungrouped_user.id, + ]; + expected_allowed_user_ids.sort_unstable(); + let expected_source_addrs: Vec<_> = expected_allowed_user_ids + .into_iter() + .map(expected_ipv4_source_range_for_user) + .collect(); + + let allow_rule = &generated_firewall_rules[0]; + assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.source_addrs, expected_source_addrs); + assert_eq!( + allow_rule.destination_addrs, + [IpAddress { + address: Some(Address::IpSubnet("192.168.10.0/24".to_string())), + }] + ); + assert_eq!( + allow_rule.destination_ports, + [Port { + port: Some(PortInner::SinglePort(443)), + }] + ); + assert_eq!(allow_rule.protocols, [i32::from(Protocol::Tcp)]); + assert!( + allow_rule + .source_addrs + .iter() + .all(|addr| addr != &expected_ipv4_source_range_for_user(grouped_denied_user.id)) + ); + assert!( + allow_rule + .source_addrs + .iter() + .all(|addr| addr != &expected_ipv4_source_range_for_user(ungrouped_blocked_user.id)) + ); + + let deny_rule = &generated_firewall_rules[1]; + assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert!(deny_rule.source_addrs.is_empty()); + assert_eq!(deny_rule.destination_addrs, allow_rule.destination_addrs); + assert!(deny_rule.destination_ports.is_empty()); + assert!(deny_rule.protocols.is_empty()); +} + +#[sqlx::test] +async fn test_allow_all_groups_deduplicates_shared_group_members_before_source_resolution( + _: PgPoolOptions, + options: PgConnectOptions, +) { + set_test_license_business(); + let pool = setup_pool(options).await; + let mut rng = thread_rng(); + + let mut location = WireguardNetwork::default() + .set_address(["10.0.0.1/16".parse().unwrap()]) + .unwrap(); + location.acl_enabled = true; + let location = location.save(&pool).await.unwrap(); + + let shared_grouped_user = create_test_user_with_devices(&mut rng, &pool, &[&location]).await; + let first_group_only_user = create_test_user_with_devices(&mut rng, &pool, &[&location]).await; + let second_group_only_user = create_test_user_with_devices(&mut rng, &pool, &[&location]).await; + let explicitly_allowed_ungrouped_user = + create_test_user_with_devices(&mut rng, &pool, &[&location]).await; + let ungrouped_blocked_user = create_test_user_with_devices(&mut rng, &pool, &[&location]).await; + + let first_group = Group { + name: "allow-all-groups-dedup-first".into(), + ..Default::default() + } + .save(&pool) + .await + .unwrap(); + let second_group = Group { + name: "allow-all-groups-dedup-second".into(), + ..Default::default() + } + .save(&pool) + .await + .unwrap(); + + add_users_to_group( + &pool, + first_group.id, + &[&shared_grouped_user, &first_group_only_user], + ) + .await; + add_users_to_group( + &pool, + second_group.id, + &[&shared_grouped_user, &second_group_only_user], + ) + .await; + + let acl_rule = AclRule { + name: "allow all groups dedup source expansion".into(), + state: RuleState::Applied, + allow_all_groups: true, + addresses: vec!["192.168.30.0/24".parse().unwrap()], + ports: vec![PortRange::new(443, 443).into()], + protocols: vec![Protocol::Tcp.into()], + any_address: false, + any_port: false, + any_protocol: false, + use_manual_destination_settings: true, + ..Default::default() + }; + + let acl_rule_info = create_acl_rule( + &pool, + acl_rule, + vec![location.id], + vec![explicitly_allowed_ungrouped_user.id], + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + ) + .await; + + let mut conn = pool.acquire().await.unwrap(); + + let allowed_users = acl_rule_info + .get_all_allowed_users(&mut conn) + .await + .unwrap(); + let denied_users = acl_rule_info.get_all_denied_users(&mut conn).await.unwrap(); + let mut resolved_source_user_ids: Vec<_> = + super::get_source_users(allowed_users, &denied_users) + .into_iter() + .map(|user| user.id) + .collect(); + + let mut expected_source_user_ids = vec![ + shared_grouped_user.id, + first_group_only_user.id, + second_group_only_user.id, + explicitly_allowed_ungrouped_user.id, + ]; + expected_source_user_ids.sort_unstable(); + resolved_source_user_ids.sort_unstable(); + + assert_eq!(resolved_source_user_ids, expected_source_user_ids); + assert_eq!(resolved_source_user_ids.len(), 4); + + let generated_firewall_rules = try_get_location_firewall_config(&location, &mut conn) + .await + .unwrap() + .unwrap() + .rules; + + assert_eq!(generated_firewall_rules.len(), 2); + + let expected_source_addrs: Vec<_> = expected_source_user_ids + .into_iter() + .map(expected_ipv4_source_range_for_user) + .collect(); + + let allow_rule = &generated_firewall_rules[0]; + assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!(allow_rule.source_addrs, expected_source_addrs); + assert_eq!( + allow_rule.destination_addrs, + [IpAddress { + address: Some(Address::IpSubnet("192.168.30.0/24".to_string())), + }] + ); + assert_eq!( + allow_rule.destination_ports, + [Port { + port: Some(PortInner::SinglePort(443)), + }] + ); + assert_eq!(allow_rule.protocols, [i32::from(Protocol::Tcp)]); + assert!( + allow_rule + .source_addrs + .iter() + .all(|addr| addr != &expected_ipv4_source_range_for_user(ungrouped_blocked_user.id)) + ); + + let deny_rule = &generated_firewall_rules[1]; + assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert!(deny_rule.source_addrs.is_empty()); + assert_eq!(deny_rule.destination_addrs, allow_rule.destination_addrs); + assert!(deny_rule.destination_ports.is_empty()); + assert!(deny_rule.protocols.is_empty()); +} + +#[sqlx::test] +async fn test_deny_all_groups_excludes_members_of_every_group_from_firewall_sources( + _: PgPoolOptions, + options: PgConnectOptions, +) { + set_test_license_business(); + let pool = setup_pool(options).await; + let mut rng = thread_rng(); + + let mut location = WireguardNetwork::default() + .set_address(["10.0.0.1/16".parse().unwrap()]) + .unwrap(); + location.acl_enabled = true; + let location = location.save(&pool).await.unwrap(); + + let grouped_denied_user_a = create_test_user_with_devices(&mut rng, &pool, &[&location]).await; + let grouped_denied_user_b = create_test_user_with_devices(&mut rng, &pool, &[&location]).await; + let grouped_denied_user_c = create_test_user_with_devices(&mut rng, &pool, &[&location]).await; + let ungrouped_allowed_user = create_test_user_with_devices(&mut rng, &pool, &[&location]).await; + let explicitly_denied_ungrouped_user = + create_test_user_with_devices(&mut rng, &pool, &[&location]).await; + + let first_group = Group { + name: "deny-all-groups-first".into(), + ..Default::default() + } + .save(&pool) + .await + .unwrap(); + let second_group = Group { + name: "deny-all-groups-second".into(), + ..Default::default() + } + .save(&pool) + .await + .unwrap(); + + add_users_to_group( + &pool, + first_group.id, + &[&grouped_denied_user_a, &grouped_denied_user_c], + ) + .await; + add_users_to_group(&pool, second_group.id, &[&grouped_denied_user_b]).await; + + let acl_rule = AclRule { + name: "deny all groups source filtering".into(), + state: RuleState::Applied, + allow_all_users: true, + deny_all_groups: true, + addresses: vec!["192.168.20.0/24".parse().unwrap()], + ports: vec![PortRange::new(8443, 8443).into()], + protocols: vec![Protocol::Tcp.into()], + any_address: false, + any_port: false, + any_protocol: false, + use_manual_destination_settings: true, + ..Default::default() + }; + + create_acl_rule( + &pool, + acl_rule, + vec![location.id], + Vec::new(), + vec![explicitly_denied_ungrouped_user.id], + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + ) + .await; + + let mut conn = pool.acquire().await.unwrap(); + let generated_firewall_rules = try_get_location_firewall_config(&location, &mut conn) + .await + .unwrap() + .unwrap() + .rules; + + assert_eq!(generated_firewall_rules.len(), 2); + + let allow_rule = &generated_firewall_rules[0]; + assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!( + allow_rule.source_addrs, + [expected_ipv4_source_range_for_user( + ungrouped_allowed_user.id + )] + ); + assert_eq!( + allow_rule.destination_addrs, + [IpAddress { + address: Some(Address::IpSubnet("192.168.20.0/24".to_string())), + }] + ); + assert_eq!( + allow_rule.destination_ports, + [Port { + port: Some(PortInner::SinglePort(8443)), + }] + ); + assert_eq!(allow_rule.protocols, [i32::from(Protocol::Tcp)]); + for denied_user in [ + grouped_denied_user_a.id, + grouped_denied_user_b.id, + grouped_denied_user_c.id, + explicitly_denied_ungrouped_user.id, + ] { + assert!( + allow_rule + .source_addrs + .iter() + .all(|addr| addr != &expected_ipv4_source_range_for_user(denied_user)) + ); + } + + let deny_rule = &generated_firewall_rules[1]; + assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert!(deny_rule.source_addrs.is_empty()); + assert_eq!(deny_rule.destination_addrs, allow_rule.destination_addrs); + assert!(deny_rule.destination_ports.is_empty()); + assert!(deny_rule.protocols.is_empty()); +} + +#[sqlx::test] +async fn test_deny_all_groups_deduplicates_shared_group_members_before_source_filtering( + _: PgPoolOptions, + options: PgConnectOptions, +) { + set_test_license_business(); + let pool = setup_pool(options).await; + let mut rng = thread_rng(); + + let mut location = WireguardNetwork::default() + .set_address(["10.0.0.1/16".parse().unwrap()]) + .unwrap(); + location.acl_enabled = true; + let location = location.save(&pool).await.unwrap(); + + let shared_grouped_denied_user = + create_test_user_with_devices(&mut rng, &pool, &[&location]).await; + let first_group_only_denied_user = + create_test_user_with_devices(&mut rng, &pool, &[&location]).await; + let second_group_only_denied_user = + create_test_user_with_devices(&mut rng, &pool, &[&location]).await; + let ungrouped_allowed_user = create_test_user_with_devices(&mut rng, &pool, &[&location]).await; + let explicitly_denied_ungrouped_user = + create_test_user_with_devices(&mut rng, &pool, &[&location]).await; + + let first_group = Group { + name: "deny-all-groups-dedup-first".into(), + ..Default::default() + } + .save(&pool) + .await + .unwrap(); + let second_group = Group { + name: "deny-all-groups-dedup-second".into(), + ..Default::default() + } + .save(&pool) + .await + .unwrap(); + + add_users_to_group( + &pool, + first_group.id, + &[&shared_grouped_denied_user, &first_group_only_denied_user], + ) + .await; + add_users_to_group( + &pool, + second_group.id, + &[&shared_grouped_denied_user, &second_group_only_denied_user], + ) + .await; + + let acl_rule = AclRule { + name: "deny all groups dedup source filtering".into(), + state: RuleState::Applied, + allow_all_users: true, + deny_all_groups: true, + addresses: vec!["192.168.40.0/24".parse().unwrap()], + ports: vec![PortRange::new(8443, 8443).into()], + protocols: vec![Protocol::Tcp.into()], + any_address: false, + any_port: false, + any_protocol: false, + use_manual_destination_settings: true, + ..Default::default() + }; + + let acl_rule_info = create_acl_rule( + &pool, + acl_rule, + vec![location.id], + Vec::new(), + vec![explicitly_denied_ungrouped_user.id], + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + ) + .await; + + let mut conn = pool.acquire().await.unwrap(); + + let allowed_users = acl_rule_info + .get_all_allowed_users(&mut conn) + .await + .unwrap(); + let denied_users = acl_rule_info.get_all_denied_users(&mut conn).await.unwrap(); + + let mut denied_user_ids: Vec<_> = denied_users.iter().map(|user| user.id).collect(); + let mut expected_denied_user_ids = vec![ + shared_grouped_denied_user.id, + first_group_only_denied_user.id, + second_group_only_denied_user.id, + explicitly_denied_ungrouped_user.id, + ]; + denied_user_ids.sort_unstable(); + expected_denied_user_ids.sort_unstable(); + + assert_eq!(denied_user_ids, expected_denied_user_ids); + assert_eq!(denied_user_ids.len(), 4); + + let mut resolved_source_user_ids: Vec<_> = + super::get_source_users(allowed_users, &denied_users) + .into_iter() + .map(|user| user.id) + .collect(); + resolved_source_user_ids.sort_unstable(); + assert_eq!(resolved_source_user_ids, [ungrouped_allowed_user.id]); + + let generated_firewall_rules = try_get_location_firewall_config(&location, &mut conn) + .await + .unwrap() + .unwrap() + .rules; + + assert_eq!(generated_firewall_rules.len(), 2); + + let allow_rule = &generated_firewall_rules[0]; + assert_eq!(allow_rule.verdict, i32::from(FirewallPolicy::Allow)); + assert_eq!( + allow_rule.source_addrs, + [expected_ipv4_source_range_for_user( + ungrouped_allowed_user.id + )] + ); + assert_eq!( + allow_rule.destination_addrs, + [IpAddress { + address: Some(Address::IpSubnet("192.168.40.0/24".to_string())), + }] + ); + assert_eq!( + allow_rule.destination_ports, + [Port { + port: Some(PortInner::SinglePort(8443)), + }] + ); + assert_eq!(allow_rule.protocols, [i32::from(Protocol::Tcp)]); + for denied_user in expected_denied_user_ids { + assert!( + allow_rule + .source_addrs + .iter() + .all(|addr| addr != &expected_ipv4_source_range_for_user(denied_user)) + ); + } + + let deny_rule = &generated_firewall_rules[1]; + assert_eq!(deny_rule.verdict, i32::from(FirewallPolicy::Deny)); + assert!(deny_rule.source_addrs.is_empty()); + assert_eq!(deny_rule.destination_addrs, allow_rule.destination_addrs); + assert!(deny_rule.destination_ports.is_empty()); + assert!(deny_rule.protocols.is_empty()); +} + #[sqlx::test] async fn test_empty_manual_destination_only_acl(_: PgPoolOptions, options: PgConnectOptions) { set_test_license_business();