diff --git a/src/enterprise/firewall/nftables/mod.rs b/src/enterprise/firewall/nftables/mod.rs index fd529231..ec3b1976 100644 --- a/src/enterprise/firewall/nftables/mod.rs +++ b/src/enterprise/firewall/nftables/mod.rs @@ -13,6 +13,7 @@ use nftnl::Batch; use super::{ api::{FirewallApi, FirewallManagementApi}, + iprange::IpAddrRangeError, Address, FirewallError, FirewallRule, Policy, Port, Protocol, }; use crate::enterprise::firewall::iprange::IpAddrRange; @@ -55,80 +56,68 @@ struct FilterRule<'a> { negated_iifname: bool, } -fn merge_addrs<'a>(addrs: &'a [Address]) -> Vec
{ - debug!("merge_addrs called with input: {:?}", addrs); +/// Merges any contiguous subets or addres ranges into an address range. +/// +/// This reflects the way `nft` CLI handles such cases. +/// Otherwise first address in any subnet after the first is not matched. +/// For example if we use `172.30.0.2/31, 172.30.0.4/31` as `saddr` in a rule, +/// then 172.30.0.4 will not be matched. +fn merge_addrs(addrs: Vec
) -> Result, IpAddrRangeError> { + debug!( + "Merging any contiguous subnets found within address list: {:?}", + addrs + ); if addrs.is_empty() { debug!("No addresses provided, returning empty vector."); - return Vec::new(); + return Ok(Vec::new()); } let mut merged_addrs = Vec::new(); - - // sort them by their .first() address - let mut addrs_sorted = Vec::from_iter(addrs.iter()); - addrs_sorted.sort_by(|a, b| { - a.first() - .partial_cmp(&b.first()) - .unwrap_or(std::cmp::Ordering::Equal) - }); - - debug!("Sorted addresses: {:?}", addrs_sorted); - - let first_addr = addrs_sorted.remove(0); - let mut current_range_start = first_addr.first(); - let mut current_range_end = first_addr.last(); - - debug!( - "Starting merge loop with initial range: {:?} - {:?}", - current_range_start, current_range_end - ); - - for addr in addrs_sorted { - let first_addr = addr.first(); - let last_addr = addr.last(); - - debug!( - "Checking addr: {:?} - {:?} against current_range_end: {:?}", - first_addr, last_addr, current_range_end - ); - - // Check if ranges overlap or are adjacent - if first_addr <= current_range_end || next_ip(current_range_end) == first_addr { - // Ranges overlap or are adjacent, merge them - if last_addr > current_range_end { - debug!( - "Extending current_range_end from {:?} to {:?}", - current_range_end, last_addr - ); - current_range_end = last_addr; + let mut current_address = None; + + // we can assume addresses coming from the core + // are already sorted and non-overlapping + for next_address in addrs { + match ¤t_address { + None => { + debug!("Initializing current address with: {next_address:?}"); + current_address = Some(next_address); + } + Some(previous_address) => { + let previous_range_start = previous_address.first(); + let previous_range_end = previous_address.last(); + let next_ip = next_ip(previous_range_end); + + let next_range_start = next_address.first(); + let next_range_end = next_address.last(); + + // check if range is adjacent to current address + if next_range_start == next_ip { + // replace current address with a combined range + debug!("Merging {next_address:?} with {current_address:?}"); + current_address = Some(Address::Range(IpAddrRange::new( + previous_range_start, + next_range_end, + )?)); + } else { + // push previous address to result and replace with next address + merged_addrs.push(previous_address.clone()); + current_address = Some(next_address); + }; } - } else { - // Ranges don't overlap and aren't adjacent, push current range and start new one - debug!( - "Pushing merged range: {:?} - {:?}", - current_range_start, current_range_end - ); - merged_addrs.push(Address::Range( - IpAddrRange::new(current_range_start, current_range_end).unwrap(), - )); - current_range_start = first_addr; - current_range_end = last_addr; } } - // Push the last range - debug!( - "Pushing final merged range: {:?} - {:?}", - current_range_start, current_range_end - ); - merged_addrs.push(Address::Range( - IpAddrRange::new(current_range_start, current_range_end).unwrap(), - )); + // push last remaining address to results + if let Some(address) = current_address { + debug!("Pushing last remaining address into results: {address:?}"); + merged_addrs.push(address) + } debug!("Prepared addresses: {:?}", merged_addrs); - merged_addrs + Ok(merged_addrs) } /// Returns the next IP address in sequence, handling overflow via wrapping @@ -174,8 +163,8 @@ impl FirewallApi { destination ports and protocols." ); - let source_addrs = merge_addrs(&rule.source_addrs); - let dest_addrs = merge_addrs(&rule.destination_addrs); + let source_addrs = merge_addrs(rule.source_addrs)?; + let dest_addrs = merge_addrs(rule.destination_addrs)?; if rule.destination_ports.is_empty() { debug!( @@ -377,7 +366,6 @@ mod tests { use ipnetwork::IpNetwork; use super::*; - use crate::proto::enterprise::firewall::FirewallRule as ProtoFirewallRule; #[test] fn test_sorting() { @@ -408,13 +396,13 @@ mod tests { ] ); - let _prepared_addrs = merge_addrs(&addrs); + let _prepared_addrs = merge_addrs(addrs).unwrap(); } #[test] fn test_merge_addrs_empty() { let addrs: Vec
= vec![]; - let result = merge_addrs(&addrs); + let result = merge_addrs(addrs).unwrap(); assert!(result.is_empty()); } @@ -423,44 +411,9 @@ mod tests { let addrs = vec![Address::Network( IpNetwork::from_str("192.168.1.10/32").unwrap(), )]; - let result = merge_addrs(&addrs); + let result = merge_addrs(addrs.clone()).unwrap(); - assert_eq!(result.len(), 1); - if let Address::Range(range) = &result[0] { - assert_eq!(range.start(), IpAddr::from_str("192.168.1.10").unwrap()); - assert_eq!(range.end(), IpAddr::from_str("192.168.1.10").unwrap()); - } else { - panic!("Expected Address::Range"); - } - } - - #[test] - fn test_merge_addrs_overlapping_ranges() { - let addrs = vec![ - Address::Range( - IpAddrRange::new( - IpAddr::from_str("192.168.1.10").unwrap(), - IpAddr::from_str("192.168.1.20").unwrap(), - ) - .unwrap(), - ), - Address::Range( - IpAddrRange::new( - IpAddr::from_str("192.168.1.15").unwrap(), - IpAddr::from_str("192.168.1.25").unwrap(), - ) - .unwrap(), - ), - ]; - let result = merge_addrs(&addrs); - - assert_eq!(result.len(), 1); - if let Address::Range(range) = &result[0] { - assert_eq!(range.start(), IpAddr::from_str("192.168.1.10").unwrap()); - assert_eq!(range.end(), IpAddr::from_str("192.168.1.25").unwrap()); - } else { - panic!("Expected Address::Range"); - } + assert_eq!(result, addrs); } #[test] @@ -481,7 +434,7 @@ mod tests { .unwrap(), ), ]; - let result = merge_addrs(&addrs); + let result = merge_addrs(addrs).unwrap(); assert_eq!(result.len(), 1); if let Address::Range(range) = &result[0] { @@ -499,7 +452,7 @@ mod tests { Address::Network(IpNetwork::from_str("192.168.1.11/32").unwrap()), Address::Network(IpNetwork::from_str("192.168.1.12/32").unwrap()), ]; - let result = merge_addrs(&addrs); + let result = merge_addrs(addrs).unwrap(); assert_eq!(result.len(), 1); if let Address::Range(range) = &result[0] { @@ -528,7 +481,7 @@ mod tests { .unwrap(), ), ]; - let result = merge_addrs(&addrs); + let result = merge_addrs(addrs).unwrap(); assert_eq!(result.len(), 2); if let Address::Range(range1) = &result[0] { @@ -558,7 +511,7 @@ mod tests { ), Address::Network(IpNetwork::from_str("192.168.1.16/32").unwrap()), ]; - let result = merge_addrs(&addrs); + let result = merge_addrs(addrs).unwrap(); assert_eq!(result.len(), 1); if let Address::Range(range) = &result[0] { @@ -569,58 +522,29 @@ mod tests { } } - #[test] - fn test_merge_addrs_unsorted_input() { - let addrs = vec![ - Address::Network(IpNetwork::from_str("192.168.1.13/32").unwrap()), - Address::Network(IpNetwork::from_str("192.168.1.10/32").unwrap()), - Address::Network(IpNetwork::from_str("192.168.1.12/32").unwrap()), - Address::Network(IpNetwork::from_str("192.168.1.11/32").unwrap()), - ]; - let result = merge_addrs(&addrs); - - assert_eq!(result.len(), 1); - if let Address::Range(range) = &result[0] { - assert_eq!(range.start(), IpAddr::from_str("192.168.1.10").unwrap()); - assert_eq!(range.end(), IpAddr::from_str("192.168.1.13").unwrap()); - } else { - panic!("Expected Address::Range"); - } - } - #[test] fn test_merge_addrs_non_adjacent_singles() { let addrs = vec![ - Address::Network(IpNetwork::from_str("192.168.1.20/32").unwrap()), Address::Network(IpNetwork::from_str("192.168.1.10/32").unwrap()), - Address::Network(IpNetwork::from_str("192.168.1.15/32").unwrap()), Address::Network(IpNetwork::from_str("192.168.1.11/32").unwrap()), + Address::Network(IpNetwork::from_str("192.168.1.15/32").unwrap()), + Address::Network(IpNetwork::from_str("192.168.1.20/32").unwrap()), ]; - let result = merge_addrs(&addrs); + let result = merge_addrs(addrs).unwrap(); // These should result in 3 separate ranges: 10-11, 15, 20 - assert_eq!(result.len(), 3); - - if let Address::Range(range1) = &result[0] { - assert_eq!(range1.start(), IpAddr::from_str("192.168.1.10").unwrap()); - assert_eq!(range1.end(), IpAddr::from_str("192.168.1.11").unwrap()); - } else { - panic!("Expected Address::Range"); - } - - if let Address::Range(range2) = &result[1] { - assert_eq!(range2.start(), IpAddr::from_str("192.168.1.15").unwrap()); - assert_eq!(range2.end(), IpAddr::from_str("192.168.1.15").unwrap()); - } else { - panic!("Expected Address::Range"); - } - - if let Address::Range(range3) = &result[2] { - assert_eq!(range3.start(), IpAddr::from_str("192.168.1.20").unwrap()); - assert_eq!(range3.end(), IpAddr::from_str("192.168.1.20").unwrap()); - } else { - panic!("Expected Address::Range"); - } + let expected_addrs = vec![ + Address::Range( + IpAddrRange::new( + IpAddr::from_str("192.168.1.10").unwrap(), + IpAddr::from_str("192.168.1.11").unwrap(), + ) + .unwrap(), + ), + Address::Network(IpNetwork::from_str("192.168.1.15/32").unwrap()), + Address::Network(IpNetwork::from_str("192.168.1.20/32").unwrap()), + ]; + assert_eq!(result, expected_addrs); } #[test] @@ -630,7 +554,7 @@ mod tests { Address::Network(IpNetwork::from_str("2001:db8::2/128").unwrap()), Address::Network(IpNetwork::from_str("2001:db8::3/128").unwrap()), ]; - let result = merge_addrs(&addrs); + let result = merge_addrs(addrs).unwrap(); assert_eq!(result.len(), 1); if let Address::Range(range) = &result[0] { @@ -641,35 +565,6 @@ mod tests { } } - #[test] - fn test_merge_addrs_contained_ranges() { - let addrs = vec![ - Address::Range( - IpAddrRange::new( - IpAddr::from_str("192.168.1.10").unwrap(), - IpAddr::from_str("192.168.1.30").unwrap(), - ) - .unwrap(), - ), - Address::Range( - IpAddrRange::new( - IpAddr::from_str("192.168.1.15").unwrap(), - IpAddr::from_str("192.168.1.20").unwrap(), - ) - .unwrap(), - ), - ]; - let result = merge_addrs(&addrs); - - assert_eq!(result.len(), 1); - if let Address::Range(range) = &result[0] { - assert_eq!(range.start(), IpAddr::from_str("192.168.1.10").unwrap()); - assert_eq!(range.end(), IpAddr::from_str("192.168.1.30").unwrap()); - } else { - panic!("Expected Address::Range"); - } - } - #[test] fn test_next_ip_ipv4() { assert_eq!( @@ -699,26 +594,29 @@ mod tests { } #[test] - fn test_merge_addrs_large_gap() { + fn test_merge_addrs_gap() { let addrs = vec![ Address::Network(IpNetwork::from_str("192.168.1.1/32").unwrap()), Address::Network(IpNetwork::from_str("192.168.1.100/32").unwrap()), ]; - let result = merge_addrs(&addrs); + let result = merge_addrs(addrs.clone()).unwrap(); - // Should not merge since there's a large gap - assert_eq!(result.len(), 2); - if let Address::Range(range1) = &result[0] { - assert_eq!(range1.start(), IpAddr::from_str("192.168.1.1").unwrap()); - assert_eq!(range1.end(), IpAddr::from_str("192.168.1.1").unwrap()); - } else { - panic!("Expected Address::Range"); - } - if let Address::Range(range2) = &result[1] { - assert_eq!(range2.start(), IpAddr::from_str("192.168.1.100").unwrap()); - assert_eq!(range2.end(), IpAddr::from_str("192.168.1.100").unwrap()); - } else { - panic!("Expected Address::Range"); - } + // Should not merge since there's a gap + assert_eq!(result, addrs); + + let addrs = vec![ + Address::Range( + IpAddrRange::new( + IpAddr::from_str("192.168.1.10").unwrap(), + IpAddr::from_str("192.168.1.20").unwrap(), + ) + .unwrap(), + ), + Address::Network(IpNetwork::from_str("192.168.1.100/32").unwrap()), + ]; + let result = merge_addrs(addrs.clone()).unwrap(); + + // Should not merge since there's a gap + assert_eq!(result, addrs); } } diff --git a/src/enterprise/firewall/nftables/netfilter.rs b/src/enterprise/firewall/nftables/netfilter.rs index 5be57aeb..b5d8505c 100644 --- a/src/enterprise/firewall/nftables/netfilter.rs +++ b/src/enterprise/firewall/nftables/netfilter.rs @@ -1,12 +1,8 @@ -#[cfg(test)] -use std::str::FromStr; use std::{ ffi::{CStr, CString}, net::{IpAddr, Ipv4Addr, Ipv6Addr}, }; -#[cfg(test)] -use ipnetwork::{Ipv4Network, Ipv6Network}; use nftnl::{ expr::{Expression, InterfaceName}, nft_expr, nftnl_sys,