diff --git a/src/enterprise/firewall/nftables/mod.rs b/src/enterprise/firewall/nftables/mod.rs
index fd529231..ec3b1976 100644
--- a/src/enterprise/firewall/nftables/mod.rs
+++ b/src/enterprise/firewall/nftables/mod.rs
@@ -13,6 +13,7 @@ use nftnl::Batch;
use super::{
api::{FirewallApi, FirewallManagementApi},
+ iprange::IpAddrRangeError,
Address, FirewallError, FirewallRule, Policy, Port, Protocol,
};
use crate::enterprise::firewall::iprange::IpAddrRange;
@@ -55,80 +56,68 @@ struct FilterRule<'a> {
negated_iifname: bool,
}
-fn merge_addrs<'a>(addrs: &'a [Address]) -> Vec
{
- debug!("merge_addrs called with input: {:?}", addrs);
+/// Merges any contiguous subets or addres ranges into an address range.
+///
+/// This reflects the way `nft` CLI handles such cases.
+/// Otherwise first address in any subnet after the first is not matched.
+/// For example if we use `172.30.0.2/31, 172.30.0.4/31` as `saddr` in a rule,
+/// then 172.30.0.4 will not be matched.
+fn merge_addrs(addrs: Vec) -> Result, IpAddrRangeError> {
+ debug!(
+ "Merging any contiguous subnets found within address list: {:?}",
+ addrs
+ );
if addrs.is_empty() {
debug!("No addresses provided, returning empty vector.");
- return Vec::new();
+ return Ok(Vec::new());
}
let mut merged_addrs = Vec::new();
-
- // sort them by their .first() address
- let mut addrs_sorted = Vec::from_iter(addrs.iter());
- addrs_sorted.sort_by(|a, b| {
- a.first()
- .partial_cmp(&b.first())
- .unwrap_or(std::cmp::Ordering::Equal)
- });
-
- debug!("Sorted addresses: {:?}", addrs_sorted);
-
- let first_addr = addrs_sorted.remove(0);
- let mut current_range_start = first_addr.first();
- let mut current_range_end = first_addr.last();
-
- debug!(
- "Starting merge loop with initial range: {:?} - {:?}",
- current_range_start, current_range_end
- );
-
- for addr in addrs_sorted {
- let first_addr = addr.first();
- let last_addr = addr.last();
-
- debug!(
- "Checking addr: {:?} - {:?} against current_range_end: {:?}",
- first_addr, last_addr, current_range_end
- );
-
- // Check if ranges overlap or are adjacent
- if first_addr <= current_range_end || next_ip(current_range_end) == first_addr {
- // Ranges overlap or are adjacent, merge them
- if last_addr > current_range_end {
- debug!(
- "Extending current_range_end from {:?} to {:?}",
- current_range_end, last_addr
- );
- current_range_end = last_addr;
+ let mut current_address = None;
+
+ // we can assume addresses coming from the core
+ // are already sorted and non-overlapping
+ for next_address in addrs {
+ match ¤t_address {
+ None => {
+ debug!("Initializing current address with: {next_address:?}");
+ current_address = Some(next_address);
+ }
+ Some(previous_address) => {
+ let previous_range_start = previous_address.first();
+ let previous_range_end = previous_address.last();
+ let next_ip = next_ip(previous_range_end);
+
+ let next_range_start = next_address.first();
+ let next_range_end = next_address.last();
+
+ // check if range is adjacent to current address
+ if next_range_start == next_ip {
+ // replace current address with a combined range
+ debug!("Merging {next_address:?} with {current_address:?}");
+ current_address = Some(Address::Range(IpAddrRange::new(
+ previous_range_start,
+ next_range_end,
+ )?));
+ } else {
+ // push previous address to result and replace with next address
+ merged_addrs.push(previous_address.clone());
+ current_address = Some(next_address);
+ };
}
- } else {
- // Ranges don't overlap and aren't adjacent, push current range and start new one
- debug!(
- "Pushing merged range: {:?} - {:?}",
- current_range_start, current_range_end
- );
- merged_addrs.push(Address::Range(
- IpAddrRange::new(current_range_start, current_range_end).unwrap(),
- ));
- current_range_start = first_addr;
- current_range_end = last_addr;
}
}
- // Push the last range
- debug!(
- "Pushing final merged range: {:?} - {:?}",
- current_range_start, current_range_end
- );
- merged_addrs.push(Address::Range(
- IpAddrRange::new(current_range_start, current_range_end).unwrap(),
- ));
+ // push last remaining address to results
+ if let Some(address) = current_address {
+ debug!("Pushing last remaining address into results: {address:?}");
+ merged_addrs.push(address)
+ }
debug!("Prepared addresses: {:?}", merged_addrs);
- merged_addrs
+ Ok(merged_addrs)
}
/// Returns the next IP address in sequence, handling overflow via wrapping
@@ -174,8 +163,8 @@ impl FirewallApi {
destination ports and protocols."
);
- let source_addrs = merge_addrs(&rule.source_addrs);
- let dest_addrs = merge_addrs(&rule.destination_addrs);
+ let source_addrs = merge_addrs(rule.source_addrs)?;
+ let dest_addrs = merge_addrs(rule.destination_addrs)?;
if rule.destination_ports.is_empty() {
debug!(
@@ -377,7 +366,6 @@ mod tests {
use ipnetwork::IpNetwork;
use super::*;
- use crate::proto::enterprise::firewall::FirewallRule as ProtoFirewallRule;
#[test]
fn test_sorting() {
@@ -408,13 +396,13 @@ mod tests {
]
);
- let _prepared_addrs = merge_addrs(&addrs);
+ let _prepared_addrs = merge_addrs(addrs).unwrap();
}
#[test]
fn test_merge_addrs_empty() {
let addrs: Vec = vec![];
- let result = merge_addrs(&addrs);
+ let result = merge_addrs(addrs).unwrap();
assert!(result.is_empty());
}
@@ -423,44 +411,9 @@ mod tests {
let addrs = vec![Address::Network(
IpNetwork::from_str("192.168.1.10/32").unwrap(),
)];
- let result = merge_addrs(&addrs);
+ let result = merge_addrs(addrs.clone()).unwrap();
- assert_eq!(result.len(), 1);
- if let Address::Range(range) = &result[0] {
- assert_eq!(range.start(), IpAddr::from_str("192.168.1.10").unwrap());
- assert_eq!(range.end(), IpAddr::from_str("192.168.1.10").unwrap());
- } else {
- panic!("Expected Address::Range");
- }
- }
-
- #[test]
- fn test_merge_addrs_overlapping_ranges() {
- let addrs = vec![
- Address::Range(
- IpAddrRange::new(
- IpAddr::from_str("192.168.1.10").unwrap(),
- IpAddr::from_str("192.168.1.20").unwrap(),
- )
- .unwrap(),
- ),
- Address::Range(
- IpAddrRange::new(
- IpAddr::from_str("192.168.1.15").unwrap(),
- IpAddr::from_str("192.168.1.25").unwrap(),
- )
- .unwrap(),
- ),
- ];
- let result = merge_addrs(&addrs);
-
- assert_eq!(result.len(), 1);
- if let Address::Range(range) = &result[0] {
- assert_eq!(range.start(), IpAddr::from_str("192.168.1.10").unwrap());
- assert_eq!(range.end(), IpAddr::from_str("192.168.1.25").unwrap());
- } else {
- panic!("Expected Address::Range");
- }
+ assert_eq!(result, addrs);
}
#[test]
@@ -481,7 +434,7 @@ mod tests {
.unwrap(),
),
];
- let result = merge_addrs(&addrs);
+ let result = merge_addrs(addrs).unwrap();
assert_eq!(result.len(), 1);
if let Address::Range(range) = &result[0] {
@@ -499,7 +452,7 @@ mod tests {
Address::Network(IpNetwork::from_str("192.168.1.11/32").unwrap()),
Address::Network(IpNetwork::from_str("192.168.1.12/32").unwrap()),
];
- let result = merge_addrs(&addrs);
+ let result = merge_addrs(addrs).unwrap();
assert_eq!(result.len(), 1);
if let Address::Range(range) = &result[0] {
@@ -528,7 +481,7 @@ mod tests {
.unwrap(),
),
];
- let result = merge_addrs(&addrs);
+ let result = merge_addrs(addrs).unwrap();
assert_eq!(result.len(), 2);
if let Address::Range(range1) = &result[0] {
@@ -558,7 +511,7 @@ mod tests {
),
Address::Network(IpNetwork::from_str("192.168.1.16/32").unwrap()),
];
- let result = merge_addrs(&addrs);
+ let result = merge_addrs(addrs).unwrap();
assert_eq!(result.len(), 1);
if let Address::Range(range) = &result[0] {
@@ -569,58 +522,29 @@ mod tests {
}
}
- #[test]
- fn test_merge_addrs_unsorted_input() {
- let addrs = vec![
- Address::Network(IpNetwork::from_str("192.168.1.13/32").unwrap()),
- Address::Network(IpNetwork::from_str("192.168.1.10/32").unwrap()),
- Address::Network(IpNetwork::from_str("192.168.1.12/32").unwrap()),
- Address::Network(IpNetwork::from_str("192.168.1.11/32").unwrap()),
- ];
- let result = merge_addrs(&addrs);
-
- assert_eq!(result.len(), 1);
- if let Address::Range(range) = &result[0] {
- assert_eq!(range.start(), IpAddr::from_str("192.168.1.10").unwrap());
- assert_eq!(range.end(), IpAddr::from_str("192.168.1.13").unwrap());
- } else {
- panic!("Expected Address::Range");
- }
- }
-
#[test]
fn test_merge_addrs_non_adjacent_singles() {
let addrs = vec![
- Address::Network(IpNetwork::from_str("192.168.1.20/32").unwrap()),
Address::Network(IpNetwork::from_str("192.168.1.10/32").unwrap()),
- Address::Network(IpNetwork::from_str("192.168.1.15/32").unwrap()),
Address::Network(IpNetwork::from_str("192.168.1.11/32").unwrap()),
+ Address::Network(IpNetwork::from_str("192.168.1.15/32").unwrap()),
+ Address::Network(IpNetwork::from_str("192.168.1.20/32").unwrap()),
];
- let result = merge_addrs(&addrs);
+ let result = merge_addrs(addrs).unwrap();
// These should result in 3 separate ranges: 10-11, 15, 20
- assert_eq!(result.len(), 3);
-
- if let Address::Range(range1) = &result[0] {
- assert_eq!(range1.start(), IpAddr::from_str("192.168.1.10").unwrap());
- assert_eq!(range1.end(), IpAddr::from_str("192.168.1.11").unwrap());
- } else {
- panic!("Expected Address::Range");
- }
-
- if let Address::Range(range2) = &result[1] {
- assert_eq!(range2.start(), IpAddr::from_str("192.168.1.15").unwrap());
- assert_eq!(range2.end(), IpAddr::from_str("192.168.1.15").unwrap());
- } else {
- panic!("Expected Address::Range");
- }
-
- if let Address::Range(range3) = &result[2] {
- assert_eq!(range3.start(), IpAddr::from_str("192.168.1.20").unwrap());
- assert_eq!(range3.end(), IpAddr::from_str("192.168.1.20").unwrap());
- } else {
- panic!("Expected Address::Range");
- }
+ let expected_addrs = vec![
+ Address::Range(
+ IpAddrRange::new(
+ IpAddr::from_str("192.168.1.10").unwrap(),
+ IpAddr::from_str("192.168.1.11").unwrap(),
+ )
+ .unwrap(),
+ ),
+ Address::Network(IpNetwork::from_str("192.168.1.15/32").unwrap()),
+ Address::Network(IpNetwork::from_str("192.168.1.20/32").unwrap()),
+ ];
+ assert_eq!(result, expected_addrs);
}
#[test]
@@ -630,7 +554,7 @@ mod tests {
Address::Network(IpNetwork::from_str("2001:db8::2/128").unwrap()),
Address::Network(IpNetwork::from_str("2001:db8::3/128").unwrap()),
];
- let result = merge_addrs(&addrs);
+ let result = merge_addrs(addrs).unwrap();
assert_eq!(result.len(), 1);
if let Address::Range(range) = &result[0] {
@@ -641,35 +565,6 @@ mod tests {
}
}
- #[test]
- fn test_merge_addrs_contained_ranges() {
- let addrs = vec![
- Address::Range(
- IpAddrRange::new(
- IpAddr::from_str("192.168.1.10").unwrap(),
- IpAddr::from_str("192.168.1.30").unwrap(),
- )
- .unwrap(),
- ),
- Address::Range(
- IpAddrRange::new(
- IpAddr::from_str("192.168.1.15").unwrap(),
- IpAddr::from_str("192.168.1.20").unwrap(),
- )
- .unwrap(),
- ),
- ];
- let result = merge_addrs(&addrs);
-
- assert_eq!(result.len(), 1);
- if let Address::Range(range) = &result[0] {
- assert_eq!(range.start(), IpAddr::from_str("192.168.1.10").unwrap());
- assert_eq!(range.end(), IpAddr::from_str("192.168.1.30").unwrap());
- } else {
- panic!("Expected Address::Range");
- }
- }
-
#[test]
fn test_next_ip_ipv4() {
assert_eq!(
@@ -699,26 +594,29 @@ mod tests {
}
#[test]
- fn test_merge_addrs_large_gap() {
+ fn test_merge_addrs_gap() {
let addrs = vec![
Address::Network(IpNetwork::from_str("192.168.1.1/32").unwrap()),
Address::Network(IpNetwork::from_str("192.168.1.100/32").unwrap()),
];
- let result = merge_addrs(&addrs);
+ let result = merge_addrs(addrs.clone()).unwrap();
- // Should not merge since there's a large gap
- assert_eq!(result.len(), 2);
- if let Address::Range(range1) = &result[0] {
- assert_eq!(range1.start(), IpAddr::from_str("192.168.1.1").unwrap());
- assert_eq!(range1.end(), IpAddr::from_str("192.168.1.1").unwrap());
- } else {
- panic!("Expected Address::Range");
- }
- if let Address::Range(range2) = &result[1] {
- assert_eq!(range2.start(), IpAddr::from_str("192.168.1.100").unwrap());
- assert_eq!(range2.end(), IpAddr::from_str("192.168.1.100").unwrap());
- } else {
- panic!("Expected Address::Range");
- }
+ // Should not merge since there's a gap
+ assert_eq!(result, addrs);
+
+ let addrs = vec![
+ Address::Range(
+ IpAddrRange::new(
+ IpAddr::from_str("192.168.1.10").unwrap(),
+ IpAddr::from_str("192.168.1.20").unwrap(),
+ )
+ .unwrap(),
+ ),
+ Address::Network(IpNetwork::from_str("192.168.1.100/32").unwrap()),
+ ];
+ let result = merge_addrs(addrs.clone()).unwrap();
+
+ // Should not merge since there's a gap
+ assert_eq!(result, addrs);
}
}
diff --git a/src/enterprise/firewall/nftables/netfilter.rs b/src/enterprise/firewall/nftables/netfilter.rs
index 5be57aeb..b5d8505c 100644
--- a/src/enterprise/firewall/nftables/netfilter.rs
+++ b/src/enterprise/firewall/nftables/netfilter.rs
@@ -1,12 +1,8 @@
-#[cfg(test)]
-use std::str::FromStr;
use std::{
ffi::{CStr, CString},
net::{IpAddr, Ipv4Addr, Ipv6Addr},
};
-#[cfg(test)]
-use ipnetwork::{Ipv4Network, Ipv6Network};
use nftnl::{
expr::{Expression, InterfaceName},
nft_expr, nftnl_sys,