diff --git a/crates/defguard_core/src/auth/mod.rs b/crates/defguard_core/src/auth/mod.rs index bca0b2dbc3..b0fd990e20 100644 --- a/crates/defguard_core/src/auth/mod.rs +++ b/crates/defguard_core/src/auth/mod.rs @@ -24,7 +24,7 @@ use crate::{ appstate::AppState, db::{ Group, Id, OAuth2AuthorizedApp, OAuth2Token, Session, SessionState, User, - models::group::Permission, + models::{group::Permission, oauth2client::OAuth2Client}, }, enterprise::{db::models::api_tokens::ApiToken, is_enterprise_enabled}, error::WebError, @@ -303,8 +303,80 @@ macro_rules! role { role!(AdminRole, Permission::IsAdmin); +#[derive(Debug)] +pub(crate) struct UserClaims { + pub email: Option, + pub family_name: Option, + pub given_name: Option, + pub name: Option, + pub phone_number: Option, + pub preferred_username: Option, + pub sub: String, +} + +fn get_available_scopes<'a>( + all_scopes: &'a [String], + requested_scopes: &'a [String], +) -> Vec<&'a str> { + let mut scopes = Vec::new(); + for scope in requested_scopes { + if all_scopes.contains(scope) { + scopes.push(scope.as_str()); + } + } + scopes +} + +impl UserClaims { + pub fn from_user( + user: &User, + oauth_client: &OAuth2Client, + oauth_token: &OAuth2Token, + ) -> Self { + let token_scopes = oauth_token + .scope + .split_whitespace() + .map(String::from) + .collect::>(); + let scopes = get_available_scopes(&oauth_client.scope, &token_scopes); + Self { + email: if scopes.contains(&"email") { + Some(user.email.clone()) + } else { + None + }, + family_name: if scopes.contains(&"profile") { + Some(user.last_name.clone()) + } else { + None + }, + given_name: if scopes.contains(&"profile") { + Some(user.first_name.clone()) + } else { + None + }, + name: if scopes.contains(&"profile") { + Some(user.name()) + } else { + None + }, + phone_number: if scopes.contains(&"phone") { + user.phone.clone() + } else { + None + }, + preferred_username: if scopes.contains(&"profile") { + Some(user.username.clone()) + } else { + None + }, + sub: user.username.clone(), + } + } +} + // User authenticated by a valid access token -pub struct AccessUserInfo(pub(crate) User); +pub struct AccessUserInfo(pub(crate) UserClaims); impl FromRequestParts for AccessUserInfo where @@ -339,7 +411,22 @@ where if let Ok(Some(user)) = User::find_by_id(&appstate.pool, authorized_app.user_id).await { - return Ok(AccessUserInfo(user)); + if let Some(client) = OAuth2Client::find_by_id( + &appstate.pool, + authorized_app.oauth2client_id, + ) + .await? + { + return Ok(AccessUserInfo(UserClaims::from_user( + &user, + &client, + &oauth2token, + ))); + } else { + return Err(WebError::Authorization( + "OAuth2 client not found".into(), + )); + } } } Ok(None) => { @@ -363,3 +450,71 @@ where Err(WebError::Authorization("Invalid session".into())) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_get_available_scopes() { + // All requested scopes are available + let all_scopes = vec![ + "email".to_string(), + "profile".to_string(), + "phone".to_string(), + ]; + let requested_scopes = vec!["email".to_string(), "profile".to_string()]; + let result = get_available_scopes(&all_scopes, &requested_scopes); + assert_eq!(result, vec!["email", "profile"]); + + // Some requested scopes are not available + let all_scopes = vec!["email".to_string(), "profile".to_string()]; + let requested_scopes = vec![ + "email".to_string(), + "phone".to_string(), + "profile".to_string(), + ]; + let result = get_available_scopes(&all_scopes, &requested_scopes); + assert_eq!(result, vec!["email", "profile"]); + + // No requested scopes + let all_scopes = vec!["email".to_string(), "profile".to_string()]; + let requested_scopes = vec![]; + let result = get_available_scopes(&all_scopes, &requested_scopes); + assert_eq!(result, Vec::<&str>::new()); + + // No available scopes + let all_scopes = vec![]; + let requested_scopes = vec!["email".to_string(), "profile".to_string()]; + let result = get_available_scopes(&all_scopes, &requested_scopes); + assert_eq!(result, Vec::<&str>::new()); + + // Both empty + let all_scopes = vec![]; + let requested_scopes = vec![]; + let result = get_available_scopes(&all_scopes, &requested_scopes); + assert_eq!(result, Vec::<&str>::new()); + + // Duplicate requested scopes + let all_scopes = vec!["email".to_string(), "profile".to_string()]; + let requested_scopes = vec![ + "email".to_string(), + "email".to_string(), + "profile".to_string(), + ]; + let result = get_available_scopes(&all_scopes, &requested_scopes); + assert_eq!(result, vec!["email", "email", "profile"]); + + // Case sensitivity + let all_scopes = vec!["email".to_string(), "profile".to_string()]; + let requested_scopes = vec!["Email".to_string(), "PROFILE".to_string()]; + let result = get_available_scopes(&all_scopes, &requested_scopes); + assert_eq!(result, Vec::<&str>::new()); + + // Single scope match + let all_scopes = vec!["email".to_string()]; + let requested_scopes = vec!["email".to_string()]; + let result = get_available_scopes(&all_scopes, &requested_scopes); + assert_eq!(result, vec!["email"]); + } +} diff --git a/crates/defguard_core/src/handlers/openid_flow.rs b/crates/defguard_core/src/handlers/openid_flow.rs index eb6d0527f5..dce8b6d569 100644 --- a/crates/defguard_core/src/handlers/openid_flow.rs +++ b/crates/defguard_core/src/handlers/openid_flow.rs @@ -41,7 +41,7 @@ use time::Duration; use super::{ApiResponse, ApiResult, SESSION_COOKIE_NAME}; use crate::{ appstate::AppState, - auth::{AccessUserInfo, SessionInfo}, + auth::{AccessUserInfo, SessionInfo, UserClaims}, db::{ Id, OAuth2AuthorizedApp, OAuth2Token, Session, SessionState, User, models::{auth_code::AuthCode, oauth2client::OAuth2Client}, @@ -52,27 +52,41 @@ use crate::{ }; /// https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims -impl From<&User> for StandardClaims { - fn from(user: &User) -> StandardClaims { - let mut name = LocalizedClaim::new(); - name.insert(None, EndUserName::new(user.name())); - let mut given_name = LocalizedClaim::new(); - given_name.insert(None, EndUserGivenName::new(user.first_name.clone())); - let mut given_name = LocalizedClaim::new(); - given_name.insert(None, EndUserGivenName::new(user.first_name.clone())); - let mut family_name = LocalizedClaim::new(); - family_name.insert(None, EndUserFamilyName::new(user.last_name.clone())); - let email = EndUserEmail::new(user.email.clone()); - let phone_number = user.phone.clone().map(EndUserPhoneNumber::new); - let preferred_username = EndUserUsername::new(user.username.clone()); - - StandardClaims::new(SubjectIdentifier::new(user.username.clone())) - .set_name(Some(name)) - .set_given_name(Some(given_name)) - .set_family_name(Some(family_name)) - .set_email(Some(email)) - .set_phone_number(phone_number) - .set_preferred_username(Some(preferred_username)) +impl From<&UserClaims> for StandardClaims { + fn from(user_claims: &UserClaims) -> StandardClaims { + let mut claims = StandardClaims::new(SubjectIdentifier::new(user_claims.sub.clone())); + + if let Some(name) = &user_claims.name { + let mut localized_claim = LocalizedClaim::new(); + localized_claim.insert(None, EndUserName::new(name.clone())); + claims = claims.set_name(Some(localized_claim)); + } + + if let Some(given_name) = &user_claims.given_name { + let mut localized_claim = LocalizedClaim::new(); + localized_claim.insert(None, EndUserGivenName::new(given_name.clone())); + claims = claims.set_given_name(Some(localized_claim)); + } + + if let Some(family_name) = &user_claims.family_name { + let mut localized_claim = LocalizedClaim::new(); + localized_claim.insert(None, EndUserFamilyName::new(family_name.clone())); + claims = claims.set_family_name(Some(localized_claim)); + } + + if let Some(email) = &user_claims.email { + claims = claims.set_email(Some(EndUserEmail::new(email.clone()))); + } + + if let Some(phone_number) = &user_claims.phone_number { + claims = claims.set_phone_number(Some(EndUserPhoneNumber::new(phone_number.clone()))); + } + + if let Some(username) = &user_claims.preferred_username { + claims = claims.set_preferred_username(Some(EndUserUsername::new(username.clone()))); + } + + claims } } @@ -830,10 +844,11 @@ pub async fn token( GroupClaims { groups: None } }; let config = server_config(); + let user_claims = UserClaims::from_user(&user, &client, &token); match form.authorization_code_flow( &auth_code, &token, - (&user).into(), + (&user_claims).into(), &config.url, client.client_secret, config.openid_key(), diff --git a/crates/defguard_core/tests/integration/api/openid.rs b/crates/defguard_core/tests/integration/api/openid.rs index 3a2b141d5a..4017100d92 100644 --- a/crates/defguard_core/tests/integration/api/openid.rs +++ b/crates/defguard_core/tests/integration/api/openid.rs @@ -4,7 +4,7 @@ use axum::http::header::ToStrError; use claims::assert_err; use defguard_core::{ db::{ - Id, + Id, User, models::{NewOpenIDClient, oauth2client::OAuth2Client}, }, handlers::Auth, @@ -610,6 +610,186 @@ async fn test_openid_authorization_code_with_pkce(_: PgPoolOptions, options: PgC .unwrap(); } +#[sqlx::test] +async fn dg25_22_test_respect_openid_scope_in_userinfo( + _: PgPoolOptions, + options: PgConnectOptions, +) { + let pool = setup_pool(options).await; + + let (client, state) = make_client_with_state(pool).await; + let mut config = state.config; + + let mut admin = User::find_by_username(&state.pool, "admin") + .await + .unwrap() + .unwrap(); + + admin.phone = Some("+123456789".into()); + admin.save(&state.pool).await.unwrap(); + + let mut rng = rand::thread_rng(); + config.openid_signing_key = RsaPrivateKey::new(&mut rng, 2048).ok(); + + let issuer_url = IssuerUrl::from_url(config.url.clone()); + + // discover OpenID service + let provider_metadata = + CoreProviderMetadata::discover_async(issuer_url, &|r| http_client(r, &client)) + .await + .unwrap(); + + // Create reusable closure for testing different scope configurations + let get_user_claims = |client_scopes: Vec, requested_scopes: Vec| { + let client = &client; + let provider_metadata = provider_metadata.clone(); + async move { + // Authenticate admin + let auth = Auth::new("admin", "pass123"); + let response = client.post("/api/v1/auth").json(&auth).send().await; + assert_eq!(response.status(), StatusCode::OK); + + // Create OAuth2 client with specified scopes + let oauth2client = NewOpenIDClient { + name: "Test client".into(), + redirect_uri: vec![FAKE_REDIRECT_URI.into()], + scope: client_scopes, + enabled: true, + }; + let response = client + .post("/api/v1/oauth") + .json(&oauth2client) + .send() + .await; + assert_eq!(response.status(), StatusCode::CREATED); + let oauth2client: OAuth2Client = response.json().await; + + // Store client_id for cleanup + let client_id_for_cleanup = oauth2client.client_id.clone(); + + // Create OpenID client + let client_id = ClientId::new(oauth2client.client_id); + let client_secret = ClientSecret::new(oauth2client.client_secret); + let core_client = CoreClient::from_provider_metadata( + provider_metadata, + client_id, + Some(client_secret), + ) + .set_redirect_uri(RedirectUrl::new(FAKE_REDIRECT_URI.into()).unwrap()); + + // Start Authorization Code Flow with PKCE + let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + let mut auth_request = core_client.authorize_url( + AuthenticationFlow::::AuthorizationCode, + CsrfToken::new_random, + Nonce::new_random, + ); + + // Add requested scopes + for scope in requested_scopes { + auth_request = auth_request.add_scope(Scope::new(scope)); + } + + let (authorize_url, _csrf_state, nonce) = + auth_request.set_pkce_challenge(pkce_challenge).url(); + + // Obtain authorization code + let uri = format!( + "{}?allow=true&{}", + authorize_url.path(), + authorize_url.query().unwrap() + ); + let response = client.post(uri).send().await; + assert_eq!(response.status(), StatusCode::FOUND); + let location = response + .headers() + .get("Location") + .unwrap() + .to_str() + .unwrap(); + let (location, query) = location.split_once('?').unwrap(); + assert_eq!(location, FAKE_REDIRECT_URI); + let auth_response: AuthenticationResponse = serde_qs::from_str(query).unwrap(); + + // Exchange authorization code for token + let token_response = core_client + .exchange_code(AuthorizationCode::new(auth_response.code.into())) + .unwrap() + .set_pkce_verifier(pkce_verifier) + .request_async(&|r| http_client(r, client)) + .await + .unwrap(); + + // Verify id token + let id_token_verifier = core_client.id_token_verifier(); + let _id_token_claims = token_response + .extra_fields() + .id_token() + .expect("Server did not return an ID token") + .claims(&id_token_verifier, &nonce) + .unwrap(); + + // Get userinfo claims + let userinfo_claims: UserInfoClaims = + core_client + .user_info(token_response.access_token().clone(), None) + .expect("Missing info endpoint") + .request_async(&|r| http_client(r, client)) + .await + .unwrap(); + + // Clean up - delete the OAuth client + client + .delete(format!("/api/v1/oauth/{}", client_id_for_cleanup)) + .send() + .await; + + userinfo_claims + } + }; + + // Client has phone and email scopes, request phone and email + let claims = get_user_claims( + vec![ + "openid".to_string(), + "phone".to_string(), + "email".to_string(), + ], + vec!["email".to_string(), "phone".to_string()], + ) + .await; + + // Verify claims include both email and phone + assert!(claims.email().is_some()); + assert!(claims.phone_number().is_some()); + + // Client has phone and email scopes, but only request email + let claims = get_user_claims( + vec![ + "openid".to_string(), + "phone".to_string(), + "email".to_string(), + ], + vec!["email".to_string()], + ) + .await; + + // Verify claims include only email, not phone + assert!(claims.email().is_some()); + assert!(claims.phone_number().is_none()); + + // Client has only email scope, request phone + let claims = get_user_claims( + vec!["openid".to_string(), "email".to_string()], + vec!["email".to_string(), "phone".to_string()], + ) + .await; + + // Verify claims include only email since client doesn't have phone scope + assert!(claims.email().is_some()); + assert!(claims.phone_number().is_none()); +} + #[sqlx::test] async fn test_openid_flow_new_login_mail(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await;