Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 158 additions & 3 deletions crates/defguard_core/src/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use crate::{
appstate::AppState,
db::{
Group, Id, OAuth2AuthorizedApp, OAuth2Token, Session, SessionState, User,
models::group::Permission,
models::{group::Permission, oauth2client::OAuth2Client},
},
enterprise::{db::models::api_tokens::ApiToken, is_enterprise_enabled},
error::WebError,
Expand Down Expand Up @@ -303,8 +303,80 @@ macro_rules! role {

role!(AdminRole, Permission::IsAdmin);

#[derive(Debug)]
pub(crate) struct UserClaims {
pub email: Option<String>,
pub family_name: Option<String>,
pub given_name: Option<String>,
pub name: Option<String>,
pub phone_number: Option<String>,
pub preferred_username: Option<String>,
pub sub: String,
}

fn get_available_scopes<'a>(
all_scopes: &'a [String],
requested_scopes: &'a [String],
) -> Vec<&'a str> {
let mut scopes = Vec::new();
for scope in requested_scopes {
if all_scopes.contains(scope) {
scopes.push(scope.as_str());
}
}
scopes
}

impl UserClaims {
pub fn from_user(
user: &User<Id>,
oauth_client: &OAuth2Client<Id>,
oauth_token: &OAuth2Token,
) -> Self {
let token_scopes = oauth_token
.scope
.split_whitespace()
.map(String::from)
.collect::<Vec<String>>();
let scopes = get_available_scopes(&oauth_client.scope, &token_scopes);
Self {
email: if scopes.contains(&"email") {
Some(user.email.clone())
} else {
None
},
family_name: if scopes.contains(&"profile") {
Some(user.last_name.clone())
} else {
None
},
given_name: if scopes.contains(&"profile") {
Some(user.first_name.clone())
} else {
None
},
name: if scopes.contains(&"profile") {
Some(user.name())
} else {
None
},
phone_number: if scopes.contains(&"phone") {
user.phone.clone()
} else {
None
},
preferred_username: if scopes.contains(&"profile") {
Some(user.username.clone())
} else {
None
},
sub: user.username.clone(),
}
}
}

// User authenticated by a valid access token
pub struct AccessUserInfo(pub(crate) User<Id>);
pub struct AccessUserInfo(pub(crate) UserClaims);

impl<S> FromRequestParts<S> for AccessUserInfo
where
Expand Down Expand Up @@ -339,7 +411,22 @@ where
if let Ok(Some(user)) =
User::find_by_id(&appstate.pool, authorized_app.user_id).await
{
return Ok(AccessUserInfo(user));
if let Some(client) = OAuth2Client::find_by_id(
&appstate.pool,
authorized_app.oauth2client_id,
)
.await?
{
return Ok(AccessUserInfo(UserClaims::from_user(
&user,
&client,
&oauth2token,
)));
} else {
return Err(WebError::Authorization(
"OAuth2 client not found".into(),
));
}
}
}
Ok(None) => {
Expand All @@ -363,3 +450,71 @@ where
Err(WebError::Authorization("Invalid session".into()))
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_get_available_scopes() {
// All requested scopes are available
let all_scopes = vec![
"email".to_string(),
"profile".to_string(),
"phone".to_string(),
];
let requested_scopes = vec!["email".to_string(), "profile".to_string()];
let result = get_available_scopes(&all_scopes, &requested_scopes);
assert_eq!(result, vec!["email", "profile"]);

// Some requested scopes are not available
let all_scopes = vec!["email".to_string(), "profile".to_string()];
let requested_scopes = vec![
"email".to_string(),
"phone".to_string(),
"profile".to_string(),
];
let result = get_available_scopes(&all_scopes, &requested_scopes);
assert_eq!(result, vec!["email", "profile"]);

// No requested scopes
let all_scopes = vec!["email".to_string(), "profile".to_string()];
let requested_scopes = vec![];
let result = get_available_scopes(&all_scopes, &requested_scopes);
assert_eq!(result, Vec::<&str>::new());

// No available scopes
let all_scopes = vec![];
let requested_scopes = vec!["email".to_string(), "profile".to_string()];
let result = get_available_scopes(&all_scopes, &requested_scopes);
assert_eq!(result, Vec::<&str>::new());

// Both empty
let all_scopes = vec![];
let requested_scopes = vec![];
let result = get_available_scopes(&all_scopes, &requested_scopes);
assert_eq!(result, Vec::<&str>::new());

// Duplicate requested scopes
let all_scopes = vec!["email".to_string(), "profile".to_string()];
let requested_scopes = vec![
"email".to_string(),
"email".to_string(),
"profile".to_string(),
];
let result = get_available_scopes(&all_scopes, &requested_scopes);
assert_eq!(result, vec!["email", "email", "profile"]);

// Case sensitivity
let all_scopes = vec!["email".to_string(), "profile".to_string()];
let requested_scopes = vec!["Email".to_string(), "PROFILE".to_string()];
let result = get_available_scopes(&all_scopes, &requested_scopes);
assert_eq!(result, Vec::<&str>::new());

// Single scope match
let all_scopes = vec!["email".to_string()];
let requested_scopes = vec!["email".to_string()];
let result = get_available_scopes(&all_scopes, &requested_scopes);
assert_eq!(result, vec!["email"]);
}
}
61 changes: 38 additions & 23 deletions crates/defguard_core/src/handlers/openid_flow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use time::Duration;
use super::{ApiResponse, ApiResult, SESSION_COOKIE_NAME};
use crate::{
appstate::AppState,
auth::{AccessUserInfo, SessionInfo},
auth::{AccessUserInfo, SessionInfo, UserClaims},
db::{
Id, OAuth2AuthorizedApp, OAuth2Token, Session, SessionState, User,
models::{auth_code::AuthCode, oauth2client::OAuth2Client},
Expand All @@ -52,27 +52,41 @@ use crate::{
};

/// https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims
impl From<&User<Id>> for StandardClaims<CoreGenderClaim> {
fn from(user: &User<Id>) -> StandardClaims<CoreGenderClaim> {
let mut name = LocalizedClaim::new();
name.insert(None, EndUserName::new(user.name()));
let mut given_name = LocalizedClaim::new();
given_name.insert(None, EndUserGivenName::new(user.first_name.clone()));
let mut given_name = LocalizedClaim::new();
given_name.insert(None, EndUserGivenName::new(user.first_name.clone()));
let mut family_name = LocalizedClaim::new();
family_name.insert(None, EndUserFamilyName::new(user.last_name.clone()));
let email = EndUserEmail::new(user.email.clone());
let phone_number = user.phone.clone().map(EndUserPhoneNumber::new);
let preferred_username = EndUserUsername::new(user.username.clone());

StandardClaims::new(SubjectIdentifier::new(user.username.clone()))
.set_name(Some(name))
.set_given_name(Some(given_name))
.set_family_name(Some(family_name))
.set_email(Some(email))
.set_phone_number(phone_number)
.set_preferred_username(Some(preferred_username))
impl From<&UserClaims> for StandardClaims<CoreGenderClaim> {
fn from(user_claims: &UserClaims) -> StandardClaims<CoreGenderClaim> {
let mut claims = StandardClaims::new(SubjectIdentifier::new(user_claims.sub.clone()));

if let Some(name) = &user_claims.name {
let mut localized_claim = LocalizedClaim::new();
localized_claim.insert(None, EndUserName::new(name.clone()));
claims = claims.set_name(Some(localized_claim));
}

if let Some(given_name) = &user_claims.given_name {
let mut localized_claim = LocalizedClaim::new();
localized_claim.insert(None, EndUserGivenName::new(given_name.clone()));
claims = claims.set_given_name(Some(localized_claim));
}

if let Some(family_name) = &user_claims.family_name {
let mut localized_claim = LocalizedClaim::new();
localized_claim.insert(None, EndUserFamilyName::new(family_name.clone()));
claims = claims.set_family_name(Some(localized_claim));
}

if let Some(email) = &user_claims.email {
claims = claims.set_email(Some(EndUserEmail::new(email.clone())));
}

if let Some(phone_number) = &user_claims.phone_number {
claims = claims.set_phone_number(Some(EndUserPhoneNumber::new(phone_number.clone())));
}

if let Some(username) = &user_claims.preferred_username {
claims = claims.set_preferred_username(Some(EndUserUsername::new(username.clone())));
}

claims
}
}

Expand Down Expand Up @@ -830,10 +844,11 @@ pub async fn token(
GroupClaims { groups: None }
};
let config = server_config();
let user_claims = UserClaims::from_user(&user, &client, &token);
match form.authorization_code_flow(
&auth_code,
&token,
(&user).into(),
(&user_claims).into(),
&config.url,
client.client_secret,
config.openid_key(),
Expand Down
Loading