Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

103 changes: 70 additions & 33 deletions crates/defguard_core/src/enterprise/db/models/acl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use defguard_common::db::{
wireguard::{LocationMfaMode, ServiceLocationMode},
},
};
use ipnetwork::{IpNetwork, IpNetworkError};
use ipnetwork::IpNetwork;
use model_derive::Model;
use sqlx::{
Error as SqlxError, FromRow, PgConnection, PgExecutor, PgPool, Type, error::ErrorKind,
Expand Down Expand Up @@ -563,6 +563,40 @@ pub(crate) struct ParsedDestination {
pub(crate) ranges: Vec<(IpAddr, IpAddr)>,
}

fn invalid_destination_range(range: &str) -> AclError {
error!("Failed to parse destination range token: \"{range}\"");
AclError::InvalidIpRangeError(range.to_string())
}

fn parse_destination_range(range: &str) -> Result<(IpAddr, IpAddr), AclError> {
let Some((start, end)) = range.split_once('-') else {
return Err(invalid_destination_range(range));
};

if start.is_empty() || end.is_empty() || end.contains('-') {
return Err(invalid_destination_range(range));
}

if start.contains('/') || end.contains('/') {
return Err(invalid_destination_range(range));
}

let start = start.parse::<IpAddr>()?;
let end = end.parse::<IpAddr>()?;

let is_non_increasing = match (&start, &end) {
(IpAddr::V4(start), IpAddr::V4(end)) => start.octets() >= end.octets(),
(IpAddr::V6(start), IpAddr::V6(end)) => start.octets() >= end.octets(),
_ => return Err(invalid_destination_range(range)),
};

if is_non_increasing {
return Err(invalid_destination_range(range));
}

Ok((start, end))
}

/// Perses a destination string into singular ip addresses or networks and address
/// ranges. We should be able to parse a string like this one:
/// `10.0.0.1/24, 10.1.1.10-10.1.1.20, 192.168.1.10, 10.1.1.1-10.10.1.1`
Expand All @@ -573,16 +607,11 @@ pub(crate) fn parse_destination_addresses(
let destination: String = destination.chars().filter(|c| !c.is_whitespace()).collect();
let mut result = ParsedDestination::default();
if !destination.is_empty() {
for v in destination.split(',') {
match v.split('-').collect::<Vec<_>>() {
l if l.len() == 1 => result.addrs.push(l[0].parse::<IpNetwork>()?),
l if l.len() == 2 => result
.ranges
.push((l[0].parse::<IpAddr>()?, l[1].parse::<IpAddr>()?)),
_ => {
error!("Failed to parse destination string: \"{destination}\"");
Err(IpNetworkError::InvalidAddr(destination.clone()))?;
}
for token in destination.split(',') {
if token.contains('-') {
result.ranges.push(parse_destination_range(token)?);
} else {
result.addrs.push(token.parse::<IpNetwork>()?);
}
}
}
Expand All @@ -594,6 +623,34 @@ pub(crate) fn parse_destination_addresses(
/// Parses ports string into singular ports and port ranges
/// We should be able to parse a string like this one:
/// `22, 23, 8000-9000, 80-90`
fn invalid_ports_format(ports: &str) -> AclError {
error!("Failed to parse ports string: \"{ports}\"");
AclError::InvalidPortsFormat(ports.to_string())
}

fn parse_port_token(port_token: &str, ports: &str) -> Result<PortRange, AclError> {
if port_token.is_empty() {
return Err(invalid_ports_format(ports));
}

let Some((start, end)) = port_token.split_once('-') else {
let port = port_token.parse::<u16>()?;
return Ok(PortRange::new(port, port));
};

if start.is_empty() || end.is_empty() || end.contains('-') {
return Err(invalid_ports_format(ports));
}

let start = start.parse::<u16>()?;
let end = end.parse::<u16>()?;
if start >= end {
return Err(invalid_ports_format(ports));
}

Ok(PortRange::new(start, end))
}

pub fn parse_ports(ports: &str) -> Result<Vec<PortRange>, AclError> {
debug!("Parsing ports string: {ports}");
let mut result = Vec::new();
Expand All @@ -602,22 +659,8 @@ pub fn parse_ports(ports: &str) -> Result<Vec<PortRange>, AclError> {
.filter(|c| !c.is_whitespace())
.collect::<String>();
if !ports.is_empty() {
for v in ports.split(',') {
match v.split('-').collect::<Vec<_>>() {
l if l.len() == 1 => {
let start = l[0].parse::<u16>()?;
result.push(PortRange::new(start, start));
}
l if l.len() == 2 => {
let start = l[0].parse::<u16>()?;
let end = l[1].parse::<u16>()?;
result.push(PortRange::new(start, end));
}
_ => {
error!("Failed to parse ports string: \"{ports}\"");
return Err(AclError::InvalidPortsFormat(ports.clone()));
}
}
for port_token in ports.split(',') {
result.push(parse_port_token(port_token, &ports)?);
}
}

Expand Down Expand Up @@ -745,12 +788,6 @@ impl AclRule<Id> {
let destination = parse_destination_addresses(&api_rule.addresses)?;
debug!("Creating related destination ranges for ACL rule {rule_id}");
for range in destination.ranges {
if range.1 <= range.0 {
return Err(AclError::InvalidIpRangeError(format!(
"{}-{}",
range.0, range.1
)));
}
let obj = AclRuleDestinationRange {
id: NoId,
rule_id,
Expand Down
115 changes: 107 additions & 8 deletions crates/defguard_core/src/enterprise/db/models/acl/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ async fn test_allow_conflicting_sources(_: PgPoolOptions, options: PgConnectOpti

// create the rule
let rule = AclRule {
id: NoId,
name: "rule".to_string(),
enabled: true,
allow_all_users: false,
Expand Down Expand Up @@ -127,7 +126,6 @@ async fn test_rule_relations(_: PgPoolOptions, options: PgConnectOptions) {

// create the rule
let mut rule = AclRule {
id: NoId,
name: "rule".to_string(),
enabled: true,
allow_all_users: false,
Expand Down Expand Up @@ -437,15 +435,13 @@ async fn test_all_allowed_users(_: PgPoolOptions, options: PgConnectOptions) {

// Create test groups
let group_1 = Group {
id: NoId,
name: "group_1".into(),
..Default::default()
}
.save(&pool)
.await
.unwrap();
let group_2 = Group {
id: NoId,
name: "group_2".into(),
..Default::default()
}
Expand Down Expand Up @@ -476,7 +472,6 @@ async fn test_all_allowed_users(_: PgPoolOptions, options: PgConnectOptions) {

// Create ACL rule
let rule = AclRule {
id: NoId,
name: "test_rule".to_string(),
allow_all_users: false,
deny_all_users: false,
Expand Down Expand Up @@ -552,15 +547,13 @@ async fn test_all_denied_users(_: PgPoolOptions, options: PgConnectOptions) {

// Create test groups
let group_1 = Group {
id: NoId,
name: "group_1".into(),
..Default::default()
}
.save(&pool)
.await
.unwrap();
let group_2 = Group {
id: NoId,
name: "group_2".into(),
..Default::default()
}
Expand Down Expand Up @@ -591,7 +584,6 @@ async fn test_all_denied_users(_: PgPoolOptions, options: PgConnectOptions) {

// Create ACL rule
let rule = AclRule {
id: NoId,
name: "test_rule".to_string(),
allow_all_users: false,
deny_all_users: false,
Expand Down Expand Up @@ -656,3 +648,110 @@ async fn test_all_denied_users(_: PgPoolOptions, options: PgConnectOptions) {
assert!(denied_users.iter().any(|u| u.id == user_3.id));
assert!(!denied_users.iter().any(|u| u.id == user_4.id));
}

#[test]
fn test_parse_ports_rejects_non_increasing_ranges() {
assert!(matches!(
parse_ports("200-100"),
Err(AclError::InvalidPortsFormat(input)) if input == "200-100"
));
assert!(matches!(
parse_ports("100-100"),
Err(AclError::InvalidPortsFormat(input)) if input == "100-100"
));
}

#[test]
fn test_parse_ports_normalizes_whitespace_before_splitting() {
let parsed = parse_ports("10 - 20, 30, 1 2").unwrap();
let parsed = parsed
.into_iter()
.map(|range| (range.first_port(), range.last_port()))
.collect::<Vec<_>>();

assert_eq!(parsed, vec![(10, 20), (30, 30), (12, 12)]);
}

#[test]
fn test_parse_ports_allows_duplicate_endpoints() {
let parsed = parse_ports("10,10,10-20,20").unwrap();
let parsed = parsed
.into_iter()
.map(|range| (range.first_port(), range.last_port()))
.collect::<Vec<_>>();

assert_eq!(parsed, vec![(10, 10), (10, 10), (10, 20), (20, 20)]);
}

#[test]
fn test_parse_ports_rejects_malformed_range_tokens() {
assert!(matches!(
parse_ports("1-2-3"),
Err(AclError::InvalidPortsFormat(input)) if input == "1-2-3"
));
}

#[test]
fn test_parse_destination_addresses_allows_empty_and_strips_whitespace() {
let parsed = parse_destination_addresses(" \n\t ").unwrap();

assert!(parsed.addrs.is_empty());
assert!(parsed.ranges.is_empty());
}

#[test]
fn test_parse_destination_addresses_accepts_single_ips_cidrs_and_ranges() {
let parsed =
parse_destination_addresses(" 10.0.0.1 , 10.0.0.0/24 , 2001:db8::1-2001:db8::2 ").unwrap();

assert_eq!(
parsed.addrs,
vec![
"10.0.0.1".parse::<IpNetwork>().unwrap(),
"10.0.0.0/24".parse::<IpNetwork>().unwrap(),
]
);
assert_eq!(
parsed.ranges,
vec![(
"2001:db8::1".parse::<IpAddr>().unwrap(),
"2001:db8::2".parse::<IpAddr>().unwrap(),
)]
);
}

#[test]
fn test_parse_destination_addresses_rejects_invalid_ranges() {
for input in [
"10.0.0.2-10.0.0.1",
"10.0.0.1-10.0.0.1",
"10.0.0.1-2001:db8::1",
"10.0.0.1-10.0.0.2-10.0.0.3",
"10.0.0.0/24-10.0.0.2",
] {
assert!(matches!(
parse_destination_addresses(input),
Err(AclError::InvalidIpRangeError(range)) if range == input
));
}
}

#[test]
fn test_parse_destination_addresses_rejects_multi_slash_cidr_tokens() {
for input in ["10.0.0.1/24/25", "2001:db8::1/64/65"] {
assert!(matches!(
parse_destination_addresses(input),
Err(AclError::IpNetworkError(_))
));
}
}

#[test]
fn test_parse_destination_addresses_rejects_malformed_cidr_prefix_tokens() {
for input in ["10.0.0.1/1e1", "10.0.0.1/0x18", "2001:db8::1/64foo"] {
assert!(matches!(
parse_destination_addresses(input),
Err(AclError::IpNetworkError(_))
));
}
}
Loading
Loading