diff --git a/Cargo.lock b/Cargo.lock index 3086429d..5db96863 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -38,9 +38,9 @@ dependencies = [ [[package]] name = "anstream" -version = "0.6.21" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" +checksum = "824a212faf96e9acacdbd09febd34438f8f711fb84e09a8916013cd7815ca28d" dependencies = [ "anstyle", "anstyle-parse", @@ -53,15 +53,15 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.13" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" +checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000" [[package]] name = "anstyle-parse" -version = "0.2.7" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" +checksum = "52ce7f38b242319f7cabaa6813055467063ecdc9d355bbb4ce0c68908cd8130e" dependencies = [ "utf8parse", ] @@ -131,7 +131,7 @@ dependencies = [ "memchr", "serde", "serde_derive", - "winnow", + "winnow 0.7.15", ] [[package]] @@ -303,9 +303,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.56" +version = "1.2.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aebf35691d1bfb0ac386a69bac2fde4dd276fb618cf8bf4f5318fe285e821bb2" +checksum = "e1e928d4b69e3077709075a938a05ffbedfa53a84c8f766efbf8220bb1ff60e1" dependencies = [ "find-msvc-tools", "jobserver", @@ -362,9 +362,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.60" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a" +checksum = "b193af5b67834b676abd72466a96c1024e6a6ad978a1f484bd90b85c94041351" dependencies = [ "clap_builder", "clap_derive", @@ -372,9 +372,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.60" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876" +checksum = "714a53001bf66416adb0e2ef5ac857140e7dc3a0c48fb28b2f10762fc4b5069f" dependencies = [ "anstream", "anstyle", @@ -384,9 +384,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.55" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a92793da1a46a5f2a02a6f4c46c6496b28c43638adea8306fcb0caa1634f24e5" +checksum = "1110bd8a634a1ab8cb04345d8d878267d57c3cf1b38d91b71af6686408bbca6a" dependencies = [ "heck", "proc-macro2", @@ -396,15 +396,15 @@ dependencies = [ [[package]] name = "clap_lex" -version = "1.0.0" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831" +checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9" [[package]] name = "colorchoice" -version = "1.0.4" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" +checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570" [[package]] name = "core-foundation" @@ -514,7 +514,7 @@ dependencies = [ [[package]] name = "defguard-gateway" -version = "1.6.3" +version = "1.6.4" dependencies = [ "axum", "base64", @@ -535,7 +535,7 @@ dependencies = [ "thiserror 2.0.18", "tokio", "tokio-stream", - "toml", + "toml 1.1.0+spec-1.1.0", "tonic", "tonic-prost", "tonic-prost-build", @@ -694,9 +694,9 @@ checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" [[package]] name = "env_filter" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a1c3cc8e57274ec99de65301228b537f1e4eedc1b8e0f9411c6caac8ae7308f" +checksum = "32e90c2accc4b07a8456ea0debdc2e7587bdd890680d71173a15d4ae604f6eef" dependencies = [ "log", "regex", @@ -704,9 +704,9 @@ dependencies = [ [[package]] name = "env_logger" -version = "0.11.9" +version = "0.11.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2daee4ea451f429a58296525ddf28b45a3b64f1acf6587e2067437bb11e218d" +checksum = "0621c04f2196ac3f488dd583365b9c09be011a4ab8b9f37248ffcc8f6198b56a" dependencies = [ "anstream", "anstyle", @@ -1273,9 +1273,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" [[package]] name = "jiff" @@ -1612,9 +1612,9 @@ dependencies = [ [[package]] name = "num-conv" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050" +checksum = "c6673768db2d862beb9b39a78fdcb1a69439615d5794a1be50caa9bc92c81967" [[package]] name = "num_threads" @@ -1939,9 +1939,9 @@ checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" [[package]] name = "portable-atomic-util" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a9db96d7fa8782dd8c15ce32ffe8680bbd1e978a43bf51a34d39483540495f5" +checksum = "091397be61a01d4be58e7841595bd4bfedb15f1cd54977d79b8271e94ed799a3" dependencies = [ "portable-atomic", ] @@ -2035,9 +2035,9 @@ dependencies = [ [[package]] name = "pulldown-cmark" -version = "0.13.1" +version = "0.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83c41efbf8f90ac44de7f3a868f0867851d261b56291732d0cbf7cceaaeb55a6" +checksum = "7c3a14896dfa883796f1cb410461aef38810ea05f2b2c33c5aded3649095fdad" dependencies = [ "bitflags", "memchr", @@ -2201,9 +2201,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.103.9" +version = "0.103.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53" +checksum = "df33b2b81ac578cabaf06b89b0631153a3f416b0a886e8a7a1707fb51abbd1ef" dependencies = [ "ring", "rustls-pki-types", @@ -2346,9 +2346,9 @@ dependencies = [ [[package]] name = "serde_spanned" -version = "1.0.4" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8bbf91e5a4d6315eee45e704372590b30e260ee83af6639d64557f51b067776" +checksum = "876ac351060d4f882bb1032b6369eb0aef79ad9df1ea8bc404874d8cc3d0cd98" dependencies = [ "serde_core", ] @@ -2392,9 +2392,9 @@ dependencies = [ [[package]] name = "simd-adler32" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" +checksum = "703d5c7ef118737c72f1af64ad2f6f8c5e1921f818cdcb97b8fe6fc69bf66214" [[package]] name = "siphasher" @@ -2678,10 +2678,23 @@ dependencies = [ "indexmap", "serde_core", "serde_spanned", - "toml_datetime", + "toml_datetime 0.7.5+spec-1.1.0", "toml_parser", "toml_writer", - "winnow", + "winnow 0.7.15", +] + +[[package]] +name = "toml" +version = "1.1.0+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8195ca05e4eb728f4ba94f3e3291661320af739c4e43779cbdfae82ab239fcc" +dependencies = [ + "serde_core", + "serde_spanned", + "toml_datetime 1.1.0+spec-1.1.0", + "toml_parser", + "winnow 1.0.0", ] [[package]] @@ -2693,20 +2706,29 @@ dependencies = [ "serde_core", ] +[[package]] +name = "toml_datetime" +version = "1.1.0+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97251a7c317e03ad83774a8752a7e81fb6067740609f75ea2b585b569a59198f" +dependencies = [ + "serde_core", +] + [[package]] name = "toml_parser" -version = "1.0.9+spec-1.1.0" +version = "1.1.0+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "702d4415e08923e7e1ef96cd5727c0dfed80b4d2fa25db9647fe5eb6f7c5a4c4" +checksum = "2334f11ee363607eb04df9b8fc8a13ca1715a72ba8662a26ac285c98aabb4011" dependencies = [ - "winnow", + "winnow 1.0.0", ] [[package]] name = "toml_writer" -version = "1.0.6+spec-1.1.0" +version = "1.1.0+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab16f14aed21ee8bfd8ec22513f7287cd4a91aa92e44edfe2c17ddd004e92607" +checksum = "d282ade6016312faf3e41e57ebbba0c073e4056dab1232ab1cb624199648f8ed" [[package]] name = "tonic" @@ -2856,9 +2878,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.22" +version = "0.3.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" +checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319" dependencies = [ "matchers", "nu-ansi-term", @@ -2938,7 +2960,7 @@ dependencies = [ "serde", "tempfile", "textwrap", - "toml", + "toml 0.9.12+spec-1.1.0", "uniffi_internal_macros", "uniffi_meta", "uniffi_pipeline", @@ -2994,7 +3016,7 @@ dependencies = [ "quote", "serde", "syn", - "toml", + "toml 0.9.12+spec-1.1.0", "uniffi_meta", ] @@ -3424,6 +3446,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "winnow" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a90e88e4667264a994d34e6d1ab2d26d398dcdca8b7f52bec8668957517fc7d8" + [[package]] name = "wireguard-nt" version = "0.5.0" diff --git a/Cargo.toml b/Cargo.toml index 3f2497ce..bee0126c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "defguard-gateway" -version = "1.6.3" +version = "1.6.4" edition = "2024" [dependencies] @@ -20,7 +20,7 @@ syslog = "7.0" thiserror = "2.0" tokio = { version = "1", features = ["macros", "rt-multi-thread", "signal"] } tokio-stream = { version = "0.1", features = [] } -toml = { version = "0.9", default-features = false, features = [ +toml = { version = "1.1", default-features = false, features = [ "parse", "serde", ] } diff --git a/src/enterprise/firewall/api.rs b/src/enterprise/firewall/api.rs index 15b29691..777d6dbb 100644 --- a/src/enterprise/firewall/api.rs +++ b/src/enterprise/firewall/api.rs @@ -1,27 +1,35 @@ #[cfg(any(target_os = "freebsd", target_os = "macos", target_os = "netbsd"))] use std::fs::{File, OpenOptions}; -#[cfg(target_os = "linux")] -use nftnl::Batch; - use super::{FirewallError, FirewallRule, Policy, SnatBinding}; -#[cfg(any(target_os = "freebsd", target_os = "macos", target_os = "netbsd"))] +#[cfg(all( + test, + any(target_os = "freebsd", target_os = "macos", target_os = "netbsd") +))] +const DEV_PF: &str = "/dev/null"; +#[cfg(all( + not(test), + any(target_os = "freebsd", target_os = "macos", target_os = "netbsd") +))] const DEV_PF: &str = "/dev/pf"; #[allow(dead_code)] -pub struct FirewallApi { +pub(crate) struct FirewallApi { pub(crate) ifname: String, #[cfg(any(target_os = "freebsd", target_os = "macos", target_os = "netbsd"))] pub(crate) file: File, #[cfg(any(target_os = "freebsd", target_os = "macos", target_os = "netbsd"))] pub(crate) default_policy: Policy, #[cfg(target_os = "linux")] - pub(crate) batch: Option, + pub(crate) socket: mnl::Socket, } impl FirewallApi { - pub fn new>(ifname: S) -> Result { + pub(crate) fn new(ifname: S) -> Result + where + S: Into, + { Ok(Self { ifname: ifname.into(), #[cfg(any(target_os = "freebsd", target_os = "macos", target_os = "netbsd"))] @@ -29,7 +37,9 @@ impl FirewallApi { #[cfg(any(target_os = "freebsd", target_os = "macos", target_os = "netbsd"))] default_policy: Policy::Deny, #[cfg(target_os = "linux")] - batch: None, + socket: mnl::Socket::new(mnl::Bus::Netfilter).map_err(|err| { + FirewallError::NetlinkError(format!("Failed to create socket: {err:?}")) + })?, }) } } @@ -43,7 +53,7 @@ pub(crate) trait FirewallManagementApi { fn cleanup(&mut self) -> Result<(), FirewallError>; /// Add firewall rules. - fn add_rules(&mut self, rules: Vec) -> Result<(), FirewallError>; + fn add_rules(&mut self, rules: &[FirewallRule]) -> Result<(), FirewallError>; /// Setup Network Address Translation using POSTROUTING chain rules fn setup_nat( @@ -51,10 +61,4 @@ pub(crate) trait FirewallManagementApi { masquerade_enabled: bool, snat_bindings: &[SnatBinding], ) -> Result<(), FirewallError>; - - /// Begin rule transaction. - fn begin(&mut self) -> Result<(), FirewallError>; - - /// Commit rule transaction. - fn commit(&mut self) -> Result<(), FirewallError>; } diff --git a/src/enterprise/firewall/iprange.rs b/src/enterprise/firewall/iprange.rs index b1b4acd6..6242ab5b 100644 --- a/src/enterprise/firewall/iprange.rs +++ b/src/enterprise/firewall/iprange.rs @@ -24,10 +24,10 @@ pub enum IpAddrRangeError { impl fmt::Display for IpAddrRangeError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::MixedTypes => write!(f, "mixed IPv4 and IPv6 addresses"), - Self::WrongOrder => write!(f, "wrong order: higher address preceeds lower"), - } + f.write_str(match self { + Self::MixedTypes => "mixed IPv4 and IPv6 addresses", + Self::WrongOrder => "wrong order: higher address preceeds lower", + }) } } diff --git a/src/enterprise/firewall/mod.rs b/src/enterprise/firewall/mod.rs index 2e4b634e..dee55ab4 100644 --- a/src/enterprise/firewall/mod.rs +++ b/src/enterprise/firewall/mod.rs @@ -177,7 +177,7 @@ impl fmt::Display for Protocol { Self::Udp => "udp", Self::IcmpV6 => "icmp6", }; - write!(f, "{protocol}") + f.write_str(protocol) } } diff --git a/src/enterprise/firewall/nftables/mod.rs b/src/enterprise/firewall/nftables/mod.rs index 476cde47..1ce2b007 100644 --- a/src/enterprise/firewall/nftables/mod.rs +++ b/src/enterprise/firewall/nftables/mod.rs @@ -1,4 +1,4 @@ -pub mod netfilter; +mod netfilter; use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr}, @@ -134,11 +134,8 @@ fn next_ip(ip: IpAddr) -> IpAddr { } impl FirewallApi { - fn add_rule(&mut self, rule: FirewallRule) -> Result<(), FirewallError> { + fn add_rule(&mut self, rule: &FirewallRule) -> Result<(), FirewallError> { debug!("Applying the following Defguard ACL rule: {rule:?}"); - let Some(ref mut batch) = self.batch else { - return Err(FirewallError::TransactionNotStarted); - }; let mut filter_rules = Vec::new(); debug!( @@ -146,8 +143,8 @@ impl FirewallApi { destination ports and protocols." ); - let source_addrs = merge_addrs(rule.source_addrs)?; - let dest_addrs = merge_addrs(rule.destination_addrs)?; + let source_addrs = merge_addrs(rule.source_addrs.clone())?; + let dest_addrs = merge_addrs(rule.destination_addrs.clone())?; if rule.destination_ports.is_empty() { debug!( @@ -218,7 +215,7 @@ impl FirewallApi { } } - apply_filter_rules(&filter_rules, batch, &self.ifname)?; + apply_filter_rules(&filter_rules, &self.ifname, &self.socket)?; debug!( "Applied firewall rules for Defguard ACL rule ID: {}", @@ -230,37 +227,32 @@ impl FirewallApi { impl FirewallManagementApi for FirewallApi { /// Sets up the firewall with the given default policy and priority. Drops the previous table. - /// - /// This function also begins a batch of operations which can be applied later using the [`apply`] method. - /// This allows for making atomic changes to the firewall rules. fn setup( &mut self, default_policy: Policy, priority: Option, ) -> Result<(), FirewallError> { debug!("Initializing firewall, VPN interface: {}", self.ifname); - if let Some(batch) = &mut self.batch { - drop_table(batch, &self.ifname); - init_firewall(default_policy, priority, batch, &self.ifname); - debug!("Allowing all established traffic"); - ignore_unrelated_traffic(batch, &self.ifname)?; - allow_established_traffic(batch, &self.ifname)?; - debug!("Allowed all established traffic"); - debug!("Initialized firewall"); - Ok(()) - } else { - Err(FirewallError::TransactionNotStarted) - } + let mut batch = Batch::new(); + drop_table(&mut batch, &self.ifname); + init_firewall(default_policy, priority, &mut batch, &self.ifname); + debug!("Allowing all established traffic"); + ignore_unrelated_traffic(&mut batch, &self.ifname)?; + allow_established_traffic(&mut batch, &self.ifname)?; + send_batch(&batch.finalize(), &self.socket)?; + debug!("Allowed all established traffic"); + debug!("Initialized firewall"); + + Ok(()) } /// Cleans up the whole Defguard table. fn cleanup(&mut self) -> Result<(), FirewallError> { debug!("Cleaning up all previous firewall rules, if any"); - if let Some(batch) = &mut self.batch { - drop_table(batch, &self.ifname); - } else { - return Err(FirewallError::TransactionNotStarted); - } + let mut batch = Batch::new(); + drop_table(&mut batch, &self.ifname); + send_batch(&batch.finalize(), &self.socket)?; + debug!("Cleaned up all previous firewall rules"); Ok(()) } @@ -271,20 +263,18 @@ impl FirewallManagementApi for FirewallApi { snat_bindings: &[SnatBinding], ) -> Result<(), FirewallError> { debug!( - "Setting up POSTROUTING chain rules with masquerade status: {masquerade_enabled} and SNAT bindings: {snat_bindings:?}" + "Setting up POSTROUTING chain rules with masquerade status: {masquerade_enabled} and \ + SNAT bindings: {snat_bindings:?}" ); - if let Some(batch) = &mut self.batch { - set_nat_rules(batch, &self.ifname, masquerade_enabled, snat_bindings)?; - } else { - return Err(FirewallError::TransactionNotStarted); - } + let mut batch = Batch::new(); + set_nat_rules(&mut batch, &self.ifname, masquerade_enabled, snat_bindings)?; debug!("Finished POSTROUTING chain rules setup"); Ok(()) } - fn add_rules(&mut self, rules: Vec) -> Result<(), FirewallError> { + fn add_rules(&mut self, rules: &[FirewallRule]) -> Result<(), FirewallError> { debug!("Applying the following Defguard ACL rules: {rules:?}"); for rule in rules { self.add_rule(rule)?; @@ -292,36 +282,6 @@ impl FirewallManagementApi for FirewallApi { debug!("Applied all Defguard ACL rules"); Ok(()) } - - fn begin(&mut self) -> Result<(), FirewallError> { - if self.batch.is_none() { - debug!("Starting new firewall transaction"); - let batch = Batch::new(); - self.batch = Some(batch); - debug!("Firewall transaction successfully started"); - Ok(()) - } else { - Err(FirewallError::TransactionFailed( - "There is another firewall transaction already in progress. Commit or \ - rollback it before starting a new one." - .to_string(), - )) - } - } - - /// Apply whole firewall configuration and send it in one go to the kernel. - fn commit(&mut self) -> Result<(), FirewallError> { - if let Some(batch) = self.batch.take() { - debug!("Committing firewall transaction"); - let finalized = batch.finalize(); - debug!("Firewall batch finalized, sending to kernel"); - send_batch(&finalized)?; - debug!("Firewall transaction successfully committed to kernel"); - Ok(()) - } else { - Err(FirewallError::TransactionNotStarted) - } - } } #[cfg(test)] @@ -366,7 +326,7 @@ mod tests { #[test] fn test_merge_addrs_empty() { - let addrs: Vec
= vec![]; + let addrs = Vec::new(); let result = merge_addrs(addrs).unwrap(); assert!(result.is_empty()); } diff --git a/src/enterprise/firewall/nftables/netfilter.rs b/src/enterprise/firewall/nftables/netfilter.rs index 2c7f6fcd..ba7d3e24 100644 --- a/src/enterprise/firewall/nftables/netfilter.rs +++ b/src/enterprise/firewall/nftables/netfilter.rs @@ -688,18 +688,15 @@ pub(super) fn set_nat_rules( batch.add(&snat_rule, MsgType::Add); } - // add MASQUERADE rule - let masquerade_rule = MasqueradeRule { - oifname: LOOPBACK_IFACE.to_string(), - negated_oifname: true, - counter: true, - } - .to_chain_rule(&nat_chain, batch)?; - if masquerade_enabled { + let masquerade_rule = MasqueradeRule { + oifname: LOOPBACK_IFACE.to_string(), + negated_oifname: true, + counter: true, + } + .to_chain_rule(&nat_chain, batch)?; + batch.add(&masquerade_rule, MsgType::Add); - } else { - batch.add(&masquerade_rule, MsgType::Del); } Ok(()) @@ -823,26 +820,28 @@ impl Chains { pub(super) fn apply_filter_rules( rules: &[FilterRule], - batch: &mut Batch, ifname: &str, + socket: &mnl::Socket, ) -> Result<(), FirewallError> { let table = Tables::Defguard(ProtoFamily::Inet).to_table(ifname); + let mut batch = Batch::new(); batch.add(&table, MsgType::Add); let forward_chain = Chains::Forward.to_chain(&table); batch.add(&forward_chain, MsgType::Add); for rule in rules { - let chain_rule = rule.to_chain_rule(&forward_chain, batch)?; + let chain_rule = rule.to_chain_rule(&forward_chain, &mut batch)?; batch.add(&chain_rule, MsgType::Add); } - Ok(()) + send_batch(&batch.finalize(), socket) } -pub(crate) fn send_batch(batch: &FinalizedBatch) -> Result<(), FirewallError> { - let socket = mnl::Socket::new(mnl::Bus::Netfilter) - .map_err(|err| FirewallError::NetlinkError(format!("Failed to create socket: {err:?}")))?; +pub(crate) fn send_batch( + batch: &FinalizedBatch, + socket: &mnl::Socket, +) -> Result<(), FirewallError> { socket.send_all(batch).map_err(|err| { FirewallError::NetlinkError(format!("Failed to send batch through socket: {err:?}")) })?; @@ -851,29 +850,31 @@ pub(crate) fn send_batch(batch: &FinalizedBatch) -> Result<(), FirewallError> { let mut buffer = vec![0; nft_nlmsg_maxsize() as usize]; let mut expected_seqs = batch.sequence_numbers(); - for message in socket.recv(&mut buffer).map_err(|err| { - FirewallError::NetlinkError(format!("Failed reading message from socket: {err:?}")) - })? { - let Ok(message) = message else { - warn!("Invalid netlink message"); - continue; - }; - let Some(seq) = expected_seqs.next() else { - warn!("Unexpected ACK in netlink messages"); - continue; - }; - match mnl::cb_run(message, seq, portid) { - Ok(mnl::CbResult::Stop) => { - debug!("Received stop signal from netlink callback"); - break; - } - Ok(mnl::CbResult::Ok) => { - debug!("Received OK signal from netlink callback"); - } - Err(err) => { - return Err(FirewallError::NetlinkError(format!( - "There was an error while sending netlink messages: {err:?}" - ))); + while !expected_seqs.is_empty() { + for message in socket.recv(&mut buffer).map_err(|err| { + FirewallError::NetlinkError(format!("Failed reading message from socket: {err:?}")) + })? { + let Ok(message) = message else { + warn!("Invalid netlink message"); + continue; + }; + let Some(seq) = expected_seqs.next() else { + warn!("Unexpected ACK in netlink messages"); + continue; + }; + match mnl::cb_run(message, seq, portid) { + Ok(mnl::CbResult::Stop) => { + debug!("Received stop signal from netlink callback"); + break; + } + Ok(mnl::CbResult::Ok) => { + debug!("Received OK signal from netlink callback"); + } + Err(err) => { + return Err(FirewallError::NetlinkError(format!( + "Failed to receive netlink callback: {err:?}" + ))); + } } } } diff --git a/src/enterprise/firewall/packetfilter/api.rs b/src/enterprise/firewall/packetfilter/api.rs index 0bcf6a66..d5d928e7 100644 --- a/src/enterprise/firewall/packetfilter/api.rs +++ b/src/enterprise/firewall/packetfilter/api.rs @@ -26,10 +26,10 @@ impl FirewallManagementApi for FirewallApi { } /// Add firewall `rules`. - fn add_rules(&mut self, rules: Vec) -> Result<(), FirewallError> { + fn add_rules(&mut self, rules: &[FirewallRule]) -> Result<(), FirewallError> { let anchor = &self.anchor(); // Begin transaction. - debug!("Begin pf transaction."); + debug!("Begin pf transaction"); let mut elements = [IocTransElement::new(RuleSet::Filter, anchor)]; let mut ioc_trans = IocTrans::new(elements.as_mut_slice()); // This will create an anchor if it doesn't exist. @@ -42,8 +42,8 @@ impl FirewallManagementApi for FirewallApi { // Create first rule from the default policy. if let Err(err) = self.add_rule_policy(ticket, pool_ticket, anchor) { - error!("Default policy rule can't be added."); - debug!("Rollback pf transaction."); + error!("Default policy rule can't be added"); + debug!("Rollback pf transaction"); // Rule cannot be added, so rollback. unsafe { pf_rollback(self.fd(), &raw mut ioc_trans)?; @@ -51,10 +51,10 @@ impl FirewallManagementApi for FirewallApi { } } - for mut rule in rules { - if let Err(err) = self.add_rule(&mut rule, ticket, pool_ticket, anchor) { - error!("Firewall rule {} can't be added.", &rule.id); - debug!("Rollback pf transaction."); + for rule in rules { + if let Err(err) = self.add_rule(rule, ticket, pool_ticket, anchor) { + error!("Firewall rule {} can't be added", &rule.id); + debug!("Rollback pf transaction"); // Rule cannot be added, so rollback. unsafe { pf_rollback(self.fd(), &raw mut ioc_trans)?; @@ -64,7 +64,7 @@ impl FirewallManagementApi for FirewallApi { } // Commit transaction. - debug!("Commit pf transaction."); + debug!("Commit pf transaction"); unsafe { pf_commit(self.file.as_raw_fd(), &raw mut ioc_trans).unwrap(); } @@ -80,16 +80,4 @@ impl FirewallManagementApi for FirewallApi { ) -> Result<(), FirewallError> { Ok(()) } - - /// Begin rule transaction. - fn begin(&mut self) -> Result<(), FirewallError> { - // TODO: remove this no-op. - Ok(()) - } - - /// Commit rule transaction. - fn commit(&mut self) -> Result<(), FirewallError> { - // TODO: remove this no-op. - Ok(()) - } } diff --git a/src/enterprise/firewall/packetfilter/calls.rs b/src/enterprise/firewall/packetfilter/calls.rs index 499ce16c..e0c85a41 100644 --- a/src/enterprise/firewall/packetfilter/calls.rs +++ b/src/enterprise/firewall/packetfilter/calls.rs @@ -328,6 +328,19 @@ impl Pool { unsafe { uninit.assume_init() } } + + /// Insert `PoolAddr` at the end of the list. Take ownership of the given `PoolAddr`. + pub(super) fn insert_pool_addr(&mut self, pool_addr: &mut PoolAddr) { + // TODO: Traverse tail queue; for now assume empty tail queue. + assert!( + self.list.tqh_first.is_null(), + "Expected one entry in PoolAddr TailQueue." + ); + self.list.tqh_first = &raw mut *pool_addr; + self.list.tqh_last = &raw mut pool_addr.entries.tqe_next; + pool_addr.entries.tqe_next = ptr::null_mut(); + pool_addr.entries.tqe_prev = &raw mut self.list.tqh_first; + } } impl Drop for Pool { @@ -835,13 +848,13 @@ ioctl_readwrite!(pf_rollback, b'D', 83, IocTrans); #[cfg(test)] mod tests { - use ipnetwork::{Ipv4Network, Ipv6Network}; - use std::{ mem::align_of, net::{Ipv4Addr, Ipv6Addr}, }; + use ipnetwork::{Ipv4Network, Ipv6Network}; + use super::*; #[test] diff --git a/src/enterprise/firewall/packetfilter/mod.rs b/src/enterprise/firewall/packetfilter/mod.rs index ea0787b9..3286c8b7 100644 --- a/src/enterprise/firewall/packetfilter/mod.rs +++ b/src/enterprise/firewall/packetfilter/mod.rs @@ -49,7 +49,7 @@ impl FirewallApi { } fn add_rule_policy( - &mut self, + &self, ticket: u32, pool_ticket: u32, anchor: &str, @@ -68,8 +68,8 @@ impl FirewallApi { /// Add a single firewall `rule`. fn add_rule( - &mut self, - rule: &mut FirewallRule, + &self, + rule: &FirewallRule, ticket: u32, pool_ticket: u32, anchor: &str, diff --git a/src/enterprise/firewall/packetfilter/rule.rs b/src/enterprise/firewall/packetfilter/rule.rs index 26458f18..3404c090 100644 --- a/src/enterprise/firewall/packetfilter/rule.rs +++ b/src/enterprise/firewall/packetfilter/rule.rs @@ -53,7 +53,7 @@ impl fmt::Display for Action { Self::Redirect => "rdr", Self::NoRedirect => "block rdr", }; - write!(f, "{action}") + f.write_str(action) } } @@ -87,7 +87,7 @@ impl fmt::Display for Direction { Self::In => "in", Self::Out => "out", }; - write!(f, "{direction}") + f.write_str(direction) } } @@ -146,7 +146,7 @@ impl fmt::Display for State { Self::Modulate => "modulate state", Self::SynProxy => "synproxy state", }; - write!(f, "{state}") + f.write_str(state) } } @@ -237,7 +237,7 @@ impl PacketFilterRule { } /// Expand `FirewallRule` into a set of `PacketFilterRule`s. - pub(super) fn from_firewall_rule(ifname: &str, fr: &mut FirewallRule) -> Vec { + pub(super) fn from_firewall_rule(ifname: &str, fr: &FirewallRule) -> Vec { let mut rules = Vec::new(); let (action, state) = match fr.verdict { Policy::Allow => (Action::Pass, State::Normal), @@ -276,18 +276,23 @@ impl PacketFilterRule { } } - if fr.destination_ports.is_empty() { - fr.destination_ports.push(Port::Any); - } - - if fr.protocols.is_empty() { - fr.protocols.push(Protocol::Any); - } + // Packet filter needs "any port" when ports are absent. + let destination_ports = if fr.destination_ports.is_empty() { + &[Port::Any] + } else { + fr.destination_ports.as_slice() + }; + // Packet filter needs "any protocol" when protocols are absent. + let protocols = if fr.protocols.is_empty() { + &[Protocol::Any] + } else { + fr.protocols.as_slice() + }; for from in &from_addrs { for to in &to_addrs { - for to_port in &fr.destination_ports { - for proto in &fr.protocols { + for to_port in destination_ports { + for proto in protocols { let rule = Self { from: *from, from_port: Port::Any, @@ -295,7 +300,7 @@ impl PacketFilterRule { to_port: *to_port, action, direction: Direction::In, - // Enable quick to match NFTables behaviour. + // Enable "quick" to match NFTables behaviour. quick: true, log: PF_LOG, state, @@ -322,7 +327,7 @@ impl fmt::Display for PacketFilterRule { write!(f, "{} {}", self.action, self.direction)?; // TODO: log if self.quick { - write!(f, " quick")?; + f.write_str(" quick")?; } if let Some(interface) = &self.interface { write!(f, " on {interface}")?; @@ -331,13 +336,13 @@ impl fmt::Display for PacketFilterRule { if let Some(from) = self.from { write!(f, " {from}")?; } else { - write!(f, " any")?; + f.write_str(" any")?; } write!(f, " {} to", self.from_port)?; if let Some(to) = self.to { write!(f, " {to}")?; } else { - write!(f, " any")?; + f.write_str(" any")?; } // TODO: tcp_flags/tcp_flags_set write!(f, " {} {}", self.to_port, self.state)?; diff --git a/src/error.rs b/src/error.rs index aceb3c4f..c0b9bf5e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -9,9 +9,6 @@ pub enum GatewayError { #[error("Command {command} execution failed. Error: {error}")] CommandExecutionFailed { command: String, error: String }, - #[error("WireGuard key error")] - KeyDecode(#[from] base64::DecodeError), - #[error("Logger error")] Logger(#[from] log::SetLoggerError), diff --git a/src/gateway.rs b/src/gateway.rs index 78fe18cf..dd876653 100644 --- a/src/gateway.rs +++ b/src/gateway.rs @@ -124,9 +124,6 @@ pub struct Gateway { interface_configuration: Option, peers: HashMap, wgapi: Arc>, - #[cfg_attr(not(target_os = "linux"), allow(unused))] - firewall_api: FirewallApi, - #[cfg_attr(not(target_os = "linux"), allow(unused))] firewall_config: Option, pub connected: Arc, client: GatewayClientType, @@ -138,7 +135,6 @@ impl Gateway { pub fn new( config: Config, wgapi: impl WireguardInterfaceApi + Send + Sync + 'static, - firewall_api: FirewallApi, ) -> Result { let client = Self::setup_client(&config)?; Ok(Self { @@ -149,7 +145,6 @@ impl Gateway { connected: Arc::new(AtomicBool::new(false)), client, stats_thread: None, - firewall_api, firewall_config: None, core_info: None, }) @@ -388,13 +383,10 @@ impl Gateway { "Received firewall configuration is different than current one. \ Reconfiguring firewall..." ); - self.firewall_api.begin()?; - self.firewall_api - .setup(fw_config.default_policy, self.config.fw_priority)?; - self.firewall_api - .setup_nat(self.config.masquerade, &fw_config.snat_bindings)?; - self.firewall_api.add_rules(fw_config.rules.clone())?; - self.firewall_api.commit()?; + let mut firewall_api = FirewallApi::new(&self.config.ifname)?; + firewall_api.setup(fw_config.default_policy, self.config.fw_priority)?; + firewall_api.setup_nat(self.config.masquerade, &fw_config.snat_bindings)?; + firewall_api.add_rules(&fw_config.rules)?; self.firewall_config = Some(fw_config.clone()); info!("Reconfigured firewall with new configuration"); } else { @@ -405,10 +397,9 @@ impl Gateway { } } else { debug!("Received firewall configuration is empty, cleaning up firewall rules..."); - self.firewall_api.begin()?; - self.firewall_api.cleanup()?; - self.firewall_api.setup_nat(self.config.masquerade, &[])?; - self.firewall_api.commit()?; + let mut firewall_api = FirewallApi::new(&self.config.ifname)?; + firewall_api.cleanup()?; + firewall_api.setup_nat(self.config.masquerade, &[])?; self.firewall_config = None; debug!("Cleaned up firewall rules"); } @@ -701,9 +692,8 @@ impl Gateway { } else { #[cfg(target_os = "linux")] if !self.config.disable_firewall_management && self.config.masquerade { - self.firewall_api.begin()?; - self.firewall_api.setup_nat(self.config.masquerade, &[])?; - self.firewall_api.commit()?; + let mut firewall_api = FirewallApi::new(&self.config.ifname)?; + firewall_api.setup_nat(self.config.masquerade, &[])?; } } @@ -791,7 +781,6 @@ mod tests { let wgapi = WG::new("wg0").unwrap(); let config = Config::default(); let client = Gateway::setup_client(&config).unwrap(); - let firewall_api = FirewallApi::new("wg0").unwrap(); let gateway = Gateway { config, interface_configuration: Some(old_config.clone()), @@ -800,7 +789,6 @@ mod tests { connected: Arc::new(AtomicBool::new(false)), client, stats_thread: None, - firewall_api, firewall_config: None, core_info: None, }; @@ -988,7 +976,6 @@ mod tests { connected: Arc::new(AtomicBool::new(false)), client, stats_thread: None, - firewall_api: FirewallApi::new("test_interface").unwrap(), firewall_config: None, core_info: None, }; @@ -1057,7 +1044,6 @@ mod tests { connected: Arc::new(AtomicBool::new(false)), client, stats_thread: None, - firewall_api: FirewallApi::new("test_interface").unwrap(), firewall_config: None, core_info: None, }; diff --git a/src/main.rs b/src/main.rs index 7ad00ac1..7c1cc229 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,8 @@ use std::{fs::File, io::Write, process, sync::Arc}; use defguard_gateway::{ - VERSION, config::get_config, enterprise::firewall::api::FirewallApi, error::GatewayError, - execute_command, gateway::Gateway, init_syslog, server::run_server, + VERSION, config::get_config, error::GatewayError, execute_command, gateway::Gateway, + init_syslog, server::run_server, }; use defguard_version::Version; #[cfg(not(any(target_os = "macos", target_os = "netbsd")))] @@ -40,16 +40,15 @@ async fn main() -> Result<(), GatewayError> { } let ifname = config.ifname.clone(); - let firewall_api = FirewallApi::new(&ifname)?; let mut gateway = if config.userspace { let wgapi = WGApi::::new(ifname)?; - Gateway::new(config.clone(), wgapi, firewall_api)? + Gateway::new(config.clone(), wgapi)? } else { #[cfg(not(any(target_os = "macos", target_os = "netbsd")))] { let wgapi = WGApi::::new(ifname)?; - Gateway::new(config.clone(), wgapi, firewall_api)? + Gateway::new(config.clone(), wgapi)? } #[cfg(any(target_os = "macos", target_os = "netbsd"))] {