diff --git a/src/api/v3/auth/token/mod.rs b/src/api/v3/auth/token/mod.rs index a0d7bf25..a2e11a1f 100644 --- a/src/api/v3/auth/token/mod.rs +++ b/src/api/v3/auth/token/mod.rs @@ -210,7 +210,7 @@ async fn post( tag="auth" )] #[tracing::instrument( - name = "api::token_get", + name = "api::v3::token_get", level = "debug", skip(state, headers, user_auth, policy) )] diff --git a/src/api/v3/user/mod.rs b/src/api/v3/user/mod.rs index 40b1bf9b..f529bf74 100644 --- a/src/api/v3/user/mod.rs +++ b/src/api/v3/user/mod.rs @@ -169,7 +169,7 @@ async fn groups( let groups: Vec = state .provider .get_identity_provider() - .list_groups_for_user(&state.db, &user_id) + .list_groups_of_user(&state.db, &user_id) .await .map_err(KeystoneApiError::identity)? .into_iter() @@ -477,7 +477,7 @@ mod tests { async fn test_groups() { let mut identity_mock = MockIdentityProvider::default(); identity_mock - .expect_list_groups_for_user() + .expect_list_groups_of_user() .withf(|_: &DatabaseConnection, uid: &str| uid == "foo") .returning(|_, _| { Ok(vec![Group { diff --git a/src/api/v4/auth/passkey/finish.rs b/src/api/v4/auth/passkey/finish.rs index ba69ab97..bfb898bd 100644 --- a/src/api/v4/auth/passkey/finish.rs +++ b/src/api/v4/auth/passkey/finish.rs @@ -99,7 +99,8 @@ pub(super) async fn finish( }) })??, ) - .methods(vec!["passkey".into()]) + // Unless Keystone support passkey auth method we use x509 (which it technically IS). + .methods(vec!["x509".into()]) .build() .map_err(AuthenticationError::from)?; authed_info.validate()?; diff --git a/src/api/v4/federation/common.rs b/src/api/v4/federation/common.rs new file mode 100644 index 00000000..c644b387 --- /dev/null +++ b/src/api/v4/federation/common.rs @@ -0,0 +1,213 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +use serde_json::Value; + +use openidconnect::IdTokenClaims; +use openidconnect::core::CoreGenderClaim; + +use crate::api::common::{find_project_from_scope, get_domain}; +use crate::api::error::KeystoneApiError; +use crate::api::v4::federation::error::OidcError; +use crate::api::v4::federation::types::{AllOtherClaims, MappedUserData, MappedUserDataBuilder}; +use crate::auth::AuthzInfo; +use crate::federation::types::{ + Scope as ProviderScope, identity_provider::IdentityProvider as ProviderIdentityProvider, + mapping::Mapping as ProviderMapping, +}; +use crate::identity::IdentityApi; +use crate::keystone::ServiceState; + +/// Convert ProviderScope to AuthZ information +/// +/// # Arguments +/// * `state`: The service state +/// * `scope`: The scope to extract the AuthZ information from +/// +/// # Returns +/// * `Ok(AuthzInfo)`: The AuthZ information +/// * `Err(KeystoneApiError)`: An error if the scope is not valid +pub(super) async fn get_authz_info( + state: &ServiceState, + scope: Option<&ProviderScope>, +) -> Result { + let authz_info = match scope { + Some(ProviderScope::Project(scope)) => { + if let Some(project) = find_project_from_scope(state, &scope.into()).await? { + AuthzInfo::Project(project) + } else { + return Err(KeystoneApiError::Unauthorized); + } + } + Some(ProviderScope::Domain(scope)) => { + if let Ok(domain) = get_domain(state, scope.id.as_ref(), scope.name.as_ref()).await { + AuthzInfo::Domain(domain) + } else { + return Err(KeystoneApiError::Unauthorized); + } + } + Some(ProviderScope::System(_scope)) => todo!(), + None => AuthzInfo::Unscoped, + }; + authz_info.validate()?; + Ok(authz_info) +} + +/// Validate bound claims in the token +/// +/// # Arguments +/// +/// * `mapping` - The mapping to validate against +/// * `claims` - The claims to validate +/// * `claims_as_json` - The claims as json to validate +/// +/// # Returns +/// +/// * `Result<(), OidcError>` +pub(super) fn validate_bound_claims( + mapping: &ProviderMapping, + claims: &IdTokenClaims, + claims_as_json: &Value, +) -> Result<(), OidcError> { + if let Some(bound_subject) = &mapping.bound_subject { + if bound_subject != claims.subject().as_str() { + return Err(OidcError::BoundSubjectMismatch { + expected: bound_subject.to_string(), + found: claims.subject().as_str().into(), + }); + } + } + if let Some(bound_audiences) = &mapping.bound_audiences { + let mut bound_audiences_match: bool = false; + for claim_audience in claims.audiences() { + if bound_audiences.iter().any(|x| x == claim_audience.as_str()) { + bound_audiences_match = true; + } + } + if !bound_audiences_match { + return Err(OidcError::BoundAudiencesMismatch { + expected: bound_audiences.join(","), + found: claims + .audiences() + .iter() + .map(|x| x.as_str()) + .collect::>() + .join(","), + }); + } + } + if let Some(bound_claims) = &mapping.bound_claims { + if let Some(required_claims) = bound_claims.as_object() { + for (claim, value) in required_claims.iter() { + if !claims_as_json + .get(claim) + .map(|x| x == value) + .is_some_and(|val| val) + { + return Err(OidcError::BoundClaimsMismatch { + claim: claim.to_string(), + expected: value.to_string(), + found: claims_as_json + .get(claim) + .map(|x| x.to_string()) + .unwrap_or_default(), + }); + } + } + } + } + Ok(()) +} + +/// Map the user data using the referred mapping +/// +/// # Arguments +/// * `idp` - The identity provider +/// * `mapping` - The mapping to use +/// * `claims_as_json` - The claims as json +/// +/// # Returns +/// The mapped user data +pub(super) async fn map_user_data( + state: &ServiceState, + idp: &ProviderIdentityProvider, + mapping: &ProviderMapping, + claims_as_json: &Value, +) -> Result { + let mut builder = MappedUserDataBuilder::default(); + if let Some(token_user_id) = &mapping.token_user_id { + // TODO: How to check that the user belongs to the right domain) + if let Ok(Some(user)) = state + .provider + .get_identity_provider() + .get_user(&state.db, token_user_id) + .await + { + builder.unique_id(token_user_id.clone()); + builder.user_name(user.name.clone()); + } else { + return Err(OidcError::UserNotFound(token_user_id.clone()))?; + } + } else { + builder.unique_id( + claims_as_json + .get(&mapping.user_id_claim) + .and_then(|x| x.as_str()) + .ok_or_else(|| OidcError::UserIdClaimRequired(mapping.user_id_claim.clone()))? + .to_string(), + ); + + builder.user_name( + claims_as_json + .get(&mapping.user_name_claim) + .and_then(|x| x.as_str()) + .ok_or_else(|| OidcError::UserNameClaimRequired(mapping.user_name_claim.clone()))?, + ); + } + + builder.domain_id( + mapping + .domain_id + .as_ref() + .or(idp.domain_id.as_ref()) + .or(mapping + .domain_id_claim + .as_ref() + .and_then(|claim| { + claims_as_json + .get(claim) + .and_then(|x| x.as_str().map(|v| v.to_string())) + }) + .as_ref()) + .ok_or(OidcError::UserDomainUnbound)?, + ); + + if let Some(groups_claim) = &mapping.groups_claim { + if let Some(group_names_data) = &claims_as_json.get(groups_claim) { + builder.group_names( + group_names_data + .as_array() + .map(|names| { + names + .iter() + .map(|group| group.as_str().map(|v| v.to_string())) + .collect::>>() + }) + .ok_or(OidcError::GroupsClaimNotArrayOfStrings)?, + ); + } + } + + Ok(builder.build()?) +} diff --git a/src/api/v4/federation/error.rs b/src/api/v4/federation/error.rs index d374ad9b..0bf35f23 100644 --- a/src/api/v4/federation/error.rs +++ b/src/api/v4/federation/error.rs @@ -40,6 +40,9 @@ pub enum OidcError { #[error("mapping id or mapping name with idp id must be specified")] MappingIdOrNameWithIdp, + #[error("groups claim must be an array of strings")] + GroupsClaimNotArrayOfStrings, + #[error("request token error")] RequestToken { msg: String }, @@ -108,7 +111,7 @@ pub enum OidcError { }, /// Authentication expired. - #[error("Authentication expired")] + #[error("authentication expired")] AuthStateExpired, /// Cannot use OIDC attribute mapping for JWT login. @@ -161,6 +164,9 @@ impl From for KeystoneApiError { OidcError::MappingIdOrNameWithIdp => { KeystoneApiError::BadRequest("Federated authentication requires mapping being specified in the payload either with ID or name with identity provider id.".to_string()) } + OidcError::GroupsClaimNotArrayOfStrings => { + KeystoneApiError::BadRequest("Groups claim must be an array of strings representing group names.".to_string()) + } OidcError::RequestToken { msg } => { KeystoneApiError::BadRequest(format!("Error exchanging authorization code for the authorization token: {msg}")) } diff --git a/src/api/v4/federation/jwt.rs b/src/api/v4/federation/jwt.rs index 0c31cee5..bede3b05 100644 --- a/src/api/v4/federation/jwt.rs +++ b/src/api/v4/federation/jwt.rs @@ -22,9 +22,6 @@ use axum::{ http::header::AUTHORIZATION, response::IntoResponse, }; -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use std::collections::HashMap; use std::str::FromStr; use tracing::warn; use utoipa_axum::{router::OpenApiRouter, routes}; @@ -34,26 +31,20 @@ use openidconnect::core::{ CoreJwsSigningAlgorithm, CoreProviderMetadata, }; use openidconnect::reqwest; -use openidconnect::{ - AdditionalClaims, Client, ClientId, IdToken, IdTokenClaims, IssuerUrl, JsonWebKeySet, - JsonWebKeySetUrl, Nonce, -}; +use openidconnect::{Client, ClientId, IdToken, IssuerUrl, JsonWebKeySet, JsonWebKeySetUrl, Nonce}; -use crate::api::common::find_project_from_scope; use crate::api::v4::auth::token::types::{ Token as ApiResponseToken, TokenResponse as KeystoneTokenResponse, }; use crate::api::v4::federation::error::OidcError; use crate::api::v4::federation::types::*; use crate::api::{Catalog, error::KeystoneApiError}; -use crate::auth::{AuthenticatedInfo, AuthenticationError, AuthzInfo}; +use crate::auth::{AuthenticatedInfo, AuthenticationError}; use crate::catalog::CatalogApi; use crate::federation::FederationApi; use crate::federation::types::{ MappingListParameters as ProviderMappingListParameters, MappingType as ProviderMappingType, Project as ProviderProject, Scope as ProviderScope, - identity_provider::IdentityProvider as ProviderIdentityProvider, - mapping::Mapping as ProviderMapping, }; use crate::identity::IdentityApi; use crate::identity::error::IdentityProviderError; @@ -61,14 +52,12 @@ use crate::identity::types::{FederationBuilder, FederationProtocol, UserCreateBu use crate::keystone::ServiceState; use crate::token::TokenApi; +use super::common::{get_authz_info, map_user_data, validate_bound_claims}; + pub(super) fn openapi_router() -> OpenApiRouter { OpenApiRouter::new().routes(routes!(login)) } -#[derive(Debug, Deserialize, Serialize)] -struct AllOtherClaims(HashMap); -impl AdditionalClaims for AllOtherClaims {} - type FullIdToken = IdToken< AllOtherClaims, CoreGenderClaim, @@ -76,33 +65,6 @@ type FullIdToken = IdToken< CoreJwsSigningAlgorithm, >; -/// Prepare the proper scope. -/// -/// # Arguments -/// * `state`: The service state -/// * `scope`: The scope to extract the AuthZ information from -/// -/// # Returns -/// * `AuthzInfo`: The AuthZ information -/// * `KeystoneApiError`: An error if the scope is not valid -async fn get_authz_info( - state: &ServiceState, - scope: Option, -) -> Result { - let authz_info = match scope { - Some(ProviderScope::Project(scope)) => { - if let Some(project) = find_project_from_scope(state, &scope.into()).await? { - AuthzInfo::Project(project) - } else { - return Err(KeystoneApiError::Unauthorized); - } - } - _ => AuthzInfo::Unscoped, - }; - authz_info.validate()?; - Ok(authz_info) -} - /// Authentication using the JWT. /// /// This operation allows user to exchange the JWT issued by the trusted identity provider for the @@ -265,6 +227,7 @@ pub async fn login( .map_err(OidcError::from)?; let claims_as_json = serde_json::to_value(&claims)?; + tracing::debug!("Claims data {claims_as_json}"); validate_bound_claims(&mapping, &claims, &claims_as_json)?; let mapped_user_data = map_user_data(&state, &idp, &mapping, &claims_as_json).await?; @@ -319,12 +282,16 @@ pub async fn login( // TODO: detect scope from the mapping or claims let authz_info = get_authz_info( &state, - mapping.token_project_id.as_ref().map(|token_project_id| { - ProviderScope::Project(ProviderProject { - id: Some(token_project_id.to_string()), - ..Default::default() + mapping + .token_project_id + .as_ref() + .map(|token_project_id| { + ProviderScope::Project(ProviderProject { + id: Some(token_project_id.to_string()), + ..Default::default() + }) }) - }), + .as_ref(), ) .await?; @@ -363,135 +330,3 @@ pub async fn login( ) .into_response()) } - -/// Validate bound claims in the token -/// -/// # Arguments -/// -/// * `mapping` - The mapping to validate against -/// * `claims` - The claims to validate -/// * `claims_as_json` - The claims as json to validate -/// -/// # Returns -/// -/// * `Result<(), OidcError>` -fn validate_bound_claims( - mapping: &ProviderMapping, - claims: &IdTokenClaims, - claims_as_json: &Value, -) -> Result<(), OidcError> { - if let Some(bound_subject) = &mapping.bound_subject { - if bound_subject != claims.subject().as_str() { - return Err(OidcError::BoundSubjectMismatch { - expected: bound_subject.to_string(), - found: claims.subject().as_str().into(), - }); - } - } - if let Some(bound_audiences) = &mapping.bound_audiences { - let mut bound_audiences_match: bool = false; - for claim_audience in claims.audiences() { - if bound_audiences.iter().any(|x| x == claim_audience.as_str()) { - bound_audiences_match = true; - } - } - if !bound_audiences_match { - return Err(OidcError::BoundAudiencesMismatch { - expected: bound_audiences.join(","), - found: claims - .audiences() - .iter() - .map(|x| x.as_str()) - .collect::>() - .join(","), - }); - } - } - if let Some(bound_claims) = &mapping.bound_claims { - if let Some(required_claims) = bound_claims.as_object() { - for (claim, value) in required_claims.iter() { - if !claims_as_json - .get(claim) - .map(|x| x == value) - .is_some_and(|val| val) - { - return Err(OidcError::BoundClaimsMismatch { - claim: claim.to_string(), - expected: value.to_string(), - found: claims_as_json - .get(claim) - .map(|x| x.to_string()) - .unwrap_or_default(), - }); - } - } - } - } - Ok(()) -} - -/// Map the user data using the referred mapping -/// -/// # Arguments -/// * `idp` - The identity provider -/// * `mapping` - The mapping to use -/// * `claims_as_json` - The claims as json -/// -/// # Returns -/// The mapped user data -async fn map_user_data( - state: &ServiceState, - idp: &ProviderIdentityProvider, - mapping: &ProviderMapping, - claims_as_json: &Value, -) -> Result { - let mut builder = MappedUserDataBuilder::default(); - if let Some(token_user_id) = &mapping.token_user_id { - // TODO: How to check that the user belongs to the right domain) - if let Ok(Some(user)) = state - .provider - .get_identity_provider() - .get_user(&state.db, token_user_id) - .await - { - builder.unique_id(token_user_id.clone()); - builder.user_name(user.name.clone()); - } else { - return Err(OidcError::UserNotFound(token_user_id.clone()))?; - } - } else { - builder.unique_id( - claims_as_json - .get(&mapping.user_id_claim) - .and_then(|x| x.as_str()) - .ok_or_else(|| OidcError::UserIdClaimRequired(mapping.user_id_claim.clone()))? - .to_string(), - ); - - builder.user_name( - claims_as_json - .get(&mapping.user_name_claim) - .and_then(|x| x.as_str()) - .ok_or_else(|| OidcError::UserNameClaimRequired(mapping.user_name_claim.clone()))?, - ); - } - - builder.domain_id( - mapping - .domain_id - .as_ref() - .or(idp.domain_id.as_ref()) - .or(mapping - .domain_id_claim - .as_ref() - .and_then(|claim| { - claims_as_json - .get(claim) - .and_then(|x| x.as_str().map(|v| v.to_string())) - }) - .as_ref()) - .ok_or(OidcError::UserDomainUnbound)?, - ); - - Ok(builder.build()?) -} diff --git a/src/api/v4/federation/mod.rs b/src/api/v4/federation/mod.rs index 0cf179c1..cf830d50 100644 --- a/src/api/v4/federation/mod.rs +++ b/src/api/v4/federation/mod.rs @@ -22,6 +22,7 @@ use utoipa_axum::router::OpenApiRouter; use crate::keystone::ServiceState; pub mod auth; +mod common; pub mod error; pub mod identity_provider; pub mod jwt; diff --git a/src/api/v4/federation/oidc.rs b/src/api/v4/federation/oidc.rs index 3ceb33d1..aabe8fd2 100644 --- a/src/api/v4/federation/oidc.rs +++ b/src/api/v4/federation/oidc.rs @@ -16,77 +16,40 @@ use axum::{Json, debug_handler, extract::State, http::StatusCode, response::IntoResponse}; use chrono::Utc; use eyre::WrapErr; -use serde_json::Value; -use tracing::debug; +use std::collections::{HashMap, HashSet}; +use tracing::{debug, trace}; use url::Url; use utoipa_axum::{router::OpenApiRouter, routes}; -use openidconnect::core::{CoreGenderClaim, CoreProviderMetadata}; +use openidconnect::core::CoreProviderMetadata; use openidconnect::reqwest; use openidconnect::{ - AuthorizationCode, ClientId, ClientSecret, IdTokenClaims, IssuerUrl, Nonce, PkceCodeVerifier, - RedirectUrl, TokenResponse, + AuthorizationCode, ClientId, ClientSecret, IssuerUrl, Nonce, PkceCodeVerifier, RedirectUrl, + TokenResponse, }; -use crate::api::common::{find_project_from_scope, get_domain}; use crate::api::v4::auth::token::types::{ Token as ApiResponseToken, TokenResponse as KeystoneTokenResponse, }; use crate::api::v4::federation::error::OidcError; use crate::api::v4::federation::types::*; use crate::api::{Catalog, error::KeystoneApiError}; -use crate::auth::{AuthenticatedInfo, AuthenticationError, AuthzInfo}; +use crate::auth::{AuthenticatedInfo, AuthenticationError}; use crate::catalog::CatalogApi; use crate::federation::FederationApi; -use crate::federation::types::{ - Scope as ProviderScope, identity_provider::IdentityProvider as ProviderIdentityProvider, - mapping::Mapping as ProviderMapping, -}; use crate::identity::IdentityApi; use crate::identity::error::IdentityProviderError; use crate::identity::types::{FederationBuilder, FederationProtocol, UserCreateBuilder}; +use crate::identity::types::{Group, GroupCreate, GroupListParameters}; use crate::keystone::ServiceState; use crate::token::TokenApi; +use super::common::{get_authz_info, map_user_data, validate_bound_claims}; + pub(super) fn openapi_router() -> OpenApiRouter { OpenApiRouter::new().routes(routes!(callback)) } -/// Extract AuthZ information from the saved scope -/// -/// # Arguments -/// * `state`: The service state -/// * `scope`: The scope to extract the AuthZ information from -/// -/// # Returns -/// * `AuthzInfo`: The AuthZ information -/// * `KeystoneApiError`: An error if the scope is not valid -async fn get_authz_info( - state: &ServiceState, - scope: Option<&ProviderScope>, -) -> Result { - let authz_info = match scope { - Some(ProviderScope::Project(scope)) => { - if let Some(project) = find_project_from_scope(state, &scope.into()).await? { - AuthzInfo::Project(project) - } else { - return Err(KeystoneApiError::Unauthorized); - } - } - Some(ProviderScope::Domain(scope)) => { - if let Ok(domain) = get_domain(state, scope.id.as_ref(), scope.name.as_ref()).await { - AuthzInfo::Domain(domain) - } else { - return Err(KeystoneApiError::Unauthorized); - } - } - Some(ProviderScope::System(_scope)) => todo!(), - None => AuthzInfo::Unscoped, - }; - authz_info.validate()?; - Ok(authz_info) -} - /// Authentication callback. /// /// This operation allows user to exchange the authorization code retrieved from the identity @@ -209,9 +172,11 @@ pub async fn callback( } let claims_as_json = serde_json::to_value(claims)?; + debug!("Claims data {claims_as_json}"); validate_bound_claims(&mapping, claims, &claims_as_json)?; - let mapped_user_data = map_user_data(&idp, &mapping, &claims_as_json)?; + let mapped_user_data = map_user_data(&state, &idp, &mapping, &claims_as_json).await?; + debug!("Mapped user is {mapped_user_data:?}"); let user = if let Some(existing_user) = state .provider @@ -250,19 +215,78 @@ pub async fn callback( ) .await? }; + + if let Some(necessary_group_names) = mapped_user_data.group_names { + let current_domain_groups: HashMap = HashMap::from_iter( + state + .provider + .get_identity_provider() + .list_groups( + &state.db, + &GroupListParameters { + domain_id: Some(user.domain_id.clone()), + ..Default::default() + }, + ) + .await? + .into_iter() + .map(|group| (group.name, group.id)), + ); + let mut group_ids: HashSet = HashSet::new(); + for group_name in necessary_group_names { + group_ids.insert( + if let Some(grp_id) = current_domain_groups.get(&group_name) { + grp_id.clone() + } else { + state + .provider + .get_identity_provider() + .create_group( + &state.db, + GroupCreate { + domain_id: user.domain_id.clone(), + name: group_name.clone(), + ..Default::default() + }, + ) + .await? + .id + }, + ); + } + if !group_ids.is_empty() { + state + .provider + .get_identity_provider() + .set_user_groups( + &state.db, + &user.id, + HashSet::from_iter(group_ids.iter().map(|i| i.as_str())), + ) + .await?; + } + } + let user_groups: Vec = Vec::from_iter( + state + .provider + .get_identity_provider() + .list_groups_of_user(&state.db, &user.id) + .await?, + ); + let authed_info = AuthenticatedInfo::builder() .user_id(user.id.clone()) .user(user.clone()) - .methods(vec!["oidc".into()]) + .methods(vec!["openid".into()]) .idp_id(idp.id.clone()) .protocol_id("oidc".to_string()) + .user_groups(user_groups) .build() .map_err(AuthenticationError::from)?; authed_info.validate()?; - // TODO: Persist group memberships - let authz_info = get_authz_info(&state, auth_state.scope.as_ref()).await?; + trace!("Granting the scope: {:?}", authz_info); let mut token = state .provider @@ -287,7 +311,7 @@ pub async fn callback( .into(); api_token.token.catalog = Some(catalog); - debug!("response is {:?}", api_token); + trace!("Token response is {:?}", api_token); Ok(( StatusCode::OK, [( @@ -298,118 +322,3 @@ pub async fn callback( ) .into_response()) } - -/// Validate bound claims in the token -/// -/// # Arguments -/// -/// * `mapping` - The mapping to validate against -/// * `claims` - The claims to validate -/// * `claims_as_json` - The claims as json to validate -/// -/// # Returns -/// -/// * `Result<(), OidcError>` -fn validate_bound_claims( - mapping: &ProviderMapping, - claims: &IdTokenClaims, - claims_as_json: &Value, -) -> Result<(), OidcError> { - if let Some(bound_subject) = &mapping.bound_subject { - if bound_subject != claims.subject().as_str() { - return Err(OidcError::BoundSubjectMismatch { - expected: bound_subject.to_string(), - found: claims.subject().as_str().into(), - }); - } - } - if let Some(bound_audiences) = &mapping.bound_audiences { - let mut bound_audiences_match: bool = false; - for claim_audience in claims.audiences() { - if bound_audiences.iter().any(|x| x == claim_audience.as_str()) { - bound_audiences_match = true; - } - } - if !bound_audiences_match { - return Err(OidcError::BoundAudiencesMismatch { - expected: bound_audiences.join(","), - found: claims - .audiences() - .iter() - .map(|x| x.as_str()) - .collect::>() - .join(","), - }); - } - } - if let Some(bound_claims) = &mapping.bound_claims { - if let Some(required_claims) = bound_claims.as_object() { - for (claim, value) in required_claims.iter() { - if !claims_as_json - .get(claim) - .map(|x| x == value) - .is_some_and(|val| val) - { - return Err(OidcError::BoundClaimsMismatch { - claim: claim.to_string(), - expected: value.to_string(), - found: claims_as_json - .get(claim) - .map(|x| x.to_string()) - .unwrap_or_default(), - }); - } - } - } - } - Ok(()) -} - -/// Map the user data using the referred mapping -/// -/// # Arguments -/// * `idp` - The identity provider -/// * `mapping` - The mapping to use -/// * `claims_as_json` - The claims as json -/// -/// # Returns -/// The mapped user data -fn map_user_data( - idp: &ProviderIdentityProvider, - mapping: &ProviderMapping, - claims_as_json: &Value, -) -> Result { - let mut builder = MappedUserDataBuilder::default(); - builder.unique_id( - claims_as_json - .get(&mapping.user_id_claim) - .and_then(|x| x.as_str()) - .ok_or_else(|| OidcError::UserIdClaimRequired(mapping.user_id_claim.clone()))?, - ); - - builder.user_name( - claims_as_json - .get(&mapping.user_name_claim) - .and_then(|x| x.as_str()) - .ok_or_else(|| OidcError::UserNameClaimRequired(mapping.user_name_claim.clone()))?, - ); - - builder.domain_id( - mapping - .domain_id - .as_ref() - .or(idp.domain_id.as_ref()) - .or(mapping - .domain_id_claim - .as_ref() - .and_then(|claim| { - claims_as_json - .get(claim) - .and_then(|x| x.as_str().map(|v| v.to_string())) - }) - .as_ref()) - .ok_or(OidcError::UserDomainUnbound)?, - ); - - Ok(builder.build()?) -} diff --git a/src/api/v4/federation/types.rs b/src/api/v4/federation/types.rs index 1e1f1b96..ff7b6297 100644 --- a/src/api/v4/federation/types.rs +++ b/src/api/v4/federation/types.rs @@ -78,10 +78,12 @@ impl AdditionalClaims for AllOtherClaims {} pub(super) struct ExtraFields(HashMap); impl ExtraTokenFields for ExtraFields {} -#[derive(Builder, Clone)] +#[derive(Builder, Debug, Clone)] #[builder(setter(into))] pub(super) struct MappedUserData { pub(super) unique_id: String, pub(super) user_name: String, pub(super) domain_id: String, + #[builder(default)] + pub(super) group_names: Option>, } diff --git a/src/api/v4/user/mod.rs b/src/api/v4/user/mod.rs index bc5540ab..6c6992a8 100644 --- a/src/api/v4/user/mod.rs +++ b/src/api/v4/user/mod.rs @@ -171,7 +171,7 @@ async fn groups( let groups: Vec = state .provider .get_identity_provider() - .list_groups_for_user(&state.db, &user_id) + .list_groups_of_user(&state.db, &user_id) .await .map_err(KeystoneApiError::identity)? .into_iter() @@ -479,7 +479,7 @@ mod tests { async fn test_groups() { let mut identity_mock = MockIdentityProvider::default(); identity_mock - .expect_list_groups_for_user() + .expect_list_groups_of_user() .withf(|_: &DatabaseConnection, uid: &str| uid == "foo") .returning(|_, _| { Ok(vec![Group { diff --git a/src/assignment/mod.rs b/src/assignment/mod.rs index 9336971a..c57bb370 100644 --- a/src/assignment/mod.rs +++ b/src/assignment/mod.rs @@ -1,7 +1,6 @@ // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software @@ -163,7 +162,7 @@ impl AssignmentApi for AssignmentProvider { if let Some(uid) = ¶ms.user_id { let users = provider .get_identity_provider() - .list_groups_for_user(db, uid) + .list_groups_of_user(db, uid) .await?; actors.extend(users.into_iter().map(|x| x.id)); }; diff --git a/src/auth/mod.rs b/src/auth/mod.rs index b823fec6..e36c89fd 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -24,7 +24,7 @@ use serde::{Deserialize, Serialize}; use thiserror::Error; use tracing::{error, warn}; -use crate::identity::types as identity_provider_types; +use crate::identity::types::{Group, UserResponse}; use crate::resource::types::{Domain, Project}; #[derive(Error, Debug)] @@ -54,12 +54,16 @@ pub struct AuthenticatedInfo { /// Resolved user object #[builder(default)] - pub user: Option, + pub user: Option, /// Resolved user domain information #[builder(default)] pub user_domain: Option, + /// Resolved user object + #[builder(default)] + pub user_groups: Vec, + /// Authentication methods #[builder(default)] pub methods: Vec, diff --git a/src/federation/backends/sql/identity_provider/create.rs b/src/federation/backends/sql/identity_provider/create.rs index 990a1040..093dd60c 100644 --- a/src/federation/backends/sql/identity_provider/create.rs +++ b/src/federation/backends/sql/identity_provider/create.rs @@ -14,12 +14,13 @@ use sea_orm::DatabaseConnection; use sea_orm::entity::*; +use sea_orm::sea_query::OnConflict; use crate::config::Config; use crate::db::entity::{ federated_identity_provider as db_federated_identity_provider, federation_protocol as db_old_federation_protocol, - identity_provider as db_old_identity_provider, + identity_provider as db_old_identity_provider, mapping as db_old_mapping, }; use crate::federation::backends::error::FederationDatabaseError; use crate::federation::types::*; @@ -85,7 +86,7 @@ pub async fn create( // constraints working db_old_identity_provider::ActiveModel { id: Set(idp.id.clone()), - enabled: Set(false), + enabled: Set(true), description: Set(Some(idp.name.clone())), domain_id: Set(idp.domain_id.clone().unwrap_or("<>".into())), authorization_ttl: NotSet, @@ -96,7 +97,7 @@ pub async fn create( db_old_federation_protocol::ActiveModel { id: Set("oidc".into()), idp_id: Set(idp.id.clone()), - mapping_id: Set("<>".into()), + mapping_id: Set("dummy".into()), remote_id_attribute: NotSet, } .insert(db) @@ -105,18 +106,34 @@ pub async fn create( db_old_federation_protocol::ActiveModel { id: Set("jwt".into()), idp_id: Set(idp.id.clone()), - mapping_id: Set("<>".into()), + mapping_id: Set("dummy".into()), remote_id_attribute: NotSet, } .insert(db) .await?; + db_old_mapping::Entity::insert(db_old_mapping::ActiveModel { + id: Set("dummy".into()), + rules: Set(Some("\"[]\"".into())), + schema_version: Set("1.0".into()), + }) + .on_conflict( + OnConflict::column(db_old_mapping::Column::Id) + // Special handling for + // [mysql](https://docs.rs/sea-query/0.32.7/sea_query/query/struct.OnConflict.html#method.do_nothing_on) + .do_nothing_on([db_old_mapping::Column::Id]) + .to_owned(), + ) + .on_empty_do_nothing() + .exec(db) + .await?; + db_entry.try_into() } #[cfg(test)] mod tests { - use sea_orm::{DatabaseBackend, MockDatabase, Transaction}; + use sea_orm::{DatabaseBackend, MockDatabase, MockExecResult, Transaction}; use serde_json::json; use crate::config::Config; @@ -132,6 +149,10 @@ mod tests { .append_query_results([vec![get_old_idp_mock("1")]]) .append_query_results([vec![get_old_proto_mock("1")]]) .append_query_results([vec![get_old_proto_mock("2")]]) + .append_exec_results([MockExecResult { + rows_affected: 1, + ..Default::default() + }]) .into_connection(); let config = Config::default(); @@ -181,17 +202,22 @@ mod tests { Transaction::from_sql_and_values( DatabaseBackend::Postgres, r#"INSERT INTO "identity_provider" ("id", "enabled", "description", "domain_id") VALUES ($1, $2, $3, $4) RETURNING "id", "enabled", "description", "domain_id", "authorization_ttl""#, - ["1".into(), false.into(), "idp".into(), "foo_domain".into(),] + ["1".into(), true.into(), "idp".into(), "foo_domain".into(),] ), Transaction::from_sql_and_values( DatabaseBackend::Postgres, r#"INSERT INTO "federation_protocol" ("id", "idp_id", "mapping_id") VALUES ($1, $2, $3) RETURNING "id", "idp_id", "mapping_id", "remote_id_attribute""#, - ["oidc".into(), "1".into(), "<>".into()] + ["oidc".into(), "1".into(), "dummy".into()] ), Transaction::from_sql_and_values( DatabaseBackend::Postgres, r#"INSERT INTO "federation_protocol" ("id", "idp_id", "mapping_id") VALUES ($1, $2, $3) RETURNING "id", "idp_id", "mapping_id", "remote_id_attribute""#, - ["jwt".into(), "1".into(), "<>".into()] + ["jwt".into(), "1".into(), "dummy".into()] + ), + Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"INSERT INTO "mapping" ("id", "rules", "schema_version") VALUES ($1, $2, $3) ON CONFLICT ("id") DO NOTHING RETURNING "id""#, + ["dummy".into(), "\"[]\"".into(), "1.0".into()] ), ] ); diff --git a/src/federation/backends/sql/identity_provider/list.rs b/src/federation/backends/sql/identity_provider/list.rs index d718e1c7..7e5a11dd 100644 --- a/src/federation/backends/sql/identity_provider/list.rs +++ b/src/federation/backends/sql/identity_provider/list.rs @@ -36,8 +36,7 @@ pub async fn list( } if let Some(val) = ¶ms.domain_ids { - let filter = - db_federated_identity_provider::Column::DomainId.is_in(val.iter().flatten()); + let filter = db_federated_identity_provider::Column::DomainId.is_in(val.iter().flatten()); select = if val.contains(&None) { select.filter( Condition::any() diff --git a/src/identity/backends/error.rs b/src/identity/backends/error.rs index 5856386f..5a1c7cb5 100644 --- a/src/identity/backends/error.rs +++ b/src/identity/backends/error.rs @@ -48,15 +48,18 @@ pub enum IdentityDatabaseError { }, /// Conflict - #[error("{0}")] - Conflict(String), + #[error("{message}")] + Conflict { message: String, context: String }, /// SqlError - #[error("{0}")] - Sql(String), + #[error("{message}")] + Sql { message: String, context: String }, - #[error(transparent)] - Database { source: sea_orm::DbErr }, + #[error("Database error while {context}")] + Database { + source: sea_orm::DbErr, + context: String, + }, #[error("password hashing error")] PasswordHash { @@ -68,15 +71,32 @@ pub enum IdentityDatabaseError { UserIdOrNameWithDomain, } +/// Convert the DB error into the IdentityDatabaseError with the context information. +pub fn db_err(e: sea_orm::DbErr, context: &str) -> IdentityDatabaseError { + e.sql_err().map_or_else( + || IdentityDatabaseError::Database { + source: e, + context: context.to_string(), + }, + |err| match err { + SqlErr::UniqueConstraintViolation(descr) => IdentityDatabaseError::Conflict { + message: descr.to_string(), + context: context.to_string(), + }, + SqlErr::ForeignKeyConstraintViolation(descr) => IdentityDatabaseError::Conflict { + message: descr.to_string(), + context: context.to_string(), + }, + other => IdentityDatabaseError::Sql { + message: other.to_string(), + context: context.to_string(), + }, + }, + ) +} + impl From for IdentityDatabaseError { fn from(err: sea_orm::DbErr) -> Self { - err.sql_err().map_or_else( - || Self::Database { source: err }, - |err| match err { - SqlErr::UniqueConstraintViolation(descr) => Self::Conflict(descr), - SqlErr::ForeignKeyConstraintViolation(descr) => Self::Conflict(descr), - other => Self::Sql(other.to_string()), - }, - ) + db_err(err, "unknown") } } diff --git a/src/identity/backends/sql.rs b/src/identity/backends/sql.rs index 47bfb66b..9c5ec7e4 100644 --- a/src/identity/backends/sql.rs +++ b/src/identity/backends/sql.rs @@ -16,6 +16,7 @@ use async_trait::async_trait; use sea_orm::DatabaseConnection; use sea_orm::entity::*; use sea_orm::query::*; +use std::collections::HashSet; use webauthn_rs::prelude::{Passkey, PasskeyAuthentication, PasskeyRegistration}; mod common; @@ -26,6 +27,7 @@ mod passkey; mod passkey_state; mod password; mod user; +mod user_group; mod user_option; use super::super::types::*; @@ -38,7 +40,7 @@ use crate::db::entity::{ user as db_user, user_option as db_user_option, }; use crate::identity::IdentityProviderError; -use crate::identity::backends::error::IdentityDatabaseError; +use crate::identity::backends::error::{IdentityDatabaseError, db_err}; use crate::identity::password_hashing; #[derive(Clone, Debug, Default)] @@ -197,14 +199,68 @@ impl IdentityBackend for SqlBackend { Ok(group::delete(&self.config, db, group_id).await?) } - /// List groups a user is member of + /// List groups a user is member of. #[tracing::instrument(level = "debug", skip(self, db))] - async fn list_groups_for_user<'a>( + async fn list_groups_of_user<'a>( &self, db: &DatabaseConnection, user_id: &'a str, ) -> Result, IdentityProviderError> { - Ok(group::list_for_user(&self.config, db, user_id).await?) + Ok(user_group::list_user_groups(db, user_id).await?) + } + + /// Add the user into the group. + #[tracing::instrument(level = "debug", skip(self, db))] + async fn add_user_to_group<'a>( + &self, + db: &DatabaseConnection, + user_id: &'a str, + group_id: &'a str, + ) -> Result<(), IdentityProviderError> { + Ok(user_group::add_user_to_group(db, user_id, group_id).await?) + } + + /// Add user group membership relations. + #[tracing::instrument(level = "debug", skip(self, db))] + async fn add_users_to_groups<'a>( + &self, + db: &DatabaseConnection, + memberships: Vec<(&'a str, &'a str)>, + ) -> Result<(), IdentityProviderError> { + Ok(user_group::add_users_to_groups(db, memberships).await?) + } + + /// Remove the user from the group. + #[tracing::instrument(level = "debug", skip(self, db))] + async fn remove_user_from_group<'a>( + &self, + db: &DatabaseConnection, + user_id: &'a str, + group_id: &'a str, + ) -> Result<(), IdentityProviderError> { + Ok(user_group::remove_user_from_group(db, user_id, group_id).await?) + } + + /// Remove the user from multiple groups. + #[tracing::instrument(level = "debug", skip(self, db))] + async fn remove_user_from_groups<'a>( + &self, + db: &DatabaseConnection, + user_id: &'a str, + group_ids: HashSet<&'a str>, + ) -> Result<(), IdentityProviderError> { + Ok(user_group::remove_user_from_groups(db, user_id, group_ids).await?) + } + + /// Set group memberships of the user. + #[tracing::instrument(level = "debug", skip(self, db))] + async fn set_user_groups<'a>( + &self, + db: &DatabaseConnection, + user_id: &'a str, + group_ids: HashSet<&'a str>, + ) -> Result<(), IdentityProviderError> { + Ok(user_group::set_user_groups(db, user_id, group_ids).await?) } /// Create passkey diff --git a/src/identity/backends/sql/group.rs b/src/identity/backends/sql/group.rs index b5c2eb96..9d96abc0 100644 --- a/src/identity/backends/sql/group.rs +++ b/src/identity/backends/sql/group.rs @@ -12,85 +12,19 @@ // // SPDX-License-Identifier: Apache-2.0 -use sea_orm::DatabaseConnection; -use sea_orm::entity::*; -use sea_orm::query::*; -use serde_json::Value; -use serde_json::json; +use crate::db::entity::group; +use crate::identity::types::Group; +use serde_json::{Value, json}; -use crate::db::entity::{ - group, - prelude::{Group as DbGroup, UserGroupMembership as DbUserGroupMembership}, - user_group_membership, -}; -use crate::identity::Config; -use crate::identity::backends::sql::IdentityDatabaseError; -use crate::identity::types::{Group, GroupCreate, GroupListParameters}; +mod create; +mod delete; +mod get; +mod list; -pub async fn list( - _conf: &Config, - db: &DatabaseConnection, - params: &GroupListParameters, -) -> Result, IdentityDatabaseError> { - // Prepare basic selects - let mut group_select = DbGroup::find(); - - if let Some(domain_id) = ¶ms.domain_id { - group_select = group_select.filter(group::Column::DomainId.eq(domain_id)); - } - if let Some(name) = ¶ms.name { - group_select = group_select.filter(group::Column::Name.eq(name)); - } - - let db_groups: Vec = group_select.all(db).await?; - let results: Vec = db_groups.into_iter().map(Into::into).collect(); - - Ok(results) -} - -pub async fn get>( - _conf: &Config, - db: &DatabaseConnection, - group_id: S, -) -> Result, IdentityDatabaseError> { - Ok(DbGroup::find_by_id(group_id.as_ref()) - .one(db) - .await? - .map(Into::into)) -} - -pub async fn create( - _conf: &Config, - db: &DatabaseConnection, - group: GroupCreate, -) -> Result { - let entry = group::ActiveModel { - id: Set(group.id.clone().unwrap_or_default()), - domain_id: Set(group.domain_id.clone()), - name: Set(group.name.clone()), - description: Set(group.description.clone()), - extra: Set(Some(serde_json::to_string(&group.extra)?)), - }; - - let db_entry: group::Model = entry.insert(db).await?; - - Ok(db_entry.into()) -} - -pub async fn delete>( - _conf: &Config, - db: &DatabaseConnection, - group_id: S, -) -> Result<(), IdentityDatabaseError> { - let res = DbGroup::delete_by_id(group_id.as_ref()).exec(db).await?; - if res.rows_affected == 1 { - Ok(()) - } else { - Err(IdentityDatabaseError::GroupNotFound( - group_id.as_ref().to_string(), - )) - } -} +pub use create::create; +pub use delete::delete; +pub use get::get; +pub use list::list; impl From for Group { fn from(value: group::Model) -> Self { @@ -106,40 +40,12 @@ impl From for Group { } } -pub async fn list_for_user>( - _conf: &Config, - db: &DatabaseConnection, - user_id: S, -) -> Result, IdentityDatabaseError> { - let groups: Vec<(user_group_membership::Model, Vec)> = - DbUserGroupMembership::find() - .filter(user_group_membership::Column::UserId.eq(user_id.as_ref())) - .find_with_related(DbGroup) - .all(db) - .await?; - - let results: Vec = groups - .into_iter() - .flat_map(|(_, x)| x.into_iter()) - .map(Into::into) - .collect(); - Ok(results) -} - #[cfg(test)] mod tests { #![allow(clippy::derivable_impls)] - - use sea_orm::{DatabaseBackend, MockDatabase, MockExecResult, Transaction}; - use serde_json::json; - - use crate::db::entity::group; - use crate::identity::Config; - use crate::identity::types::group::GroupListParametersBuilder; - use super::*; - fn get_group_mock>(id: S) -> group::Model { + pub(super) fn get_group_mock>(id: S) -> group::Model { group::Model { id: id.as_ref().to_string(), domain_id: "foo_domain".into(), @@ -148,187 +54,4 @@ mod tests { extra: Some("{\"foo\": \"bar\"}".into()), } } - - #[tokio::test] - async fn test_list() { - // Create MockDatabase with mock query results - let db = MockDatabase::new(DatabaseBackend::Postgres) - .append_query_results([ - // First query result - select user itself - vec![get_group_mock("1")], - ]) - .into_connection(); - let config = Config::default(); - assert_eq!( - list(&config, &db, &GroupListParameters::default()) - .await - .unwrap(), - vec![Group { - id: "1".into(), - domain_id: "foo_domain".into(), - name: "group".into(), - description: Some("fake".into()), - extra: Some(json!({"foo": "bar"})) - }] - ); - - // Checking transaction log - assert_eq!( - db.into_transaction_log(), - [Transaction::from_sql_and_values( - DatabaseBackend::Postgres, - r#"SELECT "group"."id", "group"."domain_id", "group"."name", "group"."description", "group"."extra" FROM "group""#, - //["1".into(), 1u64.into()] - [] - ),] - ); - } - - #[tokio::test] - async fn test_list_with_filters() { - let db = MockDatabase::new(DatabaseBackend::Postgres) - .append_query_results([Vec::::new()]) - .into_connection(); - let config = Config::default(); - assert_eq!( - list( - &config, - &db, - &GroupListParametersBuilder::default() - .domain_id("d") - .name("n") - .build() - .unwrap() - ) - .await - .unwrap(), - vec![] - ); - - // Checking transaction log - assert_eq!( - db.into_transaction_log(), - [Transaction::from_sql_and_values( - DatabaseBackend::Postgres, - r#"SELECT "group"."id", "group"."domain_id", "group"."name", "group"."description", "group"."extra" FROM "group" WHERE "group"."domain_id" = $1 AND "group"."name" = $2"#, - ["d".into(), "n".into()] - ),] - ); - } - - #[tokio::test] - async fn test_get() { - // Create MockDatabase with mock query results - let db = MockDatabase::new(DatabaseBackend::Postgres) - .append_query_results([vec![get_group_mock("1")], vec![]]) - .into_connection(); - let config = Config::default(); - - assert_eq!( - get(&config, &db, "id").await.unwrap(), - Some(Group { - id: "1".into(), - domain_id: "foo_domain".into(), - name: "group".into(), - description: Some("fake".into()), - extra: Some(json!({"foo": "bar"})) - }) - ); - assert!(get(&config, &db, "missing").await.unwrap().is_none()); - - // Checking transaction log - assert_eq!( - db.into_transaction_log(), - [ - Transaction::from_sql_and_values( - DatabaseBackend::Postgres, - r#"SELECT "group"."id", "group"."domain_id", "group"."name", "group"."description", "group"."extra" FROM "group" WHERE "group"."id" = $1 LIMIT $2"#, - ["id".into(), 1u64.into()] - ), - Transaction::from_sql_and_values( - DatabaseBackend::Postgres, - r#"SELECT "group"."id", "group"."domain_id", "group"."name", "group"."description", "group"."extra" FROM "group" WHERE "group"."id" = $1 LIMIT $2"#, - ["missing".into(), 1u64.into()] - ), - ] - ); - } - - #[tokio::test] - async fn test_create() { - // Create MockDatabase with mock query results - let db = MockDatabase::new(DatabaseBackend::Postgres) - .append_query_results([vec![get_group_mock("1")], vec![]]) - .into_connection(); - let config = Config::default(); - - let req = GroupCreate { - id: Some("1".into()), - domain_id: "foo_domain".into(), - name: "group".into(), - description: Some("fake".into()), - extra: Some(json!({"foo": "bar"})), - }; - assert_eq!( - create(&config, &db, req).await.unwrap(), - get_group_mock("1").into() - ); - // Checking transaction log - assert_eq!( - db.into_transaction_log(), - [Transaction::from_sql_and_values( - DatabaseBackend::Postgres, - r#"INSERT INTO "group" ("id", "domain_id", "name", "description", "extra") VALUES ($1, $2, $3, $4, $5) RETURNING "id", "domain_id", "name", "description", "extra""#, - [ - "1".into(), - "foo_domain".into(), - "group".into(), - "fake".into(), - "{\"foo\":\"bar\"}".into() - ] - ),] - ); - } - - #[tokio::test] - async fn test_delete() { - // Create MockDatabase with mock query results - let db = MockDatabase::new(DatabaseBackend::Postgres) - .append_exec_results([MockExecResult { - rows_affected: 1, - ..Default::default() - }]) - .into_connection(); - let config = Config::default(); - - delete(&config, &db, "id").await.unwrap(); - // Checking transaction log - assert_eq!( - db.into_transaction_log(), - [Transaction::from_sql_and_values( - DatabaseBackend::Postgres, - r#"DELETE FROM "group" WHERE "group"."id" = $1"#, - ["id".into()] - ),] - ); - } - - #[tokio::test] - async fn test_list_for_user() { - let db = MockDatabase::new(DatabaseBackend::Postgres) - .append_query_results([vec![], vec![get_group_mock("1"), get_group_mock("2")]]) - .into_connection(); - let config = Config::default(); - assert_eq!(list_for_user(&config, &db, "foo").await.unwrap(), vec![]); - - // Checking transaction log - assert_eq!( - db.into_transaction_log(), - [Transaction::from_sql_and_values( - DatabaseBackend::Postgres, - r#"SELECT "user_group_membership"."user_id" AS "A_user_id", "user_group_membership"."group_id" AS "A_group_id", "group"."id" AS "B_id", "group"."domain_id" AS "B_domain_id", "group"."name" AS "B_name", "group"."description" AS "B_description", "group"."extra" AS "B_extra" FROM "user_group_membership" LEFT JOIN "group" ON "user_group_membership"."group_id" = "group"."id" WHERE "user_group_membership"."user_id" = $1 ORDER BY "user_group_membership"."user_id" ASC, "user_group_membership"."group_id" ASC"#, - ["foo".into()] - ),] - ); - } } diff --git a/src/identity/backends/sql/group/create.rs b/src/identity/backends/sql/group/create.rs new file mode 100644 index 00000000..58a8c0ec --- /dev/null +++ b/src/identity/backends/sql/group/create.rs @@ -0,0 +1,127 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +use sea_orm::DatabaseConnection; +use sea_orm::entity::*; +use serde_json::json; + +use crate::db::entity::group; +use crate::identity::Config; +use crate::identity::backends::sql::IdentityDatabaseError; +use crate::identity::types::{Group, GroupCreate}; + +pub async fn create( + _conf: &Config, + db: &DatabaseConnection, + group: GroupCreate, +) -> Result { + let entry = group::ActiveModel { + id: Set(group.id.clone().unwrap_or_default()), + domain_id: Set(group.domain_id.clone()), + name: Set(group.name.clone()), + description: Set(group.description.clone()), + extra: Set(Some(serde_json::to_string( + &group.extra.as_ref().or(Some(&json!({}))), + )?)), + }; + + let db_entry: group::Model = entry.insert(db).await?; + + Ok(db_entry.into()) +} + +#[cfg(test)] +mod tests { + #![allow(clippy::derivable_impls)] + + use sea_orm::{DatabaseBackend, MockDatabase, Transaction}; + use serde_json::json; + + use crate::identity::Config; + + use super::super::tests::get_group_mock; + use super::*; + + #[tokio::test] + async fn test_create() { + // Create MockDatabase with mock query results + let db = MockDatabase::new(DatabaseBackend::Postgres) + .append_query_results([vec![get_group_mock("1")], vec![]]) + .into_connection(); + let config = Config::default(); + + let req = GroupCreate { + id: Some("1".into()), + domain_id: "foo_domain".into(), + name: "group".into(), + description: Some("fake".into()), + extra: Some(json!({"foo": "bar"})), + }; + assert_eq!( + create(&config, &db, req).await.unwrap(), + get_group_mock("1").into() + ); + // Checking transaction log + assert_eq!( + db.into_transaction_log(), + [Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"INSERT INTO "group" ("id", "domain_id", "name", "description", "extra") VALUES ($1, $2, $3, $4, $5) RETURNING "id", "domain_id", "name", "description", "extra""#, + [ + "1".into(), + "foo_domain".into(), + "group".into(), + "fake".into(), + "{\"foo\":\"bar\"}".into() + ] + ),] + ); + } + + #[tokio::test] + async fn test_create_empty_extra() { + // Create MockDatabase with mock query results + let db = MockDatabase::new(DatabaseBackend::Postgres) + .append_query_results([vec![get_group_mock("1")], vec![]]) + .into_connection(); + let config = Config::default(); + + let req = GroupCreate { + id: Some("1".into()), + domain_id: "foo_domain".into(), + name: "group".into(), + description: Some("fake".into()), + extra: None, + }; + assert_eq!( + create(&config, &db, req).await.unwrap(), + get_group_mock("1").into() + ); + // Checking transaction log + assert_eq!( + db.into_transaction_log(), + [Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"INSERT INTO "group" ("id", "domain_id", "name", "description", "extra") VALUES ($1, $2, $3, $4, $5) RETURNING "id", "domain_id", "name", "description", "extra""#, + [ + "1".into(), + "foo_domain".into(), + "group".into(), + "fake".into(), + "{}".into() + ] + ),] + ); + } +} diff --git a/src/identity/backends/sql/group/delete.rs b/src/identity/backends/sql/group/delete.rs new file mode 100644 index 00000000..835b80fc --- /dev/null +++ b/src/identity/backends/sql/group/delete.rs @@ -0,0 +1,69 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +use sea_orm::DatabaseConnection; +use sea_orm::entity::*; + +use crate::db::entity::prelude::Group as DbGroup; +use crate::identity::Config; +use crate::identity::backends::sql::IdentityDatabaseError; + +pub async fn delete>( + _conf: &Config, + db: &DatabaseConnection, + group_id: S, +) -> Result<(), IdentityDatabaseError> { + let res = DbGroup::delete_by_id(group_id.as_ref()).exec(db).await?; + if res.rows_affected == 1 { + Ok(()) + } else { + Err(IdentityDatabaseError::GroupNotFound( + group_id.as_ref().to_string(), + )) + } +} + +#[cfg(test)] +mod tests { + #![allow(clippy::derivable_impls)] + + use sea_orm::{DatabaseBackend, MockDatabase, MockExecResult, Transaction}; + + use crate::identity::Config; + + use super::*; + + #[tokio::test] + async fn test_delete() { + // Create MockDatabase with mock query results + let db = MockDatabase::new(DatabaseBackend::Postgres) + .append_exec_results([MockExecResult { + rows_affected: 1, + ..Default::default() + }]) + .into_connection(); + let config = Config::default(); + + delete(&config, &db, "id").await.unwrap(); + // Checking transaction log + assert_eq!( + db.into_transaction_log(), + [Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"DELETE FROM "group" WHERE "group"."id" = $1"#, + ["id".into()] + ),] + ); + } +} diff --git a/src/identity/backends/sql/group/get.rs b/src/identity/backends/sql/group/get.rs new file mode 100644 index 00000000..5ca2edad --- /dev/null +++ b/src/identity/backends/sql/group/get.rs @@ -0,0 +1,83 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +use sea_orm::DatabaseConnection; +use sea_orm::entity::*; + +use crate::db::entity::prelude::Group as DbGroup; +use crate::identity::Config; +use crate::identity::backends::sql::IdentityDatabaseError; +use crate::identity::types::Group; + +pub async fn get>( + _conf: &Config, + db: &DatabaseConnection, + group_id: S, +) -> Result, IdentityDatabaseError> { + Ok(DbGroup::find_by_id(group_id.as_ref()) + .one(db) + .await? + .map(Into::into)) +} + +#[cfg(test)] +mod tests { + #![allow(clippy::derivable_impls)] + + use sea_orm::{DatabaseBackend, MockDatabase, Transaction}; + use serde_json::json; + + use crate::identity::Config; + + use super::super::tests::get_group_mock; + use super::*; + + #[tokio::test] + async fn test_get() { + // Create MockDatabase with mock query results + let db = MockDatabase::new(DatabaseBackend::Postgres) + .append_query_results([vec![get_group_mock("1")], vec![]]) + .into_connection(); + let config = Config::default(); + + assert_eq!( + get(&config, &db, "id").await.unwrap(), + Some(Group { + id: "1".into(), + domain_id: "foo_domain".into(), + name: "group".into(), + description: Some("fake".into()), + extra: Some(json!({"foo": "bar"})) + }) + ); + assert!(get(&config, &db, "missing").await.unwrap().is_none()); + + // Checking transaction log + assert_eq!( + db.into_transaction_log(), + [ + Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"SELECT "group"."id", "group"."domain_id", "group"."name", "group"."description", "group"."extra" FROM "group" WHERE "group"."id" = $1 LIMIT $2"#, + ["id".into(), 1u64.into()] + ), + Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"SELECT "group"."id", "group"."domain_id", "group"."name", "group"."description", "group"."extra" FROM "group" WHERE "group"."id" = $1 LIMIT $2"#, + ["missing".into(), 1u64.into()] + ), + ] + ); + } +} diff --git a/src/identity/backends/sql/group/list.rs b/src/identity/backends/sql/group/list.rs new file mode 100644 index 00000000..306e1c78 --- /dev/null +++ b/src/identity/backends/sql/group/list.rs @@ -0,0 +1,125 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +use sea_orm::DatabaseConnection; +use sea_orm::entity::*; +use sea_orm::query::*; + +use crate::db::entity::{group, prelude::Group as DbGroup}; +use crate::identity::Config; +use crate::identity::backends::sql::IdentityDatabaseError; +use crate::identity::types::{Group, GroupListParameters}; + +pub async fn list( + _conf: &Config, + db: &DatabaseConnection, + params: &GroupListParameters, +) -> Result, IdentityDatabaseError> { + // Prepare basic selects + let mut group_select = DbGroup::find(); + + if let Some(domain_id) = ¶ms.domain_id { + group_select = group_select.filter(group::Column::DomainId.eq(domain_id)); + } + if let Some(name) = ¶ms.name { + group_select = group_select.filter(group::Column::Name.eq(name)); + } + + let db_groups: Vec = group_select.all(db).await?; + let results: Vec = db_groups.into_iter().map(Into::into).collect(); + + Ok(results) +} + +#[cfg(test)] +mod tests { + #![allow(clippy::derivable_impls)] + + use sea_orm::{DatabaseBackend, MockDatabase, Transaction}; + use serde_json::json; + + use crate::db::entity::group; + use crate::identity::Config; + use crate::identity::types::group::GroupListParametersBuilder; + + use super::super::tests::get_group_mock; + use super::*; + + #[tokio::test] + async fn test_list() { + // Create MockDatabase with mock query results + let db = MockDatabase::new(DatabaseBackend::Postgres) + .append_query_results([ + // First query result - select user itself + vec![get_group_mock("1")], + ]) + .into_connection(); + let config = Config::default(); + assert_eq!( + list(&config, &db, &GroupListParameters::default()) + .await + .unwrap(), + vec![Group { + id: "1".into(), + domain_id: "foo_domain".into(), + name: "group".into(), + description: Some("fake".into()), + extra: Some(json!({"foo": "bar"})) + }] + ); + + // Checking transaction log + assert_eq!( + db.into_transaction_log(), + [Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"SELECT "group"."id", "group"."domain_id", "group"."name", "group"."description", "group"."extra" FROM "group""#, + //["1".into(), 1u64.into()] + [] + ),] + ); + } + + #[tokio::test] + async fn test_list_with_filters() { + let db = MockDatabase::new(DatabaseBackend::Postgres) + .append_query_results([Vec::::new()]) + .into_connection(); + let config = Config::default(); + assert_eq!( + list( + &config, + &db, + &GroupListParametersBuilder::default() + .domain_id("d") + .name("n") + .build() + .unwrap() + ) + .await + .unwrap(), + vec![] + ); + + // Checking transaction log + assert_eq!( + db.into_transaction_log(), + [Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"SELECT "group"."id", "group"."domain_id", "group"."name", "group"."description", "group"."extra" FROM "group" WHERE "group"."domain_id" = $1 AND "group"."name" = $2"#, + ["d".into(), "n".into()] + ),] + ); + } +} diff --git a/src/identity/backends/sql/user_group.rs b/src/identity/backends/sql/user_group.rs new file mode 100644 index 00000000..b0065c7e --- /dev/null +++ b/src/identity/backends/sql/user_group.rs @@ -0,0 +1,245 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +use sea_orm::DatabaseConnection; +use sea_orm::entity::*; +use sea_orm::query::*; +use std::collections::BTreeSet; + +use crate::db::entity::{prelude::UserGroupMembership, user_group_membership}; +use crate::identity::backends::sql::{IdentityDatabaseError, db_err}; + +mod add; +mod list; +mod remove; +pub use add::{add_user_to_group, add_users_to_groups}; +pub use list::list_user_groups; +pub use remove::{remove_user_from_group, remove_user_from_groups}; + +/// Set user group memberships. +/// +/// Add user to the groups it should be in and remove from the groups where the user is currently +/// member of, but should not be. This is only incremental operation and is not deleting group +/// membership where the user should stay. +pub async fn set_user_groups( + db: &DatabaseConnection, + user_id: U, + group_ids: I, +) -> Result<(), IdentityDatabaseError> +where + I: IntoIterator, + U: AsRef, + G: AsRef, +{ + // Use BTreeSet to keep order for helping tests + let expected_groups: BTreeSet = + BTreeSet::from_iter(group_ids.into_iter().map(|group| group.as_ref().into())); + let current_groups: BTreeSet = BTreeSet::from_iter( + UserGroupMembership::find() + .filter(user_group_membership::Column::UserId.eq(user_id.as_ref())) + .all(db) + .await + .map_err(|e| db_err(e, "selecting group memberships of the user"))? + .into_iter() + .map(|item| item.group_id), + ); + + let groups_to_remove: BTreeSet = current_groups + .iter() + .filter(|&item| !expected_groups.contains(item)) + .cloned() + .collect(); + + let groups_to_add: BTreeSet = expected_groups + .iter() + .filter(|&item| !current_groups.contains(item)) + .cloned() + .collect(); + + if !groups_to_remove.is_empty() { + remove_user_from_groups(db, user_id.as_ref(), groups_to_remove).await?; + } + if !groups_to_add.is_empty() { + add_users_to_groups( + db, + groups_to_add + .into_iter() + .map(|group| (user_id.as_ref(), group.clone())), + ) + .await?; + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use sea_orm::{DatabaseBackend, MockDatabase, MockExecResult, Transaction}; + + use super::*; + + fn get_data_mock, G: AsRef>( + user_id: U, + group_id: G, + ) -> user_group_membership::Model { + user_group_membership::Model { + user_id: user_id.as_ref().to_string(), + group_id: group_id.as_ref().to_string(), + } + } + + #[tokio::test] + async fn test_add_and_remove() { + let db = MockDatabase::new(DatabaseBackend::Postgres) + .append_query_results([vec![ + get_data_mock("u1", "g1"), + get_data_mock("u1", "g2"), + get_data_mock("u1", "g3"), + get_data_mock("u1", "g4"), + ]]) + .append_exec_results([MockExecResult { + rows_affected: 1, + ..Default::default() + }]) + .append_exec_results([MockExecResult { + rows_affected: 1, + ..Default::default() + }]) + .into_connection(); + + set_user_groups(&db, "u1", vec!["g2", "g4", "g5", "g0"]) + .await + .unwrap(); + + // Checking transaction log + assert_eq!( + db.into_transaction_log(), + [ + Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"SELECT "user_group_membership"."user_id", "user_group_membership"."group_id" FROM "user_group_membership" WHERE "user_group_membership"."user_id" = $1"#, + ["u1".into()] + ), + Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"DELETE FROM "user_group_membership" WHERE "user_group_membership"."user_id" = $1 AND "user_group_membership"."group_id" IN ($2, $3)"#, + ["u1".into(), "g1".into(), "g3".into()] + ), + Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"INSERT INTO "user_group_membership" ("user_id", "group_id") VALUES ($1, $2), ($3, $4) RETURNING "user_id", "group_id""#, + ["u1".into(), "g0".into(), "u1".into(), "g5".into()] + ), + ] + ); + } + + #[tokio::test] + async fn test_only_add() { + let db = MockDatabase::new(DatabaseBackend::Postgres) + .append_query_results([vec![ + get_data_mock("u1", "g1"), + get_data_mock("u1", "g2"), + get_data_mock("u1", "g3"), + get_data_mock("u1", "g4"), + ]]) + .append_exec_results([MockExecResult { + rows_affected: 1, + ..Default::default() + }]) + .into_connection(); + + set_user_groups(&db, "u1", vec!["g1", "g2", "g3", "g4", "g5"]) + .await + .unwrap(); + + // Checking transaction log + assert_eq!( + db.into_transaction_log(), + [ + Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"SELECT "user_group_membership"."user_id", "user_group_membership"."group_id" FROM "user_group_membership" WHERE "user_group_membership"."user_id" = $1"#, + ["u1".into()] + ), + Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"INSERT INTO "user_group_membership" ("user_id", "group_id") VALUES ($1, $2) RETURNING "user_id", "group_id""#, + ["u1".into(), "g5".into()] + ), + ] + ); + } + + #[tokio::test] + async fn test_only_delete() { + let db = MockDatabase::new(DatabaseBackend::Postgres) + .append_query_results([vec![ + get_data_mock("u1", "g1"), + get_data_mock("u1", "g2"), + get_data_mock("u1", "g3"), + get_data_mock("u1", "g4"), + ]]) + .append_exec_results([MockExecResult { + rows_affected: 1, + ..Default::default() + }]) + .into_connection(); + + set_user_groups(&db, "u1", vec!["g2", "g4"]).await.unwrap(); + + // Checking transaction log + assert_eq!( + db.into_transaction_log(), + [ + Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"SELECT "user_group_membership"."user_id", "user_group_membership"."group_id" FROM "user_group_membership" WHERE "user_group_membership"."user_id" = $1"#, + ["u1".into()] + ), + Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"DELETE FROM "user_group_membership" WHERE "user_group_membership"."user_id" = $1 AND "user_group_membership"."group_id" IN ($2, $3)"#, + ["u1".into(), "g1".into(), "g3".into()] + ), + ] + ); + } + + #[tokio::test] + async fn test_no_change() { + let db = MockDatabase::new(DatabaseBackend::Postgres) + .append_query_results([vec![ + get_data_mock("u1", "g1"), + get_data_mock("u1", "g2"), + get_data_mock("u1", "g3"), + get_data_mock("u1", "g4"), + ]]) + .into_connection(); + + set_user_groups(&db, "u1", vec!["g1", "g2", "g3", "g4"]) + .await + .unwrap(); + + // Checking transaction log + assert_eq!( + db.into_transaction_log(), + [Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"SELECT "user_group_membership"."user_id", "user_group_membership"."group_id" FROM "user_group_membership" WHERE "user_group_membership"."user_id" = $1"#, + ["u1".into()] + ),] + ); + } +} diff --git a/src/identity/backends/sql/user_group/add.rs b/src/identity/backends/sql/user_group/add.rs new file mode 100644 index 00000000..33338bc4 --- /dev/null +++ b/src/identity/backends/sql/user_group/add.rs @@ -0,0 +1,129 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +use sea_orm::DatabaseConnection; +use sea_orm::entity::*; + +use crate::db::entity::{prelude::UserGroupMembership, user_group_membership}; +use crate::identity::backends::sql::{IdentityDatabaseError, db_err}; + +/// Add the user to the single group. +pub async fn add_user_to_group, G: AsRef>( + db: &DatabaseConnection, + user_id: U, + group_id: G, +) -> Result<(), IdentityDatabaseError> { + let entry = user_group_membership::ActiveModel { + user_id: Set(user_id.as_ref().into()), + group_id: Set(group_id.as_ref().into()), + }; + + entry + .insert(db) + .await + .map_err(|e| db_err(e, "adding user to single group"))?; + + Ok(()) +} + +/// Add group user relations as speified by the tuples (user_id, group_id) iterator. +pub async fn add_users_to_groups( + db: &DatabaseConnection, + iter: I, +) -> Result<(), IdentityDatabaseError> +where + I: IntoIterator, + U: AsRef, + G: AsRef, +{ + UserGroupMembership::insert_many(iter.into_iter().map(|(u, g)| { + user_group_membership::ActiveModel { + user_id: Set(u.as_ref().into()), + group_id: Set(g.as_ref().into()), + } + })) + .on_empty_do_nothing() + .exec(db) + .await + .map_err(|e| db_err(e, "adding user to groups"))?; + + Ok(()) +} + +#[cfg(test)] +mod tests { + use sea_orm::{DatabaseBackend, MockDatabase, MockExecResult, Transaction}; + + use super::*; + + fn get_mock, G: AsRef>( + user_id: U, + group_id: G, + ) -> user_group_membership::Model { + user_group_membership::Model { + user_id: user_id.as_ref().into(), + group_id: group_id.as_ref().into(), + } + } + + #[tokio::test] + async fn test_create() { + // Create MockDatabase with mock query results + let db = MockDatabase::new(DatabaseBackend::Postgres) + .append_query_results([vec![get_mock("u1", "g1")]]) + .into_connection(); + + assert!(add_user_to_group(&db, "u1", "g1").await.is_ok()); + // Checking transaction log + assert_eq!( + db.into_transaction_log(), + [Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"INSERT INTO "user_group_membership" ("user_id", "group_id") VALUES ($1, $2) RETURNING "user_id", "group_id""#, + ["u1".into(), "g1".into(),] + ),] + ); + } + #[tokio::test] + async fn test_bulk() { + // Create MockDatabase with mock query results + let db = MockDatabase::new(DatabaseBackend::Postgres) + .append_exec_results([MockExecResult { + rows_affected: 1, + ..Default::default() + }]) + .into_connection(); + + add_users_to_groups(&db, vec![("u1", "g1"), ("u1", "g2"), ("u2", "g2")]) + .await + .unwrap(); + + // Checking transaction log + assert_eq!( + db.into_transaction_log(), + [Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"INSERT INTO "user_group_membership" ("user_id", "group_id") VALUES ($1, $2), ($3, $4), ($5, $6) RETURNING "user_id", "group_id""#, + [ + "u1".into(), + "g1".into(), + "u1".into(), + "g2".into(), + "u2".into(), + "g2".into() + ] + ),] + ); + } +} diff --git a/src/identity/backends/sql/user_group/list.rs b/src/identity/backends/sql/user_group/list.rs new file mode 100644 index 00000000..c9cf18ec --- /dev/null +++ b/src/identity/backends/sql/user_group/list.rs @@ -0,0 +1,81 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +use sea_orm::DatabaseConnection; +use sea_orm::entity::*; +use sea_orm::query::*; + +use crate::db::entity::{ + group, + prelude::{Group as DbGroup, UserGroupMembership as DbUserGroupMembership}, + user_group_membership, +}; +use crate::identity::backends::sql::{IdentityDatabaseError, db_err}; +use crate::identity::types::Group; + +/// List all groups the user is member of. +pub async fn list_user_groups>( + db: &DatabaseConnection, + user_id: S, +) -> Result, IdentityDatabaseError> { + let groups: Vec<(user_group_membership::Model, Vec)> = + DbUserGroupMembership::find() + .filter(user_group_membership::Column::UserId.eq(user_id.as_ref())) + .find_with_related(DbGroup) + .all(db) + .await + .map_err(|e| db_err(e, "listing groups the user is currently in"))?; + + let results: Vec = groups + .into_iter() + .flat_map(|(_, x)| x.into_iter()) + .map(Into::into) + .collect(); + Ok(results) +} + +#[cfg(test)] +mod tests { + use sea_orm::{DatabaseBackend, MockDatabase, Transaction}; + + use super::*; + + fn get_group_mock>(id: S) -> group::Model { + group::Model { + id: id.as_ref().to_string(), + domain_id: "foo_domain".into(), + name: "group".into(), + description: Some("fake".into()), + extra: Some("{\"foo\": \"bar\"}".into()), + } + } + + #[tokio::test] + async fn test_list() { + let db = MockDatabase::new(DatabaseBackend::Postgres) + .append_query_results([vec![], vec![get_group_mock("1"), get_group_mock("2")]]) + .into_connection(); + assert_eq!(list_user_groups(&db, "foo").await.unwrap(), vec![]); + + // Checking transaction log + assert_eq!( + db.into_transaction_log(), + [Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"SELECT "user_group_membership"."user_id" AS "A_user_id", "user_group_membership"."group_id" AS "A_group_id", "group"."id" AS "B_id", "group"."domain_id" AS "B_domain_id", "group"."name" AS "B_name", "group"."description" AS "B_description", "group"."extra" AS "B_extra" FROM "user_group_membership" LEFT JOIN "group" ON "user_group_membership"."group_id" = "group"."id" WHERE "user_group_membership"."user_id" = $1 ORDER BY "user_group_membership"."user_id" ASC, "user_group_membership"."group_id" ASC"#, + ["foo".into()] + ),] + ); + } +} diff --git a/src/identity/backends/sql/user_group/remove.rs b/src/identity/backends/sql/user_group/remove.rs new file mode 100644 index 00000000..a8eeb258 --- /dev/null +++ b/src/identity/backends/sql/user_group/remove.rs @@ -0,0 +1,114 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +use sea_orm::DatabaseConnection; +use sea_orm::entity::*; +use sea_orm::query::*; + +use crate::db::entity::{prelude::UserGroupMembership, user_group_membership}; +use crate::identity::backends::sql::{IdentityDatabaseError, db_err}; + +/// Remove the user from the group. +pub async fn remove_user_from_group, G: AsRef>( + db: &DatabaseConnection, + user_id: U, + group_id: G, +) -> Result<(), IdentityDatabaseError> { + UserGroupMembership::delete_by_id((user_id.as_ref().into(), group_id.as_ref().into())) + .exec(db) + .await + .map_err(|e| db_err(e, "Deleting user<->group membership relation"))?; + + Ok(()) +} + +/// Remove the user from multiple groups. +pub async fn remove_user_from_groups( + db: &DatabaseConnection, + user_id: U, + group_ids: I, +) -> Result<(), IdentityDatabaseError> +where + I: IntoIterator, + U: AsRef, + G: AsRef, +{ + UserGroupMembership::delete_many() + .filter( + Condition::all() + .add(user_group_membership::Column::UserId.eq(user_id.as_ref())) + .add( + user_group_membership::Column::GroupId + .is_in(group_ids.into_iter().map(|grp| grp.as_ref().to_string())), + ), + ) + .exec(db) + .await + .map_err(|e| db_err(e, "Deleting user<->group membership relations"))?; + + Ok(()) +} + +#[cfg(test)] +mod tests { + use sea_orm::{DatabaseBackend, MockDatabase, MockExecResult, Transaction}; + + use super::*; + + #[tokio::test] + async fn test_remove_single() { + // Create MockDatabase with mock query results + let db = MockDatabase::new(DatabaseBackend::Postgres) + .append_exec_results([MockExecResult { + rows_affected: 1, + ..Default::default() + }]) + .into_connection(); + + remove_user_from_group(&db, "u1", "g1").await.unwrap(); + // Checking transaction log + assert_eq!( + db.into_transaction_log(), + [Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"DELETE FROM "user_group_membership" WHERE "user_group_membership"."user_id" = $1 AND "user_group_membership"."group_id" = $2"#, + ["u1".into(), "g1".into(),] + ),] + ); + } + + #[tokio::test] + async fn test_remove_from_groups() { + // Create MockDatabase with mock query results + let db = MockDatabase::new(DatabaseBackend::Postgres) + .append_exec_results([MockExecResult { + rows_affected: 1, + ..Default::default() + }]) + .into_connection(); + + remove_user_from_groups(&db, "u1", vec!["g1", "g2", "g3"]) + .await + .unwrap(); + // Checking transaction log + assert_eq!( + db.into_transaction_log(), + [Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"DELETE FROM "user_group_membership" WHERE "user_group_membership"."user_id" = $1 AND "user_group_membership"."group_id" IN ($2, $3, $4)"#, + ["u1".into(), "g1".into(), "g2".into(), "g3".into()] + ),] + ); + } +} diff --git a/src/identity/error.rs b/src/identity/error.rs index 189a8971..168fe444 100644 --- a/src/identity/error.rs +++ b/src/identity/error.rs @@ -98,7 +98,7 @@ pub enum IdentityProviderError { impl From for IdentityProviderError { fn from(source: IdentityDatabaseError) -> Self { match source { - IdentityDatabaseError::Conflict(x) => Self::Conflict(x), + IdentityDatabaseError::Conflict { message, .. } => Self::Conflict(message), IdentityDatabaseError::UserNotFound(x) => Self::UserNotFound(x), IdentityDatabaseError::GroupNotFound(x) => Self::GroupNotFound(x), IdentityDatabaseError::Serde { source } => Self::Serde { source }, diff --git a/src/identity/mod.rs b/src/identity/mod.rs index 83fd1cb3..fe95b541 100644 --- a/src/identity/mod.rs +++ b/src/identity/mod.rs @@ -16,6 +16,7 @@ use async_trait::async_trait; #[cfg(test)] use mockall::mock; use sea_orm::DatabaseConnection; +use std::collections::HashSet; use uuid::Uuid; use webauthn_rs::prelude::{Passkey, PasskeyAuthentication, PasskeyRegistration}; @@ -105,19 +106,59 @@ pub trait IdentityApi: Send + Sync + Clone { group_id: &'a str, ) -> Result<(), IdentityProviderError>; - async fn list_groups_for_user<'a>( + /// List groups the user is a member of. + async fn list_groups_of_user<'a>( &self, db: &DatabaseConnection, user_id: &'a str, ) -> Result, IdentityProviderError>; + /// Add the user to the single group. + async fn add_user_to_group<'a>( + &self, + db: &DatabaseConnection, + user_id: &'a str, + group_id: &'a str, + ) -> Result<(), IdentityProviderError>; + + /// Add user group memberships as specified by (uid, gid) tuples. + async fn add_users_to_groups<'a>( + &self, + db: &DatabaseConnection, + memberships: Vec<(&'a str, &'a str)>, + ) -> Result<(), IdentityProviderError>; + + /// Remove the user from the single group. + async fn remove_user_from_group<'a>( + &self, + db: &DatabaseConnection, + user_id: &'a str, + group_id: &'a str, + ) -> Result<(), IdentityProviderError>; + + /// Remove the user from specified groups. + async fn remove_user_from_groups<'a>( + &self, + db: &DatabaseConnection, + user_id: &'a str, + group_ids: HashSet<&'a str>, + ) -> Result<(), IdentityProviderError>; + + /// Set group memberships of the user. + async fn set_user_groups<'a>( + &self, + db: &DatabaseConnection, + user_id: &'a str, + group_ids: HashSet<&'a str>, + ) -> Result<(), IdentityProviderError>; + async fn list_user_passkeys<'a>( &self, db: &DatabaseConnection, user_id: &'a str, ) -> Result, IdentityProviderError>; - /// Create passkey + /// Create passkey. async fn create_user_passkey<'a>( &self, db: &DatabaseConnection, @@ -236,12 +277,46 @@ mock! { group_id: &'a str, ) -> Result<(), IdentityProviderError>; - async fn list_groups_for_user<'a>( + async fn list_groups_of_user<'a>( &self, db: &DatabaseConnection, user_id: &'a str, ) -> Result, IdentityProviderError>; + async fn add_user_to_group<'a>( + &self, + db: &DatabaseConnection, + user_id: &'a str, + group_id: &'a str, + ) -> Result<(), IdentityProviderError>; + + async fn add_users_to_groups<'a>( + &self, + db: &DatabaseConnection, + memberships: Vec<(&'a str, &'a str)> + ) -> Result<(), IdentityProviderError>; + + async fn remove_user_from_group<'a>( + &self, + db: &DatabaseConnection, + user_id: &'a str, + group_id: &'a str, + ) -> Result<(), IdentityProviderError>; + + async fn remove_user_from_groups<'a>( + &self, + db: &DatabaseConnection, + user_id: &'a str, + group_ids: HashSet<&'a str>, + ) -> Result<(), IdentityProviderError>; + + async fn set_user_groups<'a>( + &self, + db: &DatabaseConnection, + user_id: &'a str, + group_ids: HashSet<&'a str>, + ) -> Result<(), IdentityProviderError>; + async fn list_user_passkeys<'a>( &self, db: &DatabaseConnection, @@ -459,14 +534,73 @@ impl IdentityApi for IdentityProvider { self.backend_driver.delete_group(db, group_id).await } - /// List groups a user is a member of + /// List groups a user is a member of. #[tracing::instrument(level = "info", skip(self, db))] - async fn list_groups_for_user<'a>( + async fn list_groups_of_user<'a>( &self, db: &DatabaseConnection, user_id: &'a str, ) -> Result, IdentityProviderError> { - self.backend_driver.list_groups_for_user(db, user_id).await + self.backend_driver.list_groups_of_user(db, user_id).await + } + + #[tracing::instrument(level = "info", skip(self, db))] + async fn add_user_to_group<'a>( + &self, + db: &DatabaseConnection, + user_id: &'a str, + group_id: &'a str, + ) -> Result<(), IdentityProviderError> { + self.backend_driver + .add_user_to_group(db, user_id, group_id) + .await + } + + #[tracing::instrument(level = "info", skip(self, db))] + async fn add_users_to_groups<'a>( + &self, + db: &DatabaseConnection, + memberships: Vec<(&'a str, &'a str)>, + ) -> Result<(), IdentityProviderError> { + self.backend_driver + .add_users_to_groups(db, memberships) + .await + } + + #[tracing::instrument(level = "info", skip(self, db))] + async fn remove_user_from_group<'a>( + &self, + db: &DatabaseConnection, + user_id: &'a str, + group_id: &'a str, + ) -> Result<(), IdentityProviderError> { + self.backend_driver + .remove_user_from_group(db, user_id, group_id) + .await + } + + #[tracing::instrument(level = "info", skip(self, db))] + async fn remove_user_from_groups<'a>( + &self, + db: &DatabaseConnection, + user_id: &'a str, + group_ids: HashSet<&'a str>, + ) -> Result<(), IdentityProviderError> { + self.backend_driver + .remove_user_from_groups(db, user_id, group_ids) + .await + } + + #[tracing::instrument(level = "debug", skip(self, db))] + async fn set_user_groups<'a>( + &self, + db: &DatabaseConnection, + user_id: &'a str, + group_ids: HashSet<&'a str>, + ) -> Result<(), IdentityProviderError> { + self.backend_driver + .set_user_groups(db, user_id, group_ids) + .await } /// List user passkeys diff --git a/src/identity/types.rs b/src/identity/types.rs index da701d72..997e298a 100644 --- a/src/identity/types.rs +++ b/src/identity/types.rs @@ -12,6 +12,8 @@ // // SPDX-License-Identifier: Apache-2.0 +use std::collections::HashSet; + pub mod group; pub mod user; @@ -108,13 +110,52 @@ pub trait IdentityBackend: DynClone + Send + Sync + std::fmt::Debug { group_id: &'a str, ) -> Result<(), IdentityProviderError>; - /// List groups a user is member of - async fn list_groups_for_user<'a>( + /// List groups a user is member of. + async fn list_groups_of_user<'a>( &self, db: &DatabaseConnection, user_id: &'a str, ) -> Result, IdentityProviderError>; + /// Add the user to the group. + async fn add_user_to_group<'a>( + &self, + db: &DatabaseConnection, + user_id: &'a str, + group_id: &'a str, + ) -> Result<(), IdentityProviderError>; + + /// Add user group membership relations. + async fn add_users_to_groups<'a>( + &self, + db: &DatabaseConnection, + memberships: Vec<(&'a str, &'a str)>, + ) -> Result<(), IdentityProviderError>; + + /// Remove the user from the group + async fn remove_user_from_group<'a>( + &self, + db: &DatabaseConnection, + user_id: &'a str, + group_id: &'a str, + ) -> Result<(), IdentityProviderError>; + + /// Remove the user from multiple groups. + async fn remove_user_from_groups<'a>( + &self, + db: &DatabaseConnection, + user_id: &'a str, + group_ids: HashSet<&'a str>, + ) -> Result<(), IdentityProviderError>; + + /// Set group memberships for the user. + async fn set_user_groups<'a>( + &self, + db: &DatabaseConnection, + user_id: &'a str, + group_ids: HashSet<&'a str>, + ) -> Result<(), IdentityProviderError>; + /// List user passkeys async fn list_user_passkeys<'a>( &self, diff --git a/src/token/error.rs b/src/token/error.rs index 9d6d3414..2e732abd 100644 --- a/src/token/error.rs +++ b/src/token/error.rs @@ -182,4 +182,7 @@ pub enum TokenProviderError { #[error("user cannot be found: {0}")] UserNotFound(String), + + #[error("unsupported authentication methods in token payload")] + UnsupportedAuthMethods, } diff --git a/src/token/federation_domain_scoped.rs b/src/token/federation_domain_scoped.rs index 12cd216f..0b0bb202 100644 --- a/src/token/federation_domain_scoped.rs +++ b/src/token/federation_domain_scoped.rs @@ -99,7 +99,7 @@ impl MsgPackToken for FederationDomainScopePayload { fernet_utils::write_uuid(wd, &self.domain_id)?; fernet_utils::write_list_of_uuids(wd, self.group_ids.iter())?; fernet_utils::write_uuid(wd, &self.idp_id)?; - fernet_utils::write_uuid(wd, &self.protocol_id)?; + fernet_utils::write_str(wd, &self.protocol_id)?; fernet_utils::write_time(wd, self.expires_at)?; fernet_utils::write_audit_ids(wd, self.audit_ids.clone())?; @@ -119,7 +119,7 @@ impl MsgPackToken for FederationDomainScopePayload { let domain_id = fernet_utils::read_uuid(rd)?; let group_ids = fernet_utils::read_list_of_uuids(rd)?; let idp_id = fernet_utils::read_uuid(rd)?; - let protocol_id = fernet_utils::read_uuid(rd)?; + let protocol_id = fernet_utils::read_str(rd)?; let expires_at = fernet_utils::read_time(rd)?; let audit_ids: Vec = fernet_utils::read_audit_ids(rd)?.into_iter().collect(); Ok(Self { diff --git a/src/token/federation_project_scoped.rs b/src/token/federation_project_scoped.rs index b53c67e4..92a35a59 100644 --- a/src/token/federation_project_scoped.rs +++ b/src/token/federation_project_scoped.rs @@ -99,7 +99,7 @@ impl MsgPackToken for FederationProjectScopePayload { fernet_utils::write_uuid(wd, &self.project_id)?; fernet_utils::write_list_of_uuids(wd, self.group_ids.iter())?; fernet_utils::write_uuid(wd, &self.idp_id)?; - fernet_utils::write_uuid(wd, &self.protocol_id)?; + fernet_utils::write_str(wd, &self.protocol_id)?; fernet_utils::write_time(wd, self.expires_at)?; fernet_utils::write_audit_ids(wd, self.audit_ids.clone())?; @@ -112,14 +112,13 @@ impl MsgPackToken for FederationProjectScopePayload { ) -> Result { // Order of reading is important let user_id = fernet_utils::read_uuid(rd)?; - println!("u: {user_id:?}"); let methods: Vec = fernet::decode_auth_methods(read_pfix(rd)?.into(), auth_map)? .into_iter() .collect(); let project_id = fernet_utils::read_uuid(rd)?; let group_ids = fernet_utils::read_list_of_uuids(rd)?; let idp_id = fernet_utils::read_uuid(rd)?; - let protocol_id = fernet_utils::read_uuid(rd)?; + let protocol_id = fernet_utils::read_str(rd)?; let expires_at = fernet_utils::read_time(rd)?; let audit_ids: Vec = fernet_utils::read_audit_ids(rd)?.into_iter().collect(); Ok(Self { diff --git a/src/token/federation_unscoped.rs b/src/token/federation_unscoped.rs index 925fd635..3763715d 100644 --- a/src/token/federation_unscoped.rs +++ b/src/token/federation_unscoped.rs @@ -91,7 +91,7 @@ impl MsgPackToken for FederationUnscopedPayload { .map_err(|x| TokenProviderError::RmpEncode(x.to_string()))?; fernet_utils::write_list_of_uuids(wd, self.group_ids.iter())?; fernet_utils::write_uuid(wd, &self.idp_id)?; - fernet_utils::write_uuid(wd, &self.protocol_id)?; + fernet_utils::write_str(wd, &self.protocol_id)?; fernet_utils::write_time(wd, self.expires_at)?; fernet_utils::write_audit_ids(wd, self.audit_ids.clone())?; @@ -109,7 +109,7 @@ impl MsgPackToken for FederationUnscopedPayload { .collect(); let group_ids = fernet_utils::read_list_of_uuids(rd)?; let idp_id = fernet_utils::read_uuid(rd)?; - let protocol_id = fernet_utils::read_uuid(rd)?; + let protocol_id = fernet_utils::read_str(rd)?; let expires_at = fernet_utils::read_time(rd)?; let audit_ids: Vec = fernet_utils::read_audit_ids(rd)?.into_iter().collect(); Ok(Self { diff --git a/src/token/fernet.rs b/src/token/fernet.rs index 5cca3905..2154d9c3 100644 --- a/src/token/fernet.rs +++ b/src/token/fernet.rs @@ -110,6 +110,11 @@ pub(crate) fn encode_auth_methods>( let res = auth_map .iter() .fold(0, |acc, (k, v)| acc + if me.contains(v) { *k } else { 0 }); + + // TODO: Improve unit tests to ensure unsupporte auth method immediately raises error. + if res == 0 { + return Err(TokenProviderError::UnsupportedAuthMethods); + } Ok(res) } diff --git a/src/token/fernet_utils.rs b/src/token/fernet_utils.rs index 420dd461..05d3441d 100644 --- a/src/token/fernet_utils.rs +++ b/src/token/fernet_utils.rs @@ -14,7 +14,11 @@ use base64::{Engine as _, engine::general_purpose::URL_SAFE}; use chrono::{DateTime, Utc}; -use rmp::{Marker, decode::*, encode::*}; +use rmp::{ + Marker, + decode::*, + encode::{self, *}, +}; use std::collections::BTreeMap; use std::fs; use std::io; @@ -98,11 +102,34 @@ pub fn read_bin_data(len: u32, rd: &mut R) -> Result, io::Error Ok(buf) } -/// Read string data +/// Read string data. pub fn read_str_data(len: u32, rd: &mut R) -> Result { Ok(String::from_utf8_lossy(&read_bin_data(len, rd)?).into_owned()) } +/// Write string. +pub fn write_str(wd: &mut W, data: &str) -> Result<(), TokenProviderError> { + encode::write_str(wd, data).map_err(|x| TokenProviderError::RmpEncode(x.to_string()))?; + Ok(()) +} + +/// Read bytes as string. +pub fn read_str(rd: &mut R) -> Result { + match read_marker(rd).map_err(ValueReadError::from)? { + Marker::Bin8 => { + Ok( + String::from_utf8_lossy(&read_bin_data(read_pfix(rd)?.into(), rd)?).to_string(), + ) + } + Marker::FixStr(len) => { + Ok(read_str_data(len.into(), rd)?) + } + other => { + Err(TokenProviderError::InvalidTokenUuidMarker(other)) + } + } +} + /// Read the UUID from the payload /// It is represented as an Array[bool, bytes] where first bool indicates whether following bytes /// are UUID or just bytes that should be treated as a string (for cases where ID is not a valid diff --git a/src/token/mod.rs b/src/token/mod.rs index cfcc18af..dd7454fb 100644 --- a/src/token/mod.rs +++ b/src/token/mod.rs @@ -202,7 +202,14 @@ impl TokenProvider { ) .idp_id(idp_id) .protocol_id(protocol_id) - .group_ids(vec![]) + .group_ids( + authentication_info + .user_groups + .clone() + .iter() + .map(|grp| grp.id.clone()) + .collect::>(), + ) .project_id(project.id.clone()) .project(project.clone()) .build()?, @@ -237,6 +244,14 @@ impl TokenProvider { ) .idp_id(idp_id) .protocol_id(protocol_id) + .group_ids( + authentication_info + .user_groups + .clone() + .iter() + .map(|grp| grp.id.clone()) + .collect::>(), + ) .domain_id(domain.id.clone()) .domain(domain.clone()) .build()?, @@ -360,6 +375,7 @@ impl TokenApi for TokenProvider { window_seconds: Option, ) -> Result { let token = self.backend_driver.decode(credential)?; + tracing::debug!("Token is {:?}", token); if Local::now().to_utc() > token .expires_at() @@ -373,6 +389,7 @@ impl TokenApi for TokenProvider { Ok(token) } + #[tracing::instrument(level = "debug", skip(self))] fn issue_token( &self, authentication_info: AuthenticatedInfo, @@ -509,6 +526,62 @@ impl TokenApi for TokenProvider { return Err(TokenProviderError::ActorHasNoRolesOnTarget); } } + Token::FederationProjectScope(data) => { + data.roles = Some( + provider + .get_assignment_provider() + .list_role_assignments( + db, + provider, + &RoleAssignmentListParametersBuilder::default() + .user_id(&data.user_id) + .project_id(&data.project_id) + .include_names(true) + .effective(true) + .build() + .map_err(AssignmentProviderError::from)?, + ) + .await? + .into_iter() + .map(|x| Role { + id: x.role_id.clone(), + name: x.role_name.clone().unwrap_or_default(), + ..Default::default() + }) + .collect(), + ); + if data.roles.as_ref().is_none_or(|roles| roles.is_empty()) { + return Err(TokenProviderError::ActorHasNoRolesOnTarget); + } + } + Token::FederationDomainScope(data) => { + data.roles = Some( + provider + .get_assignment_provider() + .list_role_assignments( + db, + provider, + &RoleAssignmentListParametersBuilder::default() + .user_id(&data.user_id) + .domain_id(&data.domain_id) + .include_names(true) + .effective(true) + .build() + .map_err(AssignmentProviderError::from)?, + ) + .await? + .into_iter() + .map(|x| Role { + id: x.role_id.clone(), + name: x.role_name.clone().unwrap_or_default(), + ..Default::default() + }) + .collect(), + ); + if data.roles.as_ref().is_none_or(|roles| roles.is_empty()) { + return Err(TokenProviderError::ActorHasNoRolesOnTarget); + } + } _ => {} }