Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 18 additions & 18 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

63 changes: 42 additions & 21 deletions crates/defguard_core/src/db/models/auth_code.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
use chrono::Utc;
use model_derive::Model;
use sqlx::{Error as SqlxError, PgPool, query_as};
use sqlx::{PgExecutor, query_as};

use crate::{
db::{Id, NoId},
random::gen_alphanumeric,
};

#[derive(Model, Clone)]
#[derive(Model)]
#[table(authorization_code)]
pub struct AuthCode<I = NoId> {
pub(crate) struct AuthCode<I = NoId> {
#[allow(dead_code)]
id: I,
pub user_id: Id,
pub client_id: String,
pub code: String,
pub redirect_uri: String,
pub scope: String,
pub auth_time: i64,
pub nonce: Option<String>,
pub code_challenge: Option<String>,
pub(crate) user_id: Id,
pub(crate) client_id: String,
pub(crate) code: String,
pub(crate) redirect_uri: String,
pub(crate) scope: String,
pub(crate) auth_time: i64,
pub(crate) nonce: Option<String>,
pub(crate) code_challenge: Option<String>,
}

impl AuthCode {
#[must_use]
pub fn new(
pub(crate) fn new(
user_id: Id,
client_id: String,
redirect_uri: String,
Expand All @@ -46,21 +47,41 @@ impl AuthCode {
}
}

impl From<AuthCode<Id>> for AuthCode<NoId> {
fn from(value: AuthCode<Id>) -> Self {
Self {
id: NoId,
user_id: value.user_id,
client_id: value.client_id,
code: value.code,
redirect_uri: value.redirect_uri,
scope: value.scope,
auth_time: value.auth_time,
nonce: value.nonce,
code_challenge: value.code_challenge,
}
}
}

impl AuthCode<Id> {
/// Find by code.
pub async fn find_code(pool: &PgPool, code: &str) -> Result<Option<Self>, SqlxError> {
/// If found, delete `AuthCode` from the database right away, so it can't be reused.
pub(crate) async fn find_code<'e, E>(
executor: E,
code: &str,
) -> Result<Option<AuthCode<NoId>>, sqlx::Error>
where
E: PgExecutor<'e>,
{
query_as!(
Self,
"SELECT id, user_id, client_id, code, redirect_uri, scope, auth_time, nonce, \
code_challenge FROM authorization_code WHERE code = $1",
"DELETE FROM authorization_code WHERE code = $1 \
RETURNING id, user_id, client_id, code, redirect_uri, scope, auth_time, nonce, \
code_challenge",
code
)
.fetch_optional(pool)
.fetch_optional(executor)
.await
}

// Remove a used authorization_code
pub async fn consume(self, pool: &PgPool) -> Result<(), SqlxError> {
self.delete(pool).await
.map(|inner_option| inner_option.map(Into::into))
}
}
48 changes: 33 additions & 15 deletions crates/defguard_core/src/db/models/oauth2client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use model_derive::Model;
use reqwest::Url;
use sqlx::{Error as SqlxError, PgExecutor, PgPool, query_as};

use super::NewOpenIDClient;
Expand Down Expand Up @@ -103,23 +102,17 @@ impl OAuth2Client<Id> {
.await
}

/// Checks if `url` matches client config (ignoring trailing slashes)
/// Checks if `url` matches client config (ignoring trailing slashes).
pub(crate) fn contains_redirect_url(&self, url: &str) -> bool {
let parsed_redirect_uris: Vec<String> = self
.redirect_uri
.iter()
.map(|uri| uri.trim_end_matches('/').into())
.collect();
let url_trimmed = url.trim_end_matches('/');

// extract origin from url
let Ok(url) = Url::parse(url) else {
return false;
};
let url = url.origin().ascii_serialization();
for redirect in &self.redirect_uri {
if url_trimmed == redirect.trim_end_matches('/') {
return true;
}
}

!url.split(' ')
.map(|uri| uri.trim_end_matches('/'))
.all(|uri| !parsed_redirect_uris.iter().any(|u| u == uri))
false
}
}

Expand All @@ -140,3 +133,28 @@ impl From<OAuth2Client<Id>> for OAuth2ClientSafe {
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_contains_redirect_url() {
let oauth2client = OAuth2Client {
id: 1,
client_id: String::new(),
client_secret: String::new(),
redirect_uri: vec![
String::from("http://localhost/"),
String::from("http://safe.net/"),
],
scope: Vec::new(),
name: String::new(),
enabled: true,
};
assert!(oauth2client.contains_redirect_url("http://safe.net"));
assert!(oauth2client.contains_redirect_url("http://localhost"));
assert!(!oauth2client.contains_redirect_url("http://safe.net/api"));
assert!(!oauth2client.contains_redirect_url("http://nonexistent:8000"));
}
}
Loading