From d9a8d21637a141e06bae1f3ad9df6dd6dd0ebb80 Mon Sep 17 00:00:00 2001 From: Aleksander <170264518+t-aleksander@users.noreply.github.com> Date: Mon, 4 Aug 2025 10:51:15 +0200 Subject: [PATCH 1/5] restore merging of ips --- src/enterprise/firewall/iprange.rs | 17 + src/enterprise/firewall/mod.rs | 95 ++++ src/enterprise/firewall/nftables/mod.rs | 499 +++++++++++++++++- src/enterprise/firewall/nftables/netfilter.rs | 79 +-- 4 files changed, 595 insertions(+), 95 deletions(-) diff --git a/src/enterprise/firewall/iprange.rs b/src/enterprise/firewall/iprange.rs index 33eab39d..b1b4acd6 100644 --- a/src/enterprise/firewall/iprange.rs +++ b/src/enterprise/firewall/iprange.rs @@ -76,6 +76,23 @@ impl IpAddrRange { Self::V6(_) => true, } } + + /// Returns the start of the range. + pub fn start(&self) -> IpAddr { + match self { + Self::V4(range) => IpAddr::V4(*range.start()), + Self::V6(range) => IpAddr::V6(*range.start()), + } + } + + /// Returns the end of the range. + /// If the range is empty, returns the start of the range. + pub fn end(&self) -> IpAddr { + match self { + Self::V4(range) => IpAddr::V4(*range.end()), + Self::V6(range) => IpAddr::V6(*range.end()), + } + } } impl Iterator for IpAddrRange { diff --git a/src/enterprise/firewall/mod.rs b/src/enterprise/firewall/mod.rs index 5f564697..21660de5 100644 --- a/src/enterprise/firewall/mod.rs +++ b/src/enterprise/firewall/mod.rs @@ -26,6 +26,20 @@ pub(crate) enum Address { } impl Address { + pub fn first(&self) -> IpAddr { + match self { + Address::Network(network) => network.ip(), + Address::Range(range) => range.start(), + } + } + + pub fn last(&self) -> IpAddr { + match self { + Address::Network(network) => max_address(network), + Address::Range(range) => range.end(), + } + } + pub fn from_proto(ip: &proto::enterprise::firewall::IpAddress) -> Result { match &ip.address { Some(proto::enterprise::firewall::ip_address::Address::Ip(ip)) => { @@ -322,3 +336,84 @@ pub enum FirewallError { #[error("Firewall transaction failed: {0}")] TransactionFailed(String), } + +/// Get the max address in a network. +/// +/// - In IPv4 this is the broadcast address. +/// - In IPv6 this is just the last address in the network. +pub fn max_address(network: &IpNetwork) -> IpAddr { + match network { + IpNetwork::V4(network) => { + let addr = network.ip().to_bits(); + let mask = network.mask().to_bits(); + + IpAddr::V4(Ipv4Addr::from(addr | !mask)) + } + IpNetwork::V6(network) => { + let addr = network.ip().to_bits(); + let mask = network.mask().to_bits(); + + IpAddr::V6(Ipv6Addr::from(addr | !mask)) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_max_address_ipv4_24() { + let network = IpNetwork::V4(Ipv4Network::from_str("192.168.1.0/24").unwrap()); + let max = max_address(&network); + assert_eq!(max, IpAddr::V4(Ipv4Addr::new(192, 168, 1, 255))); + } + + #[test] + fn test_max_address_ipv4_16() { + let network = IpNetwork::V4(Ipv4Network::from_str("10.1.0.0/16").unwrap()); + let max = max_address(&network); + assert_eq!(max, IpAddr::V4(Ipv4Addr::new(10, 1, 255, 255))); + } + + #[test] + fn test_max_address_ipv4_8() { + let network = IpNetwork::V4(Ipv4Network::from_str("172.16.0.0/8").unwrap()); + let max = max_address(&network); + assert_eq!(max, IpAddr::V4(Ipv4Addr::new(172, 255, 255, 255))); + } + + #[test] + fn test_max_address_ipv4_32() { + let network = IpNetwork::V4(Ipv4Network::from_str("192.168.1.1/32").unwrap()); + let max = max_address(&network); + assert_eq!(max, IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))); + } + + #[test] + fn test_max_address_ipv6_64() { + let network = IpNetwork::V6(Ipv6Network::from_str("2001:db8::/64").unwrap()); + let max = max_address(&network); + assert_eq!( + max, + IpAddr::V6(Ipv6Addr::from_str("2001:db8::ffff:ffff:ffff:ffff").unwrap()) + ); + } + + #[test] + fn test_max_address_ipv6_128() { + let network = IpNetwork::V6(Ipv6Network::from_str("2001:db8::1/128").unwrap()); + let max = max_address(&network); + assert_eq!(max, IpAddr::V6(Ipv6Addr::from_str("2001:db8::1").unwrap())); + } + + #[test] + fn test_max_address_ipv6_48() { + let network = IpNetwork::V6(Ipv6Network::from_str("2001:db8:1234::/48").unwrap()); + let max = max_address(&network); + assert_eq!( + max, + IpAddr::V6(Ipv6Addr::from_str("2001:db8:1234:ffff:ffff:ffff:ffff:ffff").unwrap()) + ); + } +} diff --git a/src/enterprise/firewall/nftables/mod.rs b/src/enterprise/firewall/nftables/mod.rs index d5e5574b..fd529231 100644 --- a/src/enterprise/firewall/nftables/mod.rs +++ b/src/enterprise/firewall/nftables/mod.rs @@ -1,6 +1,9 @@ pub mod netfilter; -use std::sync::atomic::{AtomicU32, Ordering}; +use std::{ + net::{IpAddr, Ipv4Addr, Ipv6Addr}, + sync::atomic::{AtomicU32, Ordering}, +}; use netfilter::{ allow_established_traffic, apply_filter_rules, drop_table, ignore_unrelated_traffic, @@ -12,6 +15,7 @@ use super::{ api::{FirewallApi, FirewallManagementApi}, Address, FirewallError, FirewallRule, Policy, Port, Protocol, }; +use crate::enterprise::firewall::iprange::IpAddrRange; static SET_ID_COUNTER: AtomicU32 = AtomicU32::new(0); @@ -51,29 +55,137 @@ struct FilterRule<'a> { negated_iifname: bool, } +fn merge_addrs<'a>(addrs: &'a [Address]) -> Vec
{ + debug!("merge_addrs called with input: {:?}", addrs); + + if addrs.is_empty() { + debug!("No addresses provided, returning empty vector."); + return Vec::new(); + } + + let mut merged_addrs = Vec::new(); + + // sort them by their .first() address + let mut addrs_sorted = Vec::from_iter(addrs.iter()); + addrs_sorted.sort_by(|a, b| { + a.first() + .partial_cmp(&b.first()) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + debug!("Sorted addresses: {:?}", addrs_sorted); + + let first_addr = addrs_sorted.remove(0); + let mut current_range_start = first_addr.first(); + let mut current_range_end = first_addr.last(); + + debug!( + "Starting merge loop with initial range: {:?} - {:?}", + current_range_start, current_range_end + ); + + for addr in addrs_sorted { + let first_addr = addr.first(); + let last_addr = addr.last(); + + debug!( + "Checking addr: {:?} - {:?} against current_range_end: {:?}", + first_addr, last_addr, current_range_end + ); + + // Check if ranges overlap or are adjacent + if first_addr <= current_range_end || next_ip(current_range_end) == first_addr { + // Ranges overlap or are adjacent, merge them + if last_addr > current_range_end { + debug!( + "Extending current_range_end from {:?} to {:?}", + current_range_end, last_addr + ); + current_range_end = last_addr; + } + } else { + // Ranges don't overlap and aren't adjacent, push current range and start new one + debug!( + "Pushing merged range: {:?} - {:?}", + current_range_start, current_range_end + ); + merged_addrs.push(Address::Range( + IpAddrRange::new(current_range_start, current_range_end).unwrap(), + )); + current_range_start = first_addr; + current_range_end = last_addr; + } + } + + // Push the last range + debug!( + "Pushing final merged range: {:?} - {:?}", + current_range_start, current_range_end + ); + merged_addrs.push(Address::Range( + IpAddrRange::new(current_range_start, current_range_end).unwrap(), + )); + + debug!("Prepared addresses: {:?}", merged_addrs); + + merged_addrs +} + +/// Returns the next IP address in sequence, handling overflow via wrapping +fn next_ip(ip: IpAddr) -> IpAddr { + match ip { + IpAddr::V4(ipv4) => { + let octets = ipv4.octets(); + let mut num: u32 = ((octets[0] as u32) << 24) + | ((octets[1] as u32) << 16) + | ((octets[2] as u32) << 8) + | octets[3] as u32; + num = num.wrapping_add(1); + IpAddr::V4(Ipv4Addr::from(num)) + } + IpAddr::V6(ipv6) => { + let segments = ipv6.segments(); + let mut num: u128 = ((segments[0] as u128) << 112) + | ((segments[1] as u128) << 96) + | ((segments[2] as u128) << 80) + | ((segments[3] as u128) << 64) + | ((segments[4] as u128) << 48) + | ((segments[5] as u128) << 32) + | ((segments[6] as u128) << 16) + | segments[7] as u128; + num = num.wrapping_add(1); + IpAddr::V6(Ipv6Addr::from(num)) + } + } +} + impl FirewallApi { fn add_rule(&mut self, rule: FirewallRule) -> Result<(), FirewallError> { debug!("Applying the following Defguard ACL rule: {rule:?}"); - let mut rules = Vec::new(); let batch = if let Some(ref mut batch) = self.batch { batch } else { return Err(FirewallError::TransactionNotStarted); }; + let mut filter_rules = Vec::new(); debug!( "The rule will be split into multiple nftables rules based on the specified \ destination ports and protocols." ); + + let source_addrs = merge_addrs(&rule.source_addrs); + let dest_addrs = merge_addrs(&rule.destination_addrs); + if rule.destination_ports.is_empty() { debug!( "No destination ports specified, applying single aggregate nftables rule for \ every protocol." ); let rule = FilterRule { - src_ips: &rule.source_addrs, - dest_ips: &rule.destination_addrs, - protocols: rule.protocols, + src_ips: &source_addrs, + dest_ips: &dest_addrs, + protocols: rule.protocols.clone(), action: rule.verdict, counter: true, defguard_rule_id: rule.id, @@ -81,19 +193,19 @@ impl FirewallApi { comment: rule.comment.clone(), ..Default::default() }; - rules.push(rule); + filter_rules.push(rule); } else if !rule.protocols.is_empty() { debug!( "Destination ports and protocols specified, applying individual nftables rules \ for each protocol." ); - for protocol in rule.protocols { + for protocol in rule.protocols.clone() { debug!("Applying rule for protocol: {protocol:?}"); if protocol.supports_ports() { debug!("Protocol supports ports, rule."); let rule = FilterRule { - src_ips: &rule.source_addrs, - dest_ips: &rule.destination_addrs, + src_ips: &source_addrs, + dest_ips: &dest_addrs, dest_ports: &rule.destination_ports, protocols: vec![protocol], action: rule.verdict, @@ -103,15 +215,15 @@ impl FirewallApi { comment: rule.comment.clone(), ..Default::default() }; - rules.push(rule); + filter_rules.push(rule); } else { debug!( "Protocol does not support ports, applying nftables rule and ignoring \ destination ports." ); let rule = FilterRule { - src_ips: &rule.source_addrs, - dest_ips: &rule.destination_addrs, + src_ips: &source_addrs, + dest_ips: &dest_addrs, protocols: vec![protocol], action: rule.verdict, counter: true, @@ -120,7 +232,7 @@ impl FirewallApi { comment: rule.comment.clone(), ..Default::default() }; - rules.push(rule); + filter_rules.push(rule); } } } else { @@ -131,8 +243,8 @@ impl FirewallApi { for protocol in [Protocol::Tcp, Protocol::Udp] { debug!("Applying nftables rule for protocol: {protocol:?}"); let rule = FilterRule { - src_ips: &rule.source_addrs, - dest_ips: &rule.destination_addrs, + src_ips: &source_addrs, + dest_ips: &dest_addrs, dest_ports: &rule.destination_ports, protocols: vec![protocol], action: rule.verdict, @@ -142,11 +254,11 @@ impl FirewallApi { comment: rule.comment.clone(), ..Default::default() }; - rules.push(rule); + filter_rules.push(rule); } } - apply_filter_rules(rules, batch, &self.ifname)?; + apply_filter_rules(filter_rules, batch, &self.ifname)?; debug!( "Applied firewall rules for Defguard ACL rule ID: {}", @@ -257,3 +369,356 @@ impl FirewallManagementApi for FirewallApi { } } } + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use ipnetwork::IpNetwork; + + use super::*; + use crate::proto::enterprise::firewall::FirewallRule as ProtoFirewallRule; + + #[test] + fn test_sorting() { + let mut addrs = vec![ + Address::Network(IpNetwork::from_str("10.10.10.11/24").unwrap()), + Address::Network(IpNetwork::from_str("10.10.10.12/24").unwrap()), + Address::Network(IpNetwork::from_str("10.10.11.10/32").unwrap()), + Address::Network(IpNetwork::from_str("10.10.11.11/32").unwrap()), + Address::Network(IpNetwork::from_str("10.10.10.10/24").unwrap()), + Address::Network(IpNetwork::from_str("10.10.11.12/32").unwrap()), + ]; + + addrs.sort_by(|a, b| { + a.first() + .partial_cmp(&b.first()) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + assert_eq!( + addrs, + vec![ + Address::Network(IpNetwork::from_str("10.10.10.10/24").unwrap()), + Address::Network(IpNetwork::from_str("10.10.10.11/24").unwrap()), + Address::Network(IpNetwork::from_str("10.10.10.12/24").unwrap()), + Address::Network(IpNetwork::from_str("10.10.11.10/32").unwrap()), + Address::Network(IpNetwork::from_str("10.10.11.11/32").unwrap()), + Address::Network(IpNetwork::from_str("10.10.11.12/32").unwrap()), + ] + ); + + let _prepared_addrs = merge_addrs(&addrs); + } + + #[test] + fn test_merge_addrs_empty() { + let addrs: Vec
= vec![]; + let result = merge_addrs(&addrs); + assert!(result.is_empty()); + } + + #[test] + fn test_merge_addrs_single_address() { + let addrs = vec![Address::Network( + IpNetwork::from_str("192.168.1.10/32").unwrap(), + )]; + let result = merge_addrs(&addrs); + + assert_eq!(result.len(), 1); + if let Address::Range(range) = &result[0] { + assert_eq!(range.start(), IpAddr::from_str("192.168.1.10").unwrap()); + assert_eq!(range.end(), IpAddr::from_str("192.168.1.10").unwrap()); + } else { + panic!("Expected Address::Range"); + } + } + + #[test] + fn test_merge_addrs_overlapping_ranges() { + let addrs = vec![ + Address::Range( + IpAddrRange::new( + IpAddr::from_str("192.168.1.10").unwrap(), + IpAddr::from_str("192.168.1.20").unwrap(), + ) + .unwrap(), + ), + Address::Range( + IpAddrRange::new( + IpAddr::from_str("192.168.1.15").unwrap(), + IpAddr::from_str("192.168.1.25").unwrap(), + ) + .unwrap(), + ), + ]; + let result = merge_addrs(&addrs); + + assert_eq!(result.len(), 1); + if let Address::Range(range) = &result[0] { + assert_eq!(range.start(), IpAddr::from_str("192.168.1.10").unwrap()); + assert_eq!(range.end(), IpAddr::from_str("192.168.1.25").unwrap()); + } else { + panic!("Expected Address::Range"); + } + } + + #[test] + fn test_merge_addrs_adjacent_ranges() { + let addrs = vec![ + Address::Range( + IpAddrRange::new( + IpAddr::from_str("192.168.1.10").unwrap(), + IpAddr::from_str("192.168.1.20").unwrap(), + ) + .unwrap(), + ), + Address::Range( + IpAddrRange::new( + IpAddr::from_str("192.168.1.21").unwrap(), + IpAddr::from_str("192.168.1.30").unwrap(), + ) + .unwrap(), + ), + ]; + let result = merge_addrs(&addrs); + + assert_eq!(result.len(), 1); + if let Address::Range(range) = &result[0] { + assert_eq!(range.start(), IpAddr::from_str("192.168.1.10").unwrap()); + assert_eq!(range.end(), IpAddr::from_str("192.168.1.30").unwrap()); + } else { + panic!("Expected Address::Range"); + } + } + + #[test] + fn test_merge_addrs_adjacent_single_addresses() { + let addrs = vec![ + Address::Network(IpNetwork::from_str("192.168.1.10/32").unwrap()), + Address::Network(IpNetwork::from_str("192.168.1.11/32").unwrap()), + Address::Network(IpNetwork::from_str("192.168.1.12/32").unwrap()), + ]; + let result = merge_addrs(&addrs); + + assert_eq!(result.len(), 1); + if let Address::Range(range) = &result[0] { + assert_eq!(range.start(), IpAddr::from_str("192.168.1.10").unwrap()); + assert_eq!(range.end(), IpAddr::from_str("192.168.1.12").unwrap()); + } else { + panic!("Expected Address::Range"); + } + } + + #[test] + fn test_merge_addrs_non_adjacent_ranges() { + let addrs = vec![ + Address::Range( + IpAddrRange::new( + IpAddr::from_str("192.168.1.10").unwrap(), + IpAddr::from_str("192.168.1.20").unwrap(), + ) + .unwrap(), + ), + Address::Range( + IpAddrRange::new( + IpAddr::from_str("192.168.1.30").unwrap(), + IpAddr::from_str("192.168.1.40").unwrap(), + ) + .unwrap(), + ), + ]; + let result = merge_addrs(&addrs); + + assert_eq!(result.len(), 2); + if let Address::Range(range1) = &result[0] { + assert_eq!(range1.start(), IpAddr::from_str("192.168.1.10").unwrap()); + assert_eq!(range1.end(), IpAddr::from_str("192.168.1.20").unwrap()); + } else { + panic!("Expected Address::Range"); + } + if let Address::Range(range2) = &result[1] { + assert_eq!(range2.start(), IpAddr::from_str("192.168.1.30").unwrap()); + assert_eq!(range2.end(), IpAddr::from_str("192.168.1.40").unwrap()); + } else { + panic!("Expected Address::Range"); + } + } + + #[test] + fn test_merge_addrs_mixed_networks_and_ranges() { + let addrs = vec![ + Address::Network(IpNetwork::from_str("192.168.1.10/32").unwrap()), + Address::Range( + IpAddrRange::new( + IpAddr::from_str("192.168.1.11").unwrap(), + IpAddr::from_str("192.168.1.15").unwrap(), + ) + .unwrap(), + ), + Address::Network(IpNetwork::from_str("192.168.1.16/32").unwrap()), + ]; + let result = merge_addrs(&addrs); + + assert_eq!(result.len(), 1); + if let Address::Range(range) = &result[0] { + assert_eq!(range.start(), IpAddr::from_str("192.168.1.10").unwrap()); + assert_eq!(range.end(), IpAddr::from_str("192.168.1.16").unwrap()); + } else { + panic!("Expected Address::Range"); + } + } + + #[test] + fn test_merge_addrs_unsorted_input() { + let addrs = vec![ + Address::Network(IpNetwork::from_str("192.168.1.13/32").unwrap()), + Address::Network(IpNetwork::from_str("192.168.1.10/32").unwrap()), + Address::Network(IpNetwork::from_str("192.168.1.12/32").unwrap()), + Address::Network(IpNetwork::from_str("192.168.1.11/32").unwrap()), + ]; + let result = merge_addrs(&addrs); + + assert_eq!(result.len(), 1); + if let Address::Range(range) = &result[0] { + assert_eq!(range.start(), IpAddr::from_str("192.168.1.10").unwrap()); + assert_eq!(range.end(), IpAddr::from_str("192.168.1.13").unwrap()); + } else { + panic!("Expected Address::Range"); + } + } + + #[test] + fn test_merge_addrs_non_adjacent_singles() { + let addrs = vec![ + Address::Network(IpNetwork::from_str("192.168.1.20/32").unwrap()), + Address::Network(IpNetwork::from_str("192.168.1.10/32").unwrap()), + Address::Network(IpNetwork::from_str("192.168.1.15/32").unwrap()), + Address::Network(IpNetwork::from_str("192.168.1.11/32").unwrap()), + ]; + let result = merge_addrs(&addrs); + + // These should result in 3 separate ranges: 10-11, 15, 20 + assert_eq!(result.len(), 3); + + if let Address::Range(range1) = &result[0] { + assert_eq!(range1.start(), IpAddr::from_str("192.168.1.10").unwrap()); + assert_eq!(range1.end(), IpAddr::from_str("192.168.1.11").unwrap()); + } else { + panic!("Expected Address::Range"); + } + + if let Address::Range(range2) = &result[1] { + assert_eq!(range2.start(), IpAddr::from_str("192.168.1.15").unwrap()); + assert_eq!(range2.end(), IpAddr::from_str("192.168.1.15").unwrap()); + } else { + panic!("Expected Address::Range"); + } + + if let Address::Range(range3) = &result[2] { + assert_eq!(range3.start(), IpAddr::from_str("192.168.1.20").unwrap()); + assert_eq!(range3.end(), IpAddr::from_str("192.168.1.20").unwrap()); + } else { + panic!("Expected Address::Range"); + } + } + + #[test] + fn test_merge_addrs_ipv6() { + let addrs = vec![ + Address::Network(IpNetwork::from_str("2001:db8::1/128").unwrap()), + Address::Network(IpNetwork::from_str("2001:db8::2/128").unwrap()), + Address::Network(IpNetwork::from_str("2001:db8::3/128").unwrap()), + ]; + let result = merge_addrs(&addrs); + + assert_eq!(result.len(), 1); + if let Address::Range(range) = &result[0] { + assert_eq!(range.start(), IpAddr::from_str("2001:db8::1").unwrap()); + assert_eq!(range.end(), IpAddr::from_str("2001:db8::3").unwrap()); + } else { + panic!("Expected Address::Range"); + } + } + + #[test] + fn test_merge_addrs_contained_ranges() { + let addrs = vec![ + Address::Range( + IpAddrRange::new( + IpAddr::from_str("192.168.1.10").unwrap(), + IpAddr::from_str("192.168.1.30").unwrap(), + ) + .unwrap(), + ), + Address::Range( + IpAddrRange::new( + IpAddr::from_str("192.168.1.15").unwrap(), + IpAddr::from_str("192.168.1.20").unwrap(), + ) + .unwrap(), + ), + ]; + let result = merge_addrs(&addrs); + + assert_eq!(result.len(), 1); + if let Address::Range(range) = &result[0] { + assert_eq!(range.start(), IpAddr::from_str("192.168.1.10").unwrap()); + assert_eq!(range.end(), IpAddr::from_str("192.168.1.30").unwrap()); + } else { + panic!("Expected Address::Range"); + } + } + + #[test] + fn test_next_ip_ipv4() { + assert_eq!( + next_ip(IpAddr::from_str("192.168.1.10").unwrap()), + IpAddr::from_str("192.168.1.11").unwrap() + ); + + // Test overflow + assert_eq!( + next_ip(IpAddr::from_str("255.255.255.255").unwrap()), + IpAddr::from_str("0.0.0.0").unwrap() + ); + } + + #[test] + fn test_next_ip_ipv6() { + assert_eq!( + next_ip(IpAddr::from_str("2001:db8::1").unwrap()), + IpAddr::from_str("2001:db8::2").unwrap() + ); + + // Test overflow + assert_eq!( + next_ip(IpAddr::from_str("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff").unwrap()), + IpAddr::from_str("::").unwrap() + ); + } + + #[test] + fn test_merge_addrs_large_gap() { + let addrs = vec![ + Address::Network(IpNetwork::from_str("192.168.1.1/32").unwrap()), + Address::Network(IpNetwork::from_str("192.168.1.100/32").unwrap()), + ]; + let result = merge_addrs(&addrs); + + // Should not merge since there's a large gap + assert_eq!(result.len(), 2); + if let Address::Range(range1) = &result[0] { + assert_eq!(range1.start(), IpAddr::from_str("192.168.1.1").unwrap()); + assert_eq!(range1.end(), IpAddr::from_str("192.168.1.1").unwrap()); + } else { + panic!("Expected Address::Range"); + } + if let Address::Range(range2) = &result[1] { + assert_eq!(range2.start(), IpAddr::from_str("192.168.1.100").unwrap()); + assert_eq!(range2.end(), IpAddr::from_str("192.168.1.100").unwrap()); + } else { + panic!("Expected Address::Range"); + } + } +} diff --git a/src/enterprise/firewall/nftables/netfilter.rs b/src/enterprise/firewall/nftables/netfilter.rs index 7174d9b0..5be57aeb 100644 --- a/src/enterprise/firewall/nftables/netfilter.rs +++ b/src/enterprise/firewall/nftables/netfilter.rs @@ -5,7 +5,6 @@ use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr}, }; -use ipnetwork::IpNetwork; #[cfg(test)] use ipnetwork::{Ipv4Network, Ipv6Network}; use nftnl::{ @@ -16,7 +15,7 @@ use nftnl::{ }; use super::{get_set_id, Address, FilterRule, Policy, Port, Protocol, State}; -use crate::enterprise::firewall::{iprange::IpAddrRange, FirewallError}; +use crate::enterprise::firewall::{iprange::IpAddrRange, max_address, FirewallError}; const FILTER_TABLE: &str = "filter"; const NAT_TABLE: &str = "nat"; @@ -809,27 +808,6 @@ fn socket_recv<'a>( } } -/// Get the max address in a network. -/// -/// - In IPv4 this is the broadcast address. -/// - In IPv6 this is just the last address in the network. -fn max_address(network: &IpNetwork) -> IpAddr { - match network { - IpNetwork::V4(network) => { - let addr = network.ip().to_bits(); - let mask = network.mask().to_bits(); - - IpAddr::V4(Ipv4Addr::from(addr | !mask)) - } - IpNetwork::V6(network) => { - let addr = network.ip().to_bits(); - let mask = network.mask().to_bits(); - - IpAddr::V6(Ipv6Addr::from(addr | !mask)) - } - } -} - fn new_anon_set( table: &Table, family: ProtoFamily, @@ -991,59 +969,4 @@ mod tests { increment_bytes(&mut ip); assert_eq!(ip, [0, 0, 0, 0, 0, 0, 0, 0]); } - - #[test] - fn test_max_address_ipv4_24() { - let network = IpNetwork::V4(Ipv4Network::from_str("192.168.1.0/24").unwrap()); - let max = max_address(&network); - assert_eq!(max, IpAddr::V4(Ipv4Addr::new(192, 168, 1, 255))); - } - - #[test] - fn test_max_address_ipv4_16() { - let network = IpNetwork::V4(Ipv4Network::from_str("10.1.0.0/16").unwrap()); - let max = max_address(&network); - assert_eq!(max, IpAddr::V4(Ipv4Addr::new(10, 1, 255, 255))); - } - - #[test] - fn test_max_address_ipv4_8() { - let network = IpNetwork::V4(Ipv4Network::from_str("172.16.0.0/8").unwrap()); - let max = max_address(&network); - assert_eq!(max, IpAddr::V4(Ipv4Addr::new(172, 255, 255, 255))); - } - - #[test] - fn test_max_address_ipv4_32() { - let network = IpNetwork::V4(Ipv4Network::from_str("192.168.1.1/32").unwrap()); - let max = max_address(&network); - assert_eq!(max, IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))); - } - - #[test] - fn test_max_address_ipv6_64() { - let network = IpNetwork::V6(Ipv6Network::from_str("2001:db8::/64").unwrap()); - let max = max_address(&network); - assert_eq!( - max, - IpAddr::V6(Ipv6Addr::from_str("2001:db8::ffff:ffff:ffff:ffff").unwrap()) - ); - } - - #[test] - fn test_max_address_ipv6_128() { - let network = IpNetwork::V6(Ipv6Network::from_str("2001:db8::1/128").unwrap()); - let max = max_address(&network); - assert_eq!(max, IpAddr::V6(Ipv6Addr::from_str("2001:db8::1").unwrap())); - } - - #[test] - fn test_max_address_ipv6_48() { - let network = IpNetwork::V6(Ipv6Network::from_str("2001:db8:1234::/48").unwrap()); - let max = max_address(&network); - assert_eq!( - max, - IpAddr::V6(Ipv6Addr::from_str("2001:db8:1234:ffff:ffff:ffff:ffff:ffff").unwrap()) - ); - } } From 605041d0d2f188763ad4c3791e6eaf7e442e5855 Mon Sep 17 00:00:00 2001 From: Maciek <19913370+wojcik91@users.noreply.github.com> Date: Tue, 5 Aug 2025 16:15:22 +0200 Subject: [PATCH 2/5] merge adjacent subnets for nft (#185) * merge adjacent elements * update new tests --- src/enterprise/firewall/nftables/mod.rs | 292 ++++++------------ src/enterprise/firewall/nftables/netfilter.rs | 4 - 2 files changed, 95 insertions(+), 201 deletions(-) diff --git a/src/enterprise/firewall/nftables/mod.rs b/src/enterprise/firewall/nftables/mod.rs index fd529231..ec3b1976 100644 --- a/src/enterprise/firewall/nftables/mod.rs +++ b/src/enterprise/firewall/nftables/mod.rs @@ -13,6 +13,7 @@ use nftnl::Batch; use super::{ api::{FirewallApi, FirewallManagementApi}, + iprange::IpAddrRangeError, Address, FirewallError, FirewallRule, Policy, Port, Protocol, }; use crate::enterprise::firewall::iprange::IpAddrRange; @@ -55,80 +56,68 @@ struct FilterRule<'a> { negated_iifname: bool, } -fn merge_addrs<'a>(addrs: &'a [Address]) -> Vec
{ - debug!("merge_addrs called with input: {:?}", addrs); +/// Merges any contiguous subets or addres ranges into an address range. +/// +/// This reflects the way `nft` CLI handles such cases. +/// Otherwise first address in any subnet after the first is not matched. +/// For example if we use `172.30.0.2/31, 172.30.0.4/31` as `saddr` in a rule, +/// then 172.30.0.4 will not be matched. +fn merge_addrs(addrs: Vec
) -> Result, IpAddrRangeError> { + debug!( + "Merging any contiguous subnets found within address list: {:?}", + addrs + ); if addrs.is_empty() { debug!("No addresses provided, returning empty vector."); - return Vec::new(); + return Ok(Vec::new()); } let mut merged_addrs = Vec::new(); - - // sort them by their .first() address - let mut addrs_sorted = Vec::from_iter(addrs.iter()); - addrs_sorted.sort_by(|a, b| { - a.first() - .partial_cmp(&b.first()) - .unwrap_or(std::cmp::Ordering::Equal) - }); - - debug!("Sorted addresses: {:?}", addrs_sorted); - - let first_addr = addrs_sorted.remove(0); - let mut current_range_start = first_addr.first(); - let mut current_range_end = first_addr.last(); - - debug!( - "Starting merge loop with initial range: {:?} - {:?}", - current_range_start, current_range_end - ); - - for addr in addrs_sorted { - let first_addr = addr.first(); - let last_addr = addr.last(); - - debug!( - "Checking addr: {:?} - {:?} against current_range_end: {:?}", - first_addr, last_addr, current_range_end - ); - - // Check if ranges overlap or are adjacent - if first_addr <= current_range_end || next_ip(current_range_end) == first_addr { - // Ranges overlap or are adjacent, merge them - if last_addr > current_range_end { - debug!( - "Extending current_range_end from {:?} to {:?}", - current_range_end, last_addr - ); - current_range_end = last_addr; + let mut current_address = None; + + // we can assume addresses coming from the core + // are already sorted and non-overlapping + for next_address in addrs { + match ¤t_address { + None => { + debug!("Initializing current address with: {next_address:?}"); + current_address = Some(next_address); + } + Some(previous_address) => { + let previous_range_start = previous_address.first(); + let previous_range_end = previous_address.last(); + let next_ip = next_ip(previous_range_end); + + let next_range_start = next_address.first(); + let next_range_end = next_address.last(); + + // check if range is adjacent to current address + if next_range_start == next_ip { + // replace current address with a combined range + debug!("Merging {next_address:?} with {current_address:?}"); + current_address = Some(Address::Range(IpAddrRange::new( + previous_range_start, + next_range_end, + )?)); + } else { + // push previous address to result and replace with next address + merged_addrs.push(previous_address.clone()); + current_address = Some(next_address); + }; } - } else { - // Ranges don't overlap and aren't adjacent, push current range and start new one - debug!( - "Pushing merged range: {:?} - {:?}", - current_range_start, current_range_end - ); - merged_addrs.push(Address::Range( - IpAddrRange::new(current_range_start, current_range_end).unwrap(), - )); - current_range_start = first_addr; - current_range_end = last_addr; } } - // Push the last range - debug!( - "Pushing final merged range: {:?} - {:?}", - current_range_start, current_range_end - ); - merged_addrs.push(Address::Range( - IpAddrRange::new(current_range_start, current_range_end).unwrap(), - )); + // push last remaining address to results + if let Some(address) = current_address { + debug!("Pushing last remaining address into results: {address:?}"); + merged_addrs.push(address) + } debug!("Prepared addresses: {:?}", merged_addrs); - merged_addrs + Ok(merged_addrs) } /// Returns the next IP address in sequence, handling overflow via wrapping @@ -174,8 +163,8 @@ impl FirewallApi { destination ports and protocols." ); - let source_addrs = merge_addrs(&rule.source_addrs); - let dest_addrs = merge_addrs(&rule.destination_addrs); + let source_addrs = merge_addrs(rule.source_addrs)?; + let dest_addrs = merge_addrs(rule.destination_addrs)?; if rule.destination_ports.is_empty() { debug!( @@ -377,7 +366,6 @@ mod tests { use ipnetwork::IpNetwork; use super::*; - use crate::proto::enterprise::firewall::FirewallRule as ProtoFirewallRule; #[test] fn test_sorting() { @@ -408,13 +396,13 @@ mod tests { ] ); - let _prepared_addrs = merge_addrs(&addrs); + let _prepared_addrs = merge_addrs(addrs).unwrap(); } #[test] fn test_merge_addrs_empty() { let addrs: Vec
= vec![]; - let result = merge_addrs(&addrs); + let result = merge_addrs(addrs).unwrap(); assert!(result.is_empty()); } @@ -423,44 +411,9 @@ mod tests { let addrs = vec![Address::Network( IpNetwork::from_str("192.168.1.10/32").unwrap(), )]; - let result = merge_addrs(&addrs); + let result = merge_addrs(addrs.clone()).unwrap(); - assert_eq!(result.len(), 1); - if let Address::Range(range) = &result[0] { - assert_eq!(range.start(), IpAddr::from_str("192.168.1.10").unwrap()); - assert_eq!(range.end(), IpAddr::from_str("192.168.1.10").unwrap()); - } else { - panic!("Expected Address::Range"); - } - } - - #[test] - fn test_merge_addrs_overlapping_ranges() { - let addrs = vec![ - Address::Range( - IpAddrRange::new( - IpAddr::from_str("192.168.1.10").unwrap(), - IpAddr::from_str("192.168.1.20").unwrap(), - ) - .unwrap(), - ), - Address::Range( - IpAddrRange::new( - IpAddr::from_str("192.168.1.15").unwrap(), - IpAddr::from_str("192.168.1.25").unwrap(), - ) - .unwrap(), - ), - ]; - let result = merge_addrs(&addrs); - - assert_eq!(result.len(), 1); - if let Address::Range(range) = &result[0] { - assert_eq!(range.start(), IpAddr::from_str("192.168.1.10").unwrap()); - assert_eq!(range.end(), IpAddr::from_str("192.168.1.25").unwrap()); - } else { - panic!("Expected Address::Range"); - } + assert_eq!(result, addrs); } #[test] @@ -481,7 +434,7 @@ mod tests { .unwrap(), ), ]; - let result = merge_addrs(&addrs); + let result = merge_addrs(addrs).unwrap(); assert_eq!(result.len(), 1); if let Address::Range(range) = &result[0] { @@ -499,7 +452,7 @@ mod tests { Address::Network(IpNetwork::from_str("192.168.1.11/32").unwrap()), Address::Network(IpNetwork::from_str("192.168.1.12/32").unwrap()), ]; - let result = merge_addrs(&addrs); + let result = merge_addrs(addrs).unwrap(); assert_eq!(result.len(), 1); if let Address::Range(range) = &result[0] { @@ -528,7 +481,7 @@ mod tests { .unwrap(), ), ]; - let result = merge_addrs(&addrs); + let result = merge_addrs(addrs).unwrap(); assert_eq!(result.len(), 2); if let Address::Range(range1) = &result[0] { @@ -558,7 +511,7 @@ mod tests { ), Address::Network(IpNetwork::from_str("192.168.1.16/32").unwrap()), ]; - let result = merge_addrs(&addrs); + let result = merge_addrs(addrs).unwrap(); assert_eq!(result.len(), 1); if let Address::Range(range) = &result[0] { @@ -569,58 +522,29 @@ mod tests { } } - #[test] - fn test_merge_addrs_unsorted_input() { - let addrs = vec![ - Address::Network(IpNetwork::from_str("192.168.1.13/32").unwrap()), - Address::Network(IpNetwork::from_str("192.168.1.10/32").unwrap()), - Address::Network(IpNetwork::from_str("192.168.1.12/32").unwrap()), - Address::Network(IpNetwork::from_str("192.168.1.11/32").unwrap()), - ]; - let result = merge_addrs(&addrs); - - assert_eq!(result.len(), 1); - if let Address::Range(range) = &result[0] { - assert_eq!(range.start(), IpAddr::from_str("192.168.1.10").unwrap()); - assert_eq!(range.end(), IpAddr::from_str("192.168.1.13").unwrap()); - } else { - panic!("Expected Address::Range"); - } - } - #[test] fn test_merge_addrs_non_adjacent_singles() { let addrs = vec![ - Address::Network(IpNetwork::from_str("192.168.1.20/32").unwrap()), Address::Network(IpNetwork::from_str("192.168.1.10/32").unwrap()), - Address::Network(IpNetwork::from_str("192.168.1.15/32").unwrap()), Address::Network(IpNetwork::from_str("192.168.1.11/32").unwrap()), + Address::Network(IpNetwork::from_str("192.168.1.15/32").unwrap()), + Address::Network(IpNetwork::from_str("192.168.1.20/32").unwrap()), ]; - let result = merge_addrs(&addrs); + let result = merge_addrs(addrs).unwrap(); // These should result in 3 separate ranges: 10-11, 15, 20 - assert_eq!(result.len(), 3); - - if let Address::Range(range1) = &result[0] { - assert_eq!(range1.start(), IpAddr::from_str("192.168.1.10").unwrap()); - assert_eq!(range1.end(), IpAddr::from_str("192.168.1.11").unwrap()); - } else { - panic!("Expected Address::Range"); - } - - if let Address::Range(range2) = &result[1] { - assert_eq!(range2.start(), IpAddr::from_str("192.168.1.15").unwrap()); - assert_eq!(range2.end(), IpAddr::from_str("192.168.1.15").unwrap()); - } else { - panic!("Expected Address::Range"); - } - - if let Address::Range(range3) = &result[2] { - assert_eq!(range3.start(), IpAddr::from_str("192.168.1.20").unwrap()); - assert_eq!(range3.end(), IpAddr::from_str("192.168.1.20").unwrap()); - } else { - panic!("Expected Address::Range"); - } + let expected_addrs = vec![ + Address::Range( + IpAddrRange::new( + IpAddr::from_str("192.168.1.10").unwrap(), + IpAddr::from_str("192.168.1.11").unwrap(), + ) + .unwrap(), + ), + Address::Network(IpNetwork::from_str("192.168.1.15/32").unwrap()), + Address::Network(IpNetwork::from_str("192.168.1.20/32").unwrap()), + ]; + assert_eq!(result, expected_addrs); } #[test] @@ -630,7 +554,7 @@ mod tests { Address::Network(IpNetwork::from_str("2001:db8::2/128").unwrap()), Address::Network(IpNetwork::from_str("2001:db8::3/128").unwrap()), ]; - let result = merge_addrs(&addrs); + let result = merge_addrs(addrs).unwrap(); assert_eq!(result.len(), 1); if let Address::Range(range) = &result[0] { @@ -641,35 +565,6 @@ mod tests { } } - #[test] - fn test_merge_addrs_contained_ranges() { - let addrs = vec![ - Address::Range( - IpAddrRange::new( - IpAddr::from_str("192.168.1.10").unwrap(), - IpAddr::from_str("192.168.1.30").unwrap(), - ) - .unwrap(), - ), - Address::Range( - IpAddrRange::new( - IpAddr::from_str("192.168.1.15").unwrap(), - IpAddr::from_str("192.168.1.20").unwrap(), - ) - .unwrap(), - ), - ]; - let result = merge_addrs(&addrs); - - assert_eq!(result.len(), 1); - if let Address::Range(range) = &result[0] { - assert_eq!(range.start(), IpAddr::from_str("192.168.1.10").unwrap()); - assert_eq!(range.end(), IpAddr::from_str("192.168.1.30").unwrap()); - } else { - panic!("Expected Address::Range"); - } - } - #[test] fn test_next_ip_ipv4() { assert_eq!( @@ -699,26 +594,29 @@ mod tests { } #[test] - fn test_merge_addrs_large_gap() { + fn test_merge_addrs_gap() { let addrs = vec![ Address::Network(IpNetwork::from_str("192.168.1.1/32").unwrap()), Address::Network(IpNetwork::from_str("192.168.1.100/32").unwrap()), ]; - let result = merge_addrs(&addrs); + let result = merge_addrs(addrs.clone()).unwrap(); - // Should not merge since there's a large gap - assert_eq!(result.len(), 2); - if let Address::Range(range1) = &result[0] { - assert_eq!(range1.start(), IpAddr::from_str("192.168.1.1").unwrap()); - assert_eq!(range1.end(), IpAddr::from_str("192.168.1.1").unwrap()); - } else { - panic!("Expected Address::Range"); - } - if let Address::Range(range2) = &result[1] { - assert_eq!(range2.start(), IpAddr::from_str("192.168.1.100").unwrap()); - assert_eq!(range2.end(), IpAddr::from_str("192.168.1.100").unwrap()); - } else { - panic!("Expected Address::Range"); - } + // Should not merge since there's a gap + assert_eq!(result, addrs); + + let addrs = vec![ + Address::Range( + IpAddrRange::new( + IpAddr::from_str("192.168.1.10").unwrap(), + IpAddr::from_str("192.168.1.20").unwrap(), + ) + .unwrap(), + ), + Address::Network(IpNetwork::from_str("192.168.1.100/32").unwrap()), + ]; + let result = merge_addrs(addrs.clone()).unwrap(); + + // Should not merge since there's a gap + assert_eq!(result, addrs); } } diff --git a/src/enterprise/firewall/nftables/netfilter.rs b/src/enterprise/firewall/nftables/netfilter.rs index 5be57aeb..b5d8505c 100644 --- a/src/enterprise/firewall/nftables/netfilter.rs +++ b/src/enterprise/firewall/nftables/netfilter.rs @@ -1,12 +1,8 @@ -#[cfg(test)] -use std::str::FromStr; use std::{ ffi::{CStr, CString}, net::{IpAddr, Ipv4Addr, Ipv6Addr}, }; -#[cfg(test)] -use ipnetwork::{Ipv4Network, Ipv6Network}; use nftnl::{ expr::{Expression, InterfaceName}, nft_expr, nftnl_sys, From f20852c5e3bbafcdeaa02ecd9b49d61f783bbc89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Tue, 5 Aug 2025 16:28:07 +0200 Subject: [PATCH 3/5] linter fixes --- src/enterprise/firewall/mod.rs | 4 ++++ src/enterprise/firewall/nftables/mod.rs | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/enterprise/firewall/mod.rs b/src/enterprise/firewall/mod.rs index 21660de5..ef4304d2 100644 --- a/src/enterprise/firewall/mod.rs +++ b/src/enterprise/firewall/mod.rs @@ -26,6 +26,8 @@ pub(crate) enum Address { } impl Address { + // FIXME: remove after merging nft hotfix into dev + #[allow(dead_code)] pub fn first(&self) -> IpAddr { match self { Address::Network(network) => network.ip(), @@ -33,6 +35,8 @@ impl Address { } } + // FIXME: remove after merging nft hotfix into dev + #[allow(dead_code)] pub fn last(&self) -> IpAddr { match self { Address::Network(network) => max_address(network), diff --git a/src/enterprise/firewall/nftables/mod.rs b/src/enterprise/firewall/nftables/mod.rs index ec3b1976..ad4ce731 100644 --- a/src/enterprise/firewall/nftables/mod.rs +++ b/src/enterprise/firewall/nftables/mod.rs @@ -64,7 +64,7 @@ struct FilterRule<'a> { /// then 172.30.0.4 will not be matched. fn merge_addrs(addrs: Vec
) -> Result, IpAddrRangeError> { debug!( - "Merging any contiguous subnets found within address list: {:?}", + "Merging any contiguous subnets and ranges found within address list: {:?}", addrs ); @@ -115,7 +115,7 @@ fn merge_addrs(addrs: Vec
) -> Result, IpAddrRangeError> { merged_addrs.push(address) } - debug!("Prepared addresses: {:?}", merged_addrs); + debug!("Prepared addresses: {merged_addrs:?}"); Ok(merged_addrs) } From 25ea461d8ca86a1bc94a481a7c771ac30dbe81c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Tue, 5 Aug 2025 16:36:14 +0200 Subject: [PATCH 4/5] linter fix --- src/enterprise/firewall/nftables/mod.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/enterprise/firewall/nftables/mod.rs b/src/enterprise/firewall/nftables/mod.rs index ad4ce731..8259356b 100644 --- a/src/enterprise/firewall/nftables/mod.rs +++ b/src/enterprise/firewall/nftables/mod.rs @@ -63,10 +63,7 @@ struct FilterRule<'a> { /// For example if we use `172.30.0.2/31, 172.30.0.4/31` as `saddr` in a rule, /// then 172.30.0.4 will not be matched. fn merge_addrs(addrs: Vec
) -> Result, IpAddrRangeError> { - debug!( - "Merging any contiguous subnets and ranges found within address list: {:?}", - addrs - ); + debug!("Merging any contiguous subnets and ranges found within address list: {addrs:?}"); if addrs.is_empty() { debug!("No addresses provided, returning empty vector."); From 3af29b486a42a6e82e612616e919335897d92d97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Wed, 6 Aug 2025 09:22:33 +0200 Subject: [PATCH 5/5] review fixes --- src/enterprise/firewall/nftables/mod.rs | 62 ++++++++----------------- 1 file changed, 19 insertions(+), 43 deletions(-) diff --git a/src/enterprise/firewall/nftables/mod.rs b/src/enterprise/firewall/nftables/mod.rs index 8259356b..614b9831 100644 --- a/src/enterprise/firewall/nftables/mod.rs +++ b/src/enterprise/firewall/nftables/mod.rs @@ -121,26 +121,14 @@ fn merge_addrs(addrs: Vec
) -> Result, IpAddrRangeError> { fn next_ip(ip: IpAddr) -> IpAddr { match ip { IpAddr::V4(ipv4) => { - let octets = ipv4.octets(); - let mut num: u32 = ((octets[0] as u32) << 24) - | ((octets[1] as u32) << 16) - | ((octets[2] as u32) << 8) - | octets[3] as u32; - num = num.wrapping_add(1); - IpAddr::V4(Ipv4Addr::from(num)) + let ip_u32 = ipv4.to_bits(); + let next_ip_u32 = ip_u32.wrapping_add(1); + IpAddr::V4(Ipv4Addr::from(next_ip_u32)) } IpAddr::V6(ipv6) => { - let segments = ipv6.segments(); - let mut num: u128 = ((segments[0] as u128) << 112) - | ((segments[1] as u128) << 96) - | ((segments[2] as u128) << 80) - | ((segments[3] as u128) << 64) - | ((segments[4] as u128) << 48) - | ((segments[5] as u128) << 32) - | ((segments[6] as u128) << 16) - | segments[7] as u128; - num = num.wrapping_add(1); - IpAddr::V6(Ipv6Addr::from(num)) + let ip_u128 = ipv6.to_bits(); + let next_ip_u128 = ip_u128.wrapping_add(1); + IpAddr::V6(Ipv6Addr::from(next_ip_u128)) } } } @@ -187,39 +175,27 @@ impl FirewallApi { ); for protocol in rule.protocols.clone() { debug!("Applying rule for protocol: {protocol:?}"); + let mut filter_rule = FilterRule { + src_ips: &source_addrs, + dest_ips: &dest_addrs, + protocols: vec![protocol], + action: rule.verdict, + counter: true, + defguard_rule_id: rule.id, + v4: rule.ipv4, + comment: rule.comment.clone(), + ..Default::default() + }; if protocol.supports_ports() { debug!("Protocol supports ports, rule."); - let rule = FilterRule { - src_ips: &source_addrs, - dest_ips: &dest_addrs, - dest_ports: &rule.destination_ports, - protocols: vec![protocol], - action: rule.verdict, - counter: true, - defguard_rule_id: rule.id, - v4: rule.ipv4, - comment: rule.comment.clone(), - ..Default::default() - }; - filter_rules.push(rule); + filter_rule.dest_ports = &rule.destination_ports; } else { debug!( "Protocol does not support ports, applying nftables rule and ignoring \ destination ports." ); - let rule = FilterRule { - src_ips: &source_addrs, - dest_ips: &dest_addrs, - protocols: vec![protocol], - action: rule.verdict, - counter: true, - defguard_rule_id: rule.id, - v4: rule.ipv4, - comment: rule.comment.clone(), - ..Default::default() - }; - filter_rules.push(rule); } + filter_rules.push(filter_rule); } } else { debug!(