From a50e1129c0fda833b26a8490d11320f6cdf1d6dc Mon Sep 17 00:00:00 2001 From: gtema Date: Thu, 20 Mar 2025 17:11:34 +0000 Subject: [PATCH] feat: Implement password auth --- src/api/common.rs | 134 +++++++++ src/api/error.rs | 37 ++- src/api/mod.rs | 7 +- src/api/v3/auth/token/common.rs | 150 +++++++---- src/api/v3/auth/token/mod.rs | 299 ++++++++++++++++++++- src/api/v3/auth/token/types.rs | 136 +++++++++- src/api/v3/user/mod.rs | 15 +- src/api/v3/user/passkey/mod.rs | 21 +- src/api/v3/user/types.rs | 22 +- src/assignment/backends/sql/role.rs | 6 +- src/bin/keystone.rs | 8 +- src/identity/backends/error.rs | 7 +- src/identity/backends/sql.rs | 113 +++++--- src/identity/backends/sql/common.rs | 16 +- src/identity/backends/sql/group.rs | 25 +- src/identity/backends/sql/local_user.rs | 57 +++- src/identity/backends/sql/passkey.rs | 12 +- src/identity/backends/sql/passkey_state.rs | 33 ++- src/identity/backends/sql/user.rs | 17 +- src/identity/backends/sql/user_option.rs | 17 +- src/identity/error.rs | 33 ++- src/identity/mod.rs | 71 ++++- src/identity/password_hashing.rs | 16 ++ src/identity/types.rs | 17 +- src/identity/types/user.rs | 34 ++- src/resource/backends/sql.rs | 63 ++++- src/resource/error.rs | 10 +- src/resource/mod.rs | 74 ++++- src/resource/types.rs | 15 ++ src/token/domain_scoped.rs | 1 + src/token/error.rs | 14 + src/token/mod.rs | 77 ++++-- src/token/project_scoped.rs | 1 + 33 files changed, 1302 insertions(+), 256 deletions(-) create mode 100644 src/api/common.rs diff --git a/src/api/common.rs b/src/api/common.rs new file mode 100644 index 00000000..8f6cab7f --- /dev/null +++ b/src/api/common.rs @@ -0,0 +1,134 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +//! Common API helpers +//! +use crate::api::error::KeystoneApiError; +use crate::keystone::ServiceState; +use crate::resource::{ResourceApi, types::Domain}; + +pub async fn get_domain, N: AsRef>( + state: &ServiceState, + id: Option, + name: Option, +) -> Result { + let domain = if let Some(did) = &id { + state + .provider + .get_resource_provider() + .get_domain(&state.db, did.as_ref()) + .await? + .ok_or_else(|| KeystoneApiError::NotFound { + resource: "domain".into(), + identifier: did.as_ref().to_string(), + })? + } else if let Some(name) = &name { + state + .provider + .get_resource_provider() + .find_domain_by_name(&state.db, name.as_ref()) + .await? + .ok_or_else(|| KeystoneApiError::NotFound { + resource: "domain".into(), + identifier: name.as_ref().to_string(), + })? + } else { + return Err(KeystoneApiError::DomainIdOrName); + }; + Ok(domain) +} + +#[cfg(test)] +mod tests { + use sea_orm::DatabaseConnection; + use std::sync::Arc; + + use super::*; + + use crate::assignment::MockAssignmentProvider; + use crate::config::Config; + use crate::identity::MockIdentityProvider; + use crate::keystone::Service; + use crate::provider::ProviderBuilder; + use crate::resource::{MockResourceProvider, types::Domain}; + use crate::token::MockTokenProvider; + + #[tokio::test] + async fn test_get_domain() { + let db = DatabaseConnection::Disconnected; + let config = Config::default(); + + let mut resource_mock = MockResourceProvider::default(); + resource_mock + .expect_get_domain() + .withf(|_: &DatabaseConnection, id: &'_ str| id == "domain_id") + .returning(|_, _| { + Ok(Some(Domain { + id: "domain_id".into(), + name: "domain_name".into(), + ..Default::default() + })) + }); + resource_mock + .expect_find_domain_by_name() + .withf(|_: &DatabaseConnection, id: &'_ str| id == "domain_name") + .returning(|_, _| { + Ok(Some(Domain { + id: "domain_id".into(), + name: "domain_name".into(), + ..Default::default() + })) + }); + let identity_mock = MockIdentityProvider::default(); + let token_mock = MockTokenProvider::default(); + let assignment_mock = MockAssignmentProvider::default(); + let provider = ProviderBuilder::default() + .config(config.clone()) + .assignment(assignment_mock) + .identity(identity_mock) + .resource(resource_mock) + .token(token_mock) + .build() + .unwrap(); + + let state = Arc::new(Service::new(config, db, provider).unwrap()); + + assert_eq!( + "domain_id", + get_domain(&state, Some("domain_id"), None::<&str>) + .await + .unwrap() + .id + ); + assert_eq!( + "domain_id", + get_domain(&state, None::<&str>, Some("domain_name")) + .await + .unwrap() + .id + ); + assert_eq!( + "domain_id", + get_domain(&state, Some("domain_id"), Some("other_domain_name")) + .await + .unwrap() + .id + ); + match get_domain(&state, None::<&str>, None::<&str>).await { + Err(KeystoneApiError::DomainIdOrName) => {} + _ => { + panic!("wrong result"); + } + } + } +} diff --git a/src/api/error.rs b/src/api/error.rs index 01a1b0f5..ddfac631 100644 --- a/src/api/error.rs +++ b/src/api/error.rs @@ -19,6 +19,7 @@ use axum::{ }; use serde_json::json; use thiserror::Error; +use tracing::error; use crate::assignment::error::AssignmentProviderError; use crate::identity::error::IdentityProviderError; @@ -38,7 +39,7 @@ pub enum KeystoneApiError { }, #[error("missing authorization")] - Unauthorized(String), + Unauthorized, #[error("missing x-subject-token header")] SubjectTokenMissing, @@ -99,10 +100,20 @@ pub enum KeystoneApiError { #[from] source: serde_json::Error, }, + + #[error("domain id or name must be present")] + DomainIdOrName, + + #[error("project id or name must be present")] + ProjectIdOrName, + + #[error("project domain must be present")] + ProjectDomain, } impl IntoResponse for KeystoneApiError { fn into_response(self) -> Response { + error!("Error happened during request processing: {:?}", self); match self { KeystoneApiError::Conflict(_) => ( StatusCode::CONFLICT, @@ -113,22 +124,18 @@ impl IntoResponse for KeystoneApiError { Json(json!({"error": {"code": StatusCode::NOT_FOUND.as_u16(), "message": self.to_string()}})), ) .into_response(), - KeystoneApiError::Unauthorized(_) => { + KeystoneApiError::Unauthorized => { (StatusCode::UNAUTHORIZED, Json(json!({"error": {"code": StatusCode::UNAUTHORIZED.as_u16(), "message": self.to_string()}})), ).into_response() } - KeystoneApiError::InternalError(_) => { - (StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": {"code": StatusCode::INTERNAL_SERVER_ERROR.as_u16(), "message": self.to_string()}})), - ).into_response() - } - KeystoneApiError::IdentityError { .. } | KeystoneApiError::ResourceError { .. } | KeystoneApiError::AssignmentError { .. } | KeystoneApiError::TokenError{..} => { + KeystoneApiError::InternalError(_) | KeystoneApiError::IdentityError { .. } | KeystoneApiError::ResourceError { .. } | KeystoneApiError::AssignmentError { .. } | KeystoneApiError::TokenError{..} => { (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": {"code": StatusCode::INTERNAL_SERVER_ERROR.as_u16(), "message": self.to_string()}})), ).into_response() } - KeystoneApiError::SubjectTokenMissing | KeystoneApiError::InvalidHeader | KeystoneApiError::InvalidToken | KeystoneApiError::Token{..} | KeystoneApiError::WebAuthN{..} | KeystoneApiError::Uuid {..} | KeystoneApiError::Serde {..} => { + _ => { + // KeystoneApiError::SubjectTokenMissing | KeystoneApiError::InvalidHeader | KeystoneApiError::InvalidToken | KeystoneApiError::Token{..} | KeystoneApiError::WebAuthN{..} | KeystoneApiError::Uuid {..} | KeystoneApiError::Serde {..} | KeystoneApiError::DomainIdOrName | KeystoneApiError::ProjectIdOrName | KeystoneApiError::ProjectDomain => (StatusCode::BAD_REQUEST, Json(json!({"error": {"code": StatusCode::BAD_REQUEST.as_u16(), "message": self.to_string()}})), ).into_response() @@ -184,11 +191,23 @@ pub enum TokenError { #[from] source: crate::api::v3::auth::token::types::UserBuilderError, }, + #[error("error building token user data: {}", source)] ProjectBuilder { #[from] source: crate::api::v3::auth::token::types::ProjectBuilderError, }, + + #[error(transparent)] + UserPasswordAuthBuilder { + #[from] + source: crate::identity::types::user::UserPasswordAuthRequestBuilderError, + }, + #[error(transparent)] + DomainBuilder { + #[from] + source: crate::identity::types::user::DomainBuilderError, + }, } #[derive(Error, Debug)] diff --git a/src/api/mod.rs b/src/api/mod.rs index 1957e728..db6a6041 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -13,7 +13,6 @@ // SPDX-License-Identifier: Apache-2.0 use axum::{ - extract::OriginalUri, http::{HeaderMap, header}, response::IntoResponse, }; @@ -24,6 +23,7 @@ use crate::api::error::KeystoneApiError; use crate::keystone::ServiceState; pub mod auth; +pub(crate) mod common; pub mod error; pub mod types; pub mod v3; @@ -50,10 +50,7 @@ pub fn openapi_router() -> OpenApiRouter { ), tag = "version" )] -async fn version( - headers: HeaderMap, - OriginalUri(uri): OriginalUri, -) -> Result { +async fn version(headers: HeaderMap) -> Result { let host = headers .get(header::HOST) .and_then(|header| header.to_str().ok()) diff --git a/src/api/v3/auth/token/common.rs b/src/api/v3/auth/token/common.rs index 976e7e5d..8f0da495 100644 --- a/src/api/v3/auth/token/common.rs +++ b/src/api/v3/auth/token/common.rs @@ -12,20 +12,102 @@ // // SPDX-License-Identifier: Apache-2.0 +use crate::api::common; use crate::api::error::{KeystoneApiError, TokenError}; use crate::api::v3::auth::token::types::{ProjectBuilder, Token, TokenBuilder, UserBuilder}; use crate::api::v3::role::types::Role; use crate::assignment::AssignmentApi; use crate::assignment::types::RoleAssignmentListParametersBuilder; -use crate::identity::IdentityApi; +use crate::identity::{IdentityApi, types::UserResponse}; use crate::keystone::ServiceState; -use crate::resource::ResourceApi; +use crate::resource::{ + ResourceApi, + types::{Domain, Project}, +}; use crate::token::Token as ProviderToken; impl Token { - pub async fn from_provider_token( + // TODO: Join both methods + pub async fn from_user_auth( + state: &ServiceState, token: &ProviderToken, + user: &UserResponse, + project: Option<&Project>, + domain: Option<&Domain>, + ) -> Result { + let mut response = TokenBuilder::default(); + response.audit_ids(token.audit_ids().clone()); + response.methods(token.methods().clone()); + response.expires_at(*token.expires_at()); + + let user_domain = common::get_domain(state, Some(&user.domain_id), None::<&str>).await?; + + let mut user_response: UserBuilder = UserBuilder::default(); + user_response.id(user.id.clone()); + user_response.name(user.name.clone()); + user_response.password_expires_at(user.password_expires_at); + user_response.domain(user_domain.clone()); + response.user(user_response.build().map_err(TokenError::from)?); + + match token { + ProviderToken::Unscoped(_token) => { + // Nothing to do + } + ProviderToken::DomainScope(_token) => { + response.domain(domain.ok_or(KeystoneApiError::InternalError( + "domain scope missing".to_string(), + ))?); + } + ProviderToken::ProjectScope(token) => { + let project = project.ok_or(KeystoneApiError::InternalError( + "domain scope missing".to_string(), + ))?; + + let mut project_response = ProjectBuilder::default(); + project_response.id(project.id.clone()); + project_response.name(project.name.clone()); + if project.domain_id == user.domain_id { + project_response.domain(user_domain.clone().into()); + } else { + let project_domain = + common::get_domain(state, Some(&project.domain_id), None::<&str>).await?; + project_response.domain(project_domain.clone().into()); + } + response.project(project_response.build().map_err(TokenError::from)?); + + let token_roles = state + .provider + .get_assignment_provider() + .list_role_assignments( + &state.db, + &state.provider, + &RoleAssignmentListParametersBuilder::default() + .user_id(user.id.clone()) + .project_id(&token.project_id) + .build()?, + ) + .await?; + response.roles( + token_roles + .into_iter() + .map(|x| Role { + id: x.role_id.clone(), + name: x.role_name.clone().unwrap_or_default(), + ..Default::default() + }) + .collect::>(), + ); + } + ProviderToken::ApplicationCredential(_token) => { + todo!(); + } + } + Ok(response.build().map_err(TokenError::from)?) + } + + pub async fn from_provider_token( state: &ServiceState, + token: &ProviderToken, ) -> Result { let mut response = TokenBuilder::default(); response.audit_ids(token.audit_ids().clone()); @@ -43,16 +125,7 @@ impl Token { identifier: token.user_id().clone(), })?; - let user_domain = state - .provider - .get_resource_provider() - .get_domain(&state.db, &user.domain_id) - .await - .map_err(KeystoneApiError::resource)? - .ok_or_else(|| KeystoneApiError::NotFound { - resource: "domain".into(), - identifier: user.domain_id.clone(), - })?; + let user_domain = common::get_domain(state, Some(&user.domain_id), None::<&str>).await?; let mut user_response: UserBuilder = UserBuilder::default(); user_response.id(user.id.clone()); @@ -69,16 +142,8 @@ impl Token { if token.domain_id == user.domain_id { response.domain(user_domain.clone()); } else { - let domain = state - .provider - .get_resource_provider() - .get_domain(&state.db, &token.domain_id) - .await - .map_err(KeystoneApiError::resource)? - .ok_or_else(|| KeystoneApiError::NotFound { - resource: "domain".into(), - identifier: token.domain_id.clone(), - })?; + let domain = + common::get_domain(state, Some(&token.domain_id), None::<&str>).await?; response.domain(domain.clone()); } } @@ -100,16 +165,8 @@ impl Token { if project.domain_id == user.domain_id { project_response.domain(user_domain.clone().into()); } else { - let project_domain = state - .provider - .get_resource_provider() - .get_domain(&state.db, &project.domain_id) - .await - .map_err(KeystoneApiError::resource)? - .ok_or_else(|| KeystoneApiError::NotFound { - resource: "domain".into(), - identifier: user.domain_id.clone(), - })?; + let project_domain = + common::get_domain(state, Some(&project.domain_id), None::<&str>).await?; project_response.domain(project_domain.clone().into()); } response.project(project_response.build().map_err(TokenError::from)?); @@ -157,7 +214,7 @@ mod tests { types::{Assignment, AssignmentType, RoleAssignmentListParameters}, }; use crate::config::Config; - use crate::identity::{MockIdentityProvider, types::User}; + use crate::identity::{MockIdentityProvider, types::UserResponse}; use crate::keystone::Service; use crate::provider::ProviderBuilder; use crate::resource::{ @@ -178,7 +235,7 @@ mod tests { .expect_get_user() .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") .returning(|_, _| { - Ok(Some(User { + Ok(Some(UserResponse { id: "bar".into(), domain_id: "user_domain_id".into(), ..Default::default() @@ -209,16 +266,16 @@ mod tests { let state = Arc::new(Service::new(config, db, provider).unwrap()); let api_token = Token::from_provider_token( + &state, &ProviderToken::Unscoped(UnscopedToken { user_id: "bar".into(), ..Default::default() }), - &state, ) .await .unwrap(); assert_eq!("bar", api_token.user.id); - assert_eq!("user_domain_id", api_token.user.domain.id); + assert_eq!(Some("user_domain_id"), api_token.user.domain.id.as_deref()); assert!(api_token.project.is_none()); assert!(api_token.domain.is_none()); } @@ -232,7 +289,7 @@ mod tests { .expect_get_user() .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") .returning(|_, _| { - Ok(Some(User { + Ok(Some(UserResponse { id: "bar".into(), domain_id: "user_domain_id".into(), ..Default::default() @@ -262,19 +319,22 @@ mod tests { let state = Arc::new(Service::new(config, db, provider).unwrap()); let api_token = Token::from_provider_token( + &state, &ProviderToken::DomainScope(DomainScopeToken { user_id: "bar".into(), domain_id: "domain_id".into(), ..Default::default() }), - &state, ) .await .unwrap(); assert_eq!("bar", api_token.user.id); - assert_eq!("user_domain_id", api_token.user.domain.id); - assert_eq!("domain_id", api_token.domain.expect("domain scope").id); + assert_eq!(Some("user_domain_id"), api_token.user.domain.id.as_deref()); + assert_eq!( + Some("domain_id"), + api_token.domain.expect("domain scope").id.as_deref() + ); assert!(api_token.project.is_none()); } @@ -287,7 +347,7 @@ mod tests { .expect_get_user() .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") .returning(|_, _| { - Ok(Some(User { + Ok(Some(UserResponse { id: "bar".into(), domain_id: "user_domain_id".into(), ..Default::default() @@ -338,20 +398,20 @@ mod tests { let state = Arc::new(Service::new(config, db, provider).unwrap()); let api_token = Token::from_provider_token( + &state, &ProviderToken::ProjectScope(ProjectScopeToken { user_id: "bar".into(), project_id: "project_id".into(), ..Default::default() }), - &state, ) .await .unwrap(); assert_eq!("bar", api_token.user.id); - assert_eq!("user_domain_id", api_token.user.domain.id); + assert_eq!(Some("user_domain_id"), api_token.user.domain.id.as_deref()); let project = api_token.project.expect("project_scope"); - assert_eq!("project_domain_id", project.domain.id); + assert_eq!(Some("project_domain_id"), project.domain.id.as_deref()); assert_eq!("project_id", project.id); assert!(api_token.domain.is_none()); assert_eq!( diff --git a/src/api/v3/auth/token/mod.rs b/src/api/v3/auth/token/mod.rs index 17f82951..0f1b82ed 100644 --- a/src/api/v3/auth/token/mod.rs +++ b/src/api/v3/auth/token/mod.rs @@ -12,20 +12,158 @@ // // SPDX-License-Identifier: Apache-2.0 -use axum::{extract::State, http::HeaderMap, response::IntoResponse}; +use axum::{Json, extract::State, http::HeaderMap, http::StatusCode, response::IntoResponse}; +use base64::{Engine as _, engine::general_purpose::URL_SAFE}; use utoipa_axum::{router::OpenApiRouter, routes}; +use uuid::Uuid; -use crate::api::auth::Auth; -use crate::api::error::KeystoneApiError; +use crate::api::{auth::Auth, common::get_domain, error::KeystoneApiError}; +use crate::identity::IdentityApi; +use crate::identity::types::UserResponse; use crate::keystone::ServiceState; +use crate::resource::{ + ResourceApi, + types::{Domain, Project}, +}; use crate::token::TokenApi; -use types::{Token as ApiResponseToken, TokenResponse}; +use types::{AuthRequest, Scope, Token as ApiResponseToken, TokenResponse}; mod common; pub mod types; pub(super) fn openapi_router() -> OpenApiRouter { - OpenApiRouter::new().routes(routes!(show)) + OpenApiRouter::new().routes(routes!(show, post)) +} + +/// Authenticate user issuing a new token +#[utoipa::path( + post, + path = "/", + description = "Issue token", + params(), + responses( + (status = OK, description = "Token object", body = TokenResponse), + ), + tag="auth" +)] +#[tracing::instrument(name = "api::token_post", level = "debug", skip(state, req))] +async fn post( + State(state): State, + Json(req): Json, +) -> Result { + let mut methods: Vec = Vec::new(); + let mut user: Option = None; + let mut project: Option = None; + let mut domain: Option = None; + + match req.auth.scope { + Some(Scope::Project(scope)) => { + project = if let Some(pid) = &scope.id { + state + .provider + .get_resource_provider() + .get_project(&state.db, pid) + .await? + } else if let Some(name) = &scope.name { + if let Some(domain) = scope.domain { + let domain_id = match domain.id { + Some(id) => id.clone(), + None => { + state + .provider + .get_resource_provider() + .find_domain_by_name( + &state.db, + &domain + .name + .clone() + .ok_or(KeystoneApiError::DomainIdOrName)?, + ) + .await? + .ok_or(KeystoneApiError::NotFound { + resource: "domain".to_string(), + identifier: domain + .name + .clone() + .ok_or(KeystoneApiError::DomainIdOrName)?, + })? + .id + } + }; + state + .provider + .get_resource_provider() + .get_project_by_name(&state.db, name, &domain_id) + .await? + } else { + return Err(KeystoneApiError::ProjectDomain); + } + } else { + return Err(KeystoneApiError::ProjectIdOrName); + }; + if !project.as_ref().is_some_and(|target| target.enabled) { + return Err(KeystoneApiError::Unauthorized); + } + } + Some(Scope::Domain(scope)) => { + domain = Some(get_domain(&state, scope.id.as_ref(), scope.name.as_ref()).await?); + if !domain.as_ref().is_some_and(|target| target.enabled) { + return Err(KeystoneApiError::Unauthorized); + } + } + None => {} + } + + for method in req.auth.identity.methods.iter() { + if method == "password" { + if let Some(password_auth) = &req.auth.identity.password { + let req = password_auth.user.clone().try_into()?; + user = Some( + state + .provider + .get_identity_provider() + .authenticate_by_password(&state.db, &state.provider, req) + .await?, + ); + methods.push(method.clone()); + } + } + } + + if let Some(authed_user) = &user { + let token = state.provider.get_token_provider().issue_token( + authed_user.id.clone(), + methods, + Vec::::from([URL_SAFE + .encode(Uuid::new_v4().as_bytes()) + .trim_end_matches('=') + .to_string()]), + project.as_ref(), + domain.as_ref(), + )?; + + let api_token = TokenResponse { + token: ApiResponseToken::from_user_auth( + &state, + &token, + authed_user, + project.as_ref(), + domain.as_ref(), + ) + .await?, + }; + return Ok(( + StatusCode::OK, + [( + "X-Subject-Token", + state.provider.get_token_provider().encode_token(&token)?, + )], + Json(api_token), + ) + .into_response()); + } + + return Err(KeystoneApiError::Unauthorized); } /// Validate token @@ -39,9 +177,13 @@ pub(super) fn openapi_router() -> OpenApiRouter { ), tag="auth" )] -#[tracing::instrument(name = "api::token_get", level = "debug", skip(state))] +#[tracing::instrument( + name = "api::token_get", + level = "debug", + skip(state, headers, _user_auth) +)] async fn show( - Auth(user_auth): Auth, + Auth(_user_auth): Auth, headers: HeaderMap, State(state): State, ) -> Result { @@ -59,7 +201,7 @@ async fn show( .await .map_err(|_| KeystoneApiError::InvalidToken)?; - let response_token = ApiResponseToken::from_provider_token(&token, &state).await?; + let response_token = ApiResponseToken::from_provider_token(&state, &token).await?; Ok(TokenResponse { token: response_token, @@ -70,10 +212,11 @@ async fn show( mod tests { use axum::{ body::Body, - http::{Request, StatusCode}, + http::{Request, StatusCode, header}, }; use http_body_util::BodyExt; // for `collect` use sea_orm::DatabaseConnection; + use serde_json::json; use std::sync::Arc; use tower::ServiceExt; // for `call`, `oneshot`, and `ready` use tower_http::trace::TraceLayer; @@ -82,10 +225,13 @@ mod tests { use crate::api::v3::auth::token::types::TokenResponse; use crate::assignment::MockAssignmentProvider; use crate::config::Config; - use crate::identity::{MockIdentityProvider, types::User}; + use crate::identity::{MockIdentityProvider, types::UserResponse}; use crate::keystone::Service; use crate::provider::ProviderBuilder; - use crate::resource::{MockResourceProvider, types::Domain}; + use crate::resource::{ + MockResourceProvider, + types::{Domain, Project}, + }; use crate::tests::api::get_mocked_state_unauthed; use crate::token::*; @@ -96,7 +242,7 @@ mod tests { let assignment_mock = MockAssignmentProvider::default(); let mut identity_mock = MockIdentityProvider::default(); identity_mock.expect_get_user().returning(|_, id: &'_ str| { - Ok(Some(User { + Ok(Some(UserResponse { id: id.to_string(), domain_id: "user_domain_id".into(), ..Default::default() @@ -185,4 +331,133 @@ mod tests { assert_eq!(response.status(), StatusCode::UNAUTHORIZED); } + + #[tokio::test] + async fn test_post() { + let db = DatabaseConnection::Disconnected; + let config = Config::default(); + let mut assignment_mock = MockAssignmentProvider::default(); + assignment_mock + .expect_list_role_assignments() + .returning(|_, _, _| Ok(Vec::new())); + + let mut identity_mock = MockIdentityProvider::default(); + identity_mock + .expect_authenticate_by_password() + .returning(|_, _, _| { + Ok(UserResponse { + id: "uid".to_string(), + domain_id: "user_domain_id".into(), + ..Default::default() + }) + }); + + let mut resource_mock = MockResourceProvider::default(); + resource_mock + .expect_get_project() + .withf(|_: &DatabaseConnection, id: &'_ str| id == "pid") + .returning(|_, _| { + Ok(Some(Project { + id: "pid".into(), + domain_id: "pdid".into(), + enabled: true, + ..Default::default() + })) + }); + resource_mock + .expect_get_domain() + .withf(|_: &DatabaseConnection, id: &'_ str| id == "user_domain_id") + .returning(|_, _| { + Ok(Some(Domain { + id: "user_domain_id".into(), + enabled: true, + ..Default::default() + })) + }); + resource_mock + .expect_get_domain() + .withf(|_: &DatabaseConnection, id: &'_ str| id == "pdid") + .returning(|_, _| { + Ok(Some(Domain { + id: "pdid".into(), + enabled: true, + ..Default::default() + })) + }); + let mut token_mock = MockTokenProvider::default(); + token_mock.expect_issue_token().returning(|_, _, _, _, _| { + Ok(Token::ProjectScope(ProjectScopeToken { + user_id: "bar".into(), + methods: Vec::from(["password".to_string()]), + ..Default::default() + })) + }); + + token_mock + .expect_encode_token() + .returning(|_| Ok("token".to_string())); + + let provider = ProviderBuilder::default() + .config(config.clone()) + .assignment(assignment_mock) + .identity(identity_mock) + .resource(resource_mock) + .token(token_mock) + .build() + .unwrap(); + + let state = Arc::new(Service::new(config, db, provider).unwrap()); + + let mut api = openapi_router() + .layer(TraceLayer::new_for_http()) + .with_state(state.clone()); + + let response = api + .as_service() + .oneshot( + Request::builder() + .uri("/") + .method("POST") + .header(header::CONTENT_TYPE, "application/json") + .body(Body::from( + serde_json::to_vec(&json!({ + "auth": { + "identity": { + "methods": ["password"], + "password": { + "user": { + "id": "uid", + "name": "uname", + "domain": { + "id": "udid", + "name": "udname" + }, + "password": "pass", + }, + }, + }, + "scope": { + "project": { + "id": "pid", + "name": "pname", + "domain": { + "id": "pdid", + "name": "pdname" + } + } + } + } + })) + .unwrap(), + )) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + let body = response.into_body().collect().await.unwrap().to_bytes(); + let res: TokenResponse = serde_json::from_slice(&body).unwrap(); + assert_eq!(vec!["password"], res.token.methods); + } } diff --git a/src/api/v3/auth/token/types.rs b/src/api/v3/auth/token/types.rs index 0f542223..41ef2be2 100644 --- a/src/api/v3/auth/token/types.rs +++ b/src/api/v3/auth/token/types.rs @@ -24,6 +24,7 @@ use utoipa::ToSchema; use crate::api::error::TokenError; use crate::api::v3::role::types::Role; +use crate::identity::types as identity_types; use crate::resource::types as resource_provider_types; use crate::token::Token as BackendToken; @@ -54,7 +55,7 @@ pub struct Token { pub expires_at: DateTime, /// A user object. - #[builder(default)] + //#[builder(default)] pub user: User, /// A project object including the id, name and domain object representing the project the @@ -88,6 +89,115 @@ impl IntoResponse for TokenResponse { } } +/// An authentication request. +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] +pub struct AuthRequest { + /// An identity object. + pub auth: AuthRequestInner, +} + +/// An authentication request. +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] +pub struct AuthRequestInner { + /// An identity object. + pub identity: Identity, + + /// The authorization scope, including the system (Since v3.10), a project, or a domain (Since + /// v3.4). If multiple scopes are specified in the same request (e.g. project and domain or + /// domain and system) an HTTP 400 Bad Request will be returned, as a token cannot be + /// simultaneously scoped to multiple authorization targets. An ID is sufficient to uniquely + /// identify a project but if a project is specified by name, then the domain of the project + /// must also be specified in order to uniquely identify the project by name. A domain scope + /// may be specified by either the domain’s ID or name with equivalent results. + pub scope: Option, +} + +/// An identity object. +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] +pub struct Identity { + /// The authentication method. For password authentication, specify password. + pub methods: Vec, + + /// The password object, contains the authentication information. + pub password: Option, +} + +/// The password object, contains the authentication information. +#[derive(Builder, Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] +#[builder(setter(strip_option, into))] +pub struct PasswordAuth { + /// A user object. + #[builder(default)] + pub user: UserPassword, +} + +/// User password information +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] +pub struct UserPassword { + /// User ID + pub id: Option, + /// User Name + pub name: Option, + /// User domain + pub domain: Option, + /// User password expiry date + pub password: String, +} + +impl TryFrom for identity_types::UserPasswordAuthRequest { + type Error = TokenError; + + fn try_from(value: UserPassword) -> Result { + let mut upa = identity_types::UserPasswordAuthRequestBuilder::default(); + if let Some(id) = &value.id { + upa.id(id); + } + if let Some(name) = &value.name { + upa.name(name); + } + if let Some(domain) = &value.domain { + let mut domain_builder = identity_types::DomainBuilder::default(); + if let Some(id) = &domain.id { + domain_builder.id(id); + } + if let Some(name) = &domain.name { + domain_builder.name(name); + } + upa.domain(domain_builder.build()?); + } + upa.password(value.password.clone()); + Ok(upa.build()?) + } +} + +/// The authorization scope, including the system (Since v3.10), a project, or a domain (Since +/// v3.4). If multiple scopes are specified in the same request (e.g. project and domain or domain +/// and system) an HTTP 400 Bad Request will be returned, as a token cannot be simultaneously +/// scoped to multiple authorization targets. An ID is sufficient to uniquely identify a project +/// but if a project is specified by name, then the domain of the project must also be specified in +/// order to uniquely identify the project by name. A domain scope may be specified by either the +/// domain’s ID or name with equivalent results. +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)] +pub enum Scope { + /// Project scope + #[serde(rename = "project")] + Project(ProjectScope), + /// Domain scope + #[serde(rename = "domain")] + Domain(Domain), +} + +/// Project scope information +#[derive(Builder, Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] +pub struct ProjectScope { + /// Project ID + pub id: Option, + /// Project Name + pub name: Option, + /// project domain + pub domain: Option, +} + /// Project information #[derive(Builder, Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] pub struct Project { @@ -95,7 +205,6 @@ pub struct Project { pub id: String, /// Project Name pub name: String, - /// project domain pub domain: Domain, } @@ -107,10 +216,12 @@ pub struct User { /// User ID pub id: String, /// User Name - pub name: String, + #[builder(default)] + pub name: Option, /// User domain pub domain: Domain, /// User password expiry date + #[builder(default)] pub password_expires_at: Option>, } @@ -119,16 +230,27 @@ pub struct User { #[builder(setter(into))] pub struct Domain { /// Domain ID - pub id: String, + #[builder(default)] + pub id: Option, /// Domain Name - pub name: String, + #[builder(default)] + pub name: Option, } impl From for Domain { fn from(value: resource_provider_types::Domain) -> Self { Self { - id: value.id.clone(), - name: value.name.clone(), + id: Some(value.id.clone()), + name: Some(value.name.clone()), + } + } +} + +impl From<&resource_provider_types::Domain> for Domain { + fn from(value: &resource_provider_types::Domain) -> Self { + Self { + id: Some(value.id.clone()), + name: Some(value.name.clone()), } } } diff --git a/src/api/v3/user/mod.rs b/src/api/v3/user/mod.rs index 13326cb0..bc5540ab 100644 --- a/src/api/v3/user/mod.rs +++ b/src/api/v3/user/mod.rs @@ -196,12 +196,13 @@ mod tests { use super::openapi_router; use crate::api::v3::group::types::{Group as ApiGroup, GroupList}; use crate::api::v3::user::types::{ - User as ApiUser, UserCreate as ApiUserCreate, UserCreateRequest, UserList, UserResponse, + User as ApiUser, UserCreate as ApiUserCreate, UserCreateRequest, UserList, + UserResponse as ApiUserResponse, }; use crate::identity::{ MockIdentityProvider, error::IdentityProviderError, - types::{Group, User, UserCreate, UserListParameters}, + types::{Group, UserCreate, UserListParameters, UserResponse}, }; use crate::tests::api::{get_mocked_state, get_mocked_state_unauthed}; @@ -213,7 +214,7 @@ mod tests { .expect_list_users() .withf(|_: &DatabaseConnection, _: &UserListParameters| true) .returning(|_, _| { - Ok(vec![User { + Ok(vec![UserResponse { id: "1".into(), name: "2".into(), ..Default::default() @@ -317,7 +318,7 @@ mod tests { req.domain_id == "domain" && req.name == "name" }) .returning(|_, req| { - Ok(User { + Ok(UserResponse { id: "bar".into(), domain_id: req.domain_id, name: req.name, @@ -356,7 +357,7 @@ mod tests { assert_eq!(response.status(), StatusCode::CREATED); let body = response.into_body().collect().await.unwrap().to_bytes(); - let created_user: UserResponse = serde_json::from_slice(&body).unwrap(); + let created_user: ApiUserResponse = serde_json::from_slice(&body).unwrap(); assert_eq!(created_user.user.name, user.user.name); } @@ -372,7 +373,7 @@ mod tests { .expect_get_user() .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") .returning(|_, _| { - Ok(Some(User { + Ok(Some(UserResponse { id: "bar".into(), ..Default::default() })) @@ -413,7 +414,7 @@ mod tests { assert_eq!(response.status(), StatusCode::OK); let body = response.into_body().collect().await.unwrap().to_bytes(); - let res: UserResponse = serde_json::from_slice(&body).unwrap(); + let res: ApiUserResponse = serde_json::from_slice(&body).unwrap(); assert_eq!( ApiUser { id: "bar".into(), diff --git a/src/api/v3/user/passkey/mod.rs b/src/api/v3/user/passkey/mod.rs index 71a7e3ef..fd055868 100644 --- a/src/api/v3/user/passkey/mod.rs +++ b/src/api/v3/user/passkey/mod.rs @@ -18,6 +18,7 @@ use axum::{ http::StatusCode, response::IntoResponse, }; +use base64::{Engine as _, engine::general_purpose::URL_SAFE}; use serde_json::Value; use tracing::debug; use utoipa_axum::{router::OpenApiRouter, routes}; @@ -61,6 +62,7 @@ async fn register_start( .get_identity_provider() .delete_user_passkey_registration_state(&state.db, &user_id) .await?; + // TODO: user names let res = match state.webauthn.start_passkey_registration( Uuid::parse_str(&user_id)?, "foo", @@ -95,7 +97,7 @@ async fn register_start( #[tracing::instrument( name = "api::user_passkey_register_finish", level = "debug", - skip(state) + skip(state, reg) )] async fn register_finish( Path(user_id): Path, @@ -189,7 +191,11 @@ async fn login_start( responses(), tag = "passkey" )] -#[tracing::instrument(name = "api::user_passkey_login_finish", level = "debug", skip(state))] +#[tracing::instrument( + name = "api::user_passkey_login_finish", + level = "debug", + skip(state, reg) +)] async fn login_finish( Path(user_id): Path, State(state): State, @@ -221,14 +227,19 @@ async fn login_finish( let token = state.provider.get_token_provider().issue_token( user_id, vec!["passkey".into()], - Vec::::new(), + Vec::::from([URL_SAFE + .encode(Uuid::new_v4().as_bytes()) + .trim_end_matches('=') + .to_string()]), + None, + None, )?; - let api_token = ApiToken::try_from(&token)?; + let api_token = ApiToken::from_provider_token(&state, &token).await?; Ok(( StatusCode::OK, [( - "X-Auth-Token", + "X-Subject-Token", state.provider.get_token_provider().encode_token(&token)?, )], Json(api_token), diff --git a/src/api/v3/user/types.rs b/src/api/v3/user/types.rs index ca153876..d9fb1867 100644 --- a/src/api/v3/user/types.rs +++ b/src/api/v3/user/types.rs @@ -22,7 +22,7 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use utoipa::{IntoParams, ToSchema}; -use crate::identity::types; +use crate::identity::types as identity_types; #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] pub struct User { @@ -134,8 +134,8 @@ pub struct UserOptions { pub multi_factor_auth_enabled: Option, } -impl From for UserOptions { - fn from(value: types::UserOptions) -> Self { +impl From for UserOptions { + fn from(value: identity_types::UserOptions) -> Self { Self { ignore_change_password_upon_first_use: value.ignore_change_password_upon_first_use, ignore_password_expiry: value.ignore_password_expiry, @@ -148,7 +148,7 @@ impl From for UserOptions { } } -impl From for types::UserOptions { +impl From for identity_types::UserOptions { fn from(value: UserOptions) -> Self { Self { ignore_change_password_upon_first_use: value.ignore_change_password_upon_first_use, @@ -168,8 +168,8 @@ pub struct UserCreateRequest { pub user: UserCreate, } -impl From for User { - fn from(value: types::User) -> Self { +impl From for User { + fn from(value: identity_types::UserResponse) -> Self { let opts: UserOptions = value.options.clone().into(); // We only want to see user options if there is at least 1 option set let opts = if opts.ignore_change_password_upon_first_use.is_some() @@ -197,7 +197,7 @@ impl From for User { } } -impl From for types::UserCreate { +impl From for identity_types::UserCreate { fn from(value: UserCreateRequest) -> Self { let user = value.user; Self { @@ -220,7 +220,7 @@ impl IntoResponse for UserResponse { } } -impl IntoResponse for types::User { +impl IntoResponse for identity_types::UserResponse { fn into_response(self) -> Response { ( StatusCode::OK, @@ -239,8 +239,8 @@ pub struct UserList { pub users: Vec, } -impl From> for UserList { - fn from(value: Vec) -> Self { +impl From> for UserList { + fn from(value: Vec) -> Self { let objects: Vec = value.into_iter().map(User::from).collect(); Self { users: objects } } @@ -260,7 +260,7 @@ pub struct UserListParameters { pub name: Option, } -impl From for types::UserListParameters { +impl From for identity_types::UserListParameters { fn from(value: UserListParameters) -> Self { Self { domain_id: value.domain_id, diff --git a/src/assignment/backends/sql/role.rs b/src/assignment/backends/sql/role.rs index 3fa09655..4a50f441 100644 --- a/src/assignment/backends/sql/role.rs +++ b/src/assignment/backends/sql/role.rs @@ -24,12 +24,12 @@ use crate::db::entity::{prelude::Role as DbRole, role as db_role}; static NULL_DOMAIN_ID: &str = "<>"; -pub async fn get( +pub async fn get>( _conf: &Config, db: &DatabaseConnection, - id: &str, + id: I, ) -> Result, AssignmentDatabaseError> { - let role_select = DbRole::find_by_id(id); + let role_select = DbRole::find_by_id(id.as_ref()); let entry: Option = role_select.one(db).await?; entry.map(TryInto::try_into).transpose() diff --git a/src/bin/keystone.rs b/src/bin/keystone.rs index 889c6b85..26a7926a 100644 --- a/src/bin/keystone.rs +++ b/src/bin/keystone.rs @@ -105,7 +105,13 @@ async fn main() -> Result<(), Report> { .split_for_parts(); let x_request_id = HeaderName::from_static("x-openstack-request-id"); - let sensitive_headers: Arc<[_]> = vec![header::AUTHORIZATION, header::COOKIE].into(); + let sensitive_headers: Arc<[_]> = vec![ + header::AUTHORIZATION, + header::COOKIE, + header::HeaderName::from_static("x-auth-token"), + header::HeaderName::from_static("x-subject-token"), + ] + .into(); let middleware = ServiceBuilder::new() // Inject x-request-id header into processing diff --git a/src/identity/backends/error.rs b/src/identity/backends/error.rs index e88e7705..b8b95ede 100644 --- a/src/identity/backends/error.rs +++ b/src/identity/backends/error.rs @@ -34,10 +34,10 @@ pub enum IdentityDatabaseError { source: serde_json::Error, }, - #[error("building user data")] + #[error("building user response data")] UserBuilderError { #[from] - source: UserBuilderError, + source: UserResponseBuilderError, }, #[error("database data")] @@ -51,4 +51,7 @@ pub enum IdentityDatabaseError { #[from] source: IdentityProviderPasswordHashError, }, + + #[error("either user id or user name with user domain id or name must be given")] + UserIdOrNameWithDomain, } diff --git a/src/identity/backends/sql.rs b/src/identity/backends/sql.rs index 3cdc322b..2d1f3074 100644 --- a/src/identity/backends/sql.rs +++ b/src/identity/backends/sql.rs @@ -54,13 +54,56 @@ impl IdentityBackend for SqlBackend { self.config = config; } + /// Authenticate a user by a password + async fn authenticate_by_password( + &self, + db: &DatabaseConnection, + auth: UserPasswordAuthRequest, + ) -> Result { + let user_with_passwords = local_user::load_local_user_with_passwords( + db, + auth.id, + auth.name, + auth.domain.and_then(|x| x.id), + ) + .await?; + if let Some((local_user, password)) = user_with_passwords { + let passwords: Vec = password.into_iter().collect(); + if let Some(latest_password) = passwords.first() { + if let Some(expected_hash) = &latest_password.password_hash { + let user_opts = user_option::get(db, local_user.user_id.clone()).await?; + + if password_hashing::verify_password( + &self.config, + auth.password, + expected_hash, + )? { + if let Some(user) = user::get(db, &local_user.user_id).await? { + let user_builder = common::get_local_user_builder( + &self.config, + &user, + local_user, + Some(passwords), + user_opts, + ); + return Ok(user_builder.build()?); + } + } else { + return Err(IdentityProviderError::WrongUsernamePassword); + } + } + } + } + return Err(IdentityProviderError::WrongUsernamePassword); + } + /// Fetch users from the database #[tracing::instrument(level = "debug", skip(self, db))] async fn list_users( &self, db: &DatabaseConnection, params: &UserListParameters, - ) -> Result, IdentityProviderError> { + ) -> Result, IdentityProviderError> { Ok(list_users(&self.config, db, params).await?) } @@ -70,7 +113,7 @@ impl IdentityBackend for SqlBackend { &self, db: &DatabaseConnection, user_id: &'a str, - ) -> Result, IdentityProviderError> { + ) -> Result, IdentityProviderError> { Ok(get_user(&self.config, db, user_id).await?) } @@ -80,7 +123,7 @@ impl IdentityBackend for SqlBackend { &self, db: &DatabaseConnection, user: UserCreate, - ) -> Result { + ) -> Result { Ok(create_user(&self.config, db, user).await?) } @@ -240,7 +283,7 @@ async fn list_users( conf: &Config, db: &DatabaseConnection, params: &UserListParameters, -) -> Result, IdentityDatabaseError> { +) -> Result, IdentityDatabaseError> { // Prepare basic selects let mut user_select = DbUser::find(); let mut local_user_select = LocalUser::find(); @@ -277,7 +320,7 @@ async fn list_users( ) .await?; - let mut results: Vec = Vec::new(); + let mut results: Vec = Vec::new(); for (u, (o, (l, (p, (n, f))))) in db_users.into_iter().zip( user_opts.into_iter().zip( local_users.into_iter().zip( @@ -290,7 +333,7 @@ async fn list_users( if l.is_none() && n.is_none() && f.is_empty() { continue; } - let user_builder: UserBuilder = if let Some(local) = l { + let user_builder: UserResponseBuilder = if let Some(local) = l { common::get_local_user_builder(conf, &u, local, p.map(|x| x.into_iter()), o) } else if let Some(nonlocal) = n { common::get_nonlocal_user_builder(&u, nonlocal, o) @@ -323,7 +366,7 @@ pub async fn get_user( conf: &Config, db: &DatabaseConnection, user_id: &str, -) -> Result, IdentityDatabaseError> { +) -> Result, IdentityDatabaseError> { let user_select = DbUser::find_by_id(user_id); let user_entry: Option = user_select.one(db).await?; @@ -331,29 +374,35 @@ pub async fn get_user( if let Some(user) = &user_entry { let user_opts: Vec = user.find_related(UserOption).all(db).await?; - let user_builder: UserBuilder = - match local_user::load_local_user_with_passwords(db, &user_id).await? { - Some(local_user_with_passwords) => common::get_local_user_builder( - conf, - user, - local_user_with_passwords.0, - Some(local_user_with_passwords.1), - user_opts, - ), - _ => match user.find_related(NonlocalUser).one(db).await? { - Some(nonlocal_user) => { - common::get_nonlocal_user_builder(user, nonlocal_user, user_opts) - } - _ => { - let federated_user = user.find_related(FederatedUser).all(db).await?; - if !federated_user.is_empty() { - common::get_federated_user_builder(user, federated_user, user_opts) - } else { - return Err(IdentityDatabaseError::MalformedUser(user_id.to_string()))?; - } + let user_builder: UserResponseBuilder = match local_user::load_local_user_with_passwords( + db, + Some(&user_id), + None::<&str>, + None::<&str>, + ) + .await? + { + Some(local_user_with_passwords) => common::get_local_user_builder( + conf, + user, + local_user_with_passwords.0, + Some(local_user_with_passwords.1), + user_opts, + ), + _ => match user.find_related(NonlocalUser).one(db).await? { + Some(nonlocal_user) => { + common::get_nonlocal_user_builder(user, nonlocal_user, user_opts) + } + _ => { + let federated_user = user.find_related(FederatedUser).all(db).await?; + if !federated_user.is_empty() { + common::get_federated_user_builder(user, federated_user, user_opts) + } else { + return Err(IdentityDatabaseError::MalformedUser(user_id.to_string()))?; } - }, - }; + } + }, + }; return Ok(Some(user_builder.build()?)); } @@ -365,7 +414,7 @@ async fn create_user( conf: &Config, db: &DatabaseConnection, user: UserCreate, -) -> Result { +) -> Result { let main_user = user::create(conf, db, &user).await?; if let Some(_federated) = &user.federated { } else { @@ -470,7 +519,7 @@ mod tests { let config = Config::default(); assert_eq!( get_user(&config, &db, "1").await.unwrap().unwrap(), - User { + UserResponse { id: "1".into(), domain_id: "foo_domain".into(), name: "Apple Cake".to_owned(), @@ -499,7 +548,7 @@ mod tests { ), Transaction::from_sql_and_values( DatabaseBackend::Postgres, - r#"SELECT "local_user"."id" AS "A_id", "local_user"."user_id" AS "A_user_id", "local_user"."domain_id" AS "A_domain_id", "local_user"."name" AS "A_name", "local_user"."failed_auth_count" AS "A_failed_auth_count", "local_user"."failed_auth_at" AS "A_failed_auth_at", "password"."id" AS "B_id", "password"."local_user_id" AS "B_local_user_id", "password"."self_service" AS "B_self_service", "password"."created_at" AS "B_created_at", "password"."expires_at" AS "B_expires_at", "password"."password_hash" AS "B_password_hash", "password"."created_at_int" AS "B_created_at_int", "password"."expires_at_int" AS "B_expires_at_int" FROM "local_user" LEFT JOIN "password" ON "local_user"."id" = "password"."local_user_id" WHERE "local_user"."user_id" = $1 ORDER BY "local_user"."id" ASC"#, + r#"SELECT "local_user"."id" AS "A_id", "local_user"."user_id" AS "A_user_id", "local_user"."domain_id" AS "A_domain_id", "local_user"."name" AS "A_name", "local_user"."failed_auth_count" AS "A_failed_auth_count", "local_user"."failed_auth_at" AS "A_failed_auth_at", "password"."id" AS "B_id", "password"."local_user_id" AS "B_local_user_id", "password"."self_service" AS "B_self_service", "password"."created_at" AS "B_created_at", "password"."expires_at" AS "B_expires_at", "password"."password_hash" AS "B_password_hash", "password"."created_at_int" AS "B_created_at_int", "password"."expires_at_int" AS "B_expires_at_int" FROM "local_user" LEFT JOIN "password" ON "local_user"."id" = "password"."local_user_id" WHERE "local_user"."user_id" = $1 ORDER BY "local_user"."id" ASC, "password"."created_at_int" DESC"#, ["1".into()] ), ] diff --git a/src/identity/backends/sql/common.rs b/src/identity/backends/sql/common.rs index 6bbadd9d..962e36e4 100644 --- a/src/identity/backends/sql/common.rs +++ b/src/identity/backends/sql/common.rs @@ -28,8 +28,8 @@ use crate::identity::types::*; pub fn get_user_builder>( user: &user::Model, opts: O, -) -> UserBuilder { - let mut user_builder: UserBuilder = UserBuilder::default(); +) -> UserResponseBuilder { + let mut user_builder: UserResponseBuilder = UserResponseBuilder::default(); user_builder.id(user.id.clone()); user_builder.domain_id(user.domain_id.clone()); // TODO: default enabled logic @@ -52,8 +52,8 @@ pub fn get_local_user_builder< data: local_user::Model, passwords: Option

, opts: O, -) -> UserBuilder { - let mut user_builder: UserBuilder = get_user_builder(user, opts); +) -> UserResponseBuilder { + let mut user_builder: UserResponseBuilder = get_user_builder(user, opts); user_builder.name(data.name.clone()); if let Some(password_expires_days) = conf.security_compliance.password_expires_days { if let Some(pass) = passwords { @@ -79,8 +79,8 @@ pub fn get_nonlocal_user_builder>( user: &user::Model, data: nonlocal_user::Model, opts: O, -) -> UserBuilder { - let mut user_builder: UserBuilder = get_user_builder(user, opts); +) -> UserResponseBuilder { + let mut user_builder: UserResponseBuilder = get_user_builder(user, opts); user_builder.name(data.name.clone()); user_builder } @@ -92,8 +92,8 @@ pub fn get_federated_user_builder< user: &user::Model, data: F, opts: O, -) -> UserBuilder { - let mut user_builder: UserBuilder = get_user_builder(user, opts); +) -> UserResponseBuilder { + let mut user_builder: UserResponseBuilder = get_user_builder(user, opts); if let Some(first) = data.into_iter().next() { if let Some(name) = first.display_name { user_builder.name(name.clone()); diff --git a/src/identity/backends/sql/group.rs b/src/identity/backends/sql/group.rs index 09ff4b82..b5c2eb96 100644 --- a/src/identity/backends/sql/group.rs +++ b/src/identity/backends/sql/group.rs @@ -48,12 +48,15 @@ pub async fn list( Ok(results) } -pub async fn get( +pub async fn get>( _conf: &Config, db: &DatabaseConnection, - group_id: &str, + group_id: S, ) -> Result, IdentityDatabaseError> { - Ok(DbGroup::find_by_id(group_id).one(db).await?.map(Into::into)) + Ok(DbGroup::find_by_id(group_id.as_ref()) + .one(db) + .await? + .map(Into::into)) } pub async fn create( @@ -74,16 +77,18 @@ pub async fn create( Ok(db_entry.into()) } -pub async fn delete( +pub async fn delete>( _conf: &Config, db: &DatabaseConnection, - group_id: &str, + group_id: S, ) -> Result<(), IdentityDatabaseError> { - let res = DbGroup::delete_by_id(group_id).exec(db).await?; + let res = DbGroup::delete_by_id(group_id.as_ref()).exec(db).await?; if res.rows_affected == 1 { Ok(()) } else { - Err(IdentityDatabaseError::GroupNotFound(group_id.to_string())) + Err(IdentityDatabaseError::GroupNotFound( + group_id.as_ref().to_string(), + )) } } @@ -101,14 +106,14 @@ impl From for Group { } } -pub async fn list_for_user( +pub async fn list_for_user>( _conf: &Config, db: &DatabaseConnection, - user_id: &str, + user_id: S, ) -> Result, IdentityDatabaseError> { let groups: Vec<(user_group_membership::Model, Vec)> = DbUserGroupMembership::find() - .filter(user_group_membership::Column::UserId.eq(user_id)) + .filter(user_group_membership::Column::UserId.eq(user_id.as_ref())) .find_with_related(DbGroup) .all(db) .await?; diff --git a/src/identity/backends/sql/local_user.rs b/src/identity/backends/sql/local_user.rs index b1f3cbc6..7be0e37b 100644 --- a/src/identity/backends/sql/local_user.rs +++ b/src/identity/backends/sql/local_user.rs @@ -26,19 +26,34 @@ use crate::identity::backends::error::IdentityDatabaseError; use crate::identity::types::UserCreate; /// Load local user record with passwords from database -pub async fn load_local_user_with_passwords>( +pub async fn load_local_user_with_passwords, S2: AsRef, S3: AsRef>( db: &DatabaseConnection, - user_id: S, + user_id: Option, + name: Option, + domain_id: Option, ) -> Result< - Option<( - local_user::Model, - impl IntoIterator + use, - )>, + Option<(local_user::Model, impl IntoIterator)>, IdentityDatabaseError, > { - let results: Vec<(local_user::Model, Vec)> = LocalUser::find() - .filter(local_user::Column::UserId.eq(user_id.as_ref())) + let mut select = LocalUser::find(); + if let Some(user_id) = user_id { + select = select.filter(local_user::Column::UserId.eq(user_id.as_ref())) + } else { + select = select + .filter( + local_user::Column::Name.eq(name + .ok_or(IdentityDatabaseError::UserIdOrNameWithDomain)? + .as_ref()), + ) + .filter( + local_user::Column::DomainId.eq(domain_id + .ok_or(IdentityDatabaseError::UserIdOrNameWithDomain)? + .as_ref()), + ); + } + let results: Vec<(local_user::Model, Vec)> = select .find_with_related(Password) + .order_by(password::Column::CreatedAtInt, Order::Desc) .all(db) .await?; Ok(results.first().cloned()) @@ -74,7 +89,7 @@ pub async fn load_local_users_passwords>>( passwords.into_iter().for_each(|item| { let vec = hashmap .get_mut(&item.local_user_id) - .expect("Failed finding key on passwords hashmap"); + .expect("failed to find key on passwords hashmap"); vec.push(item); }); @@ -116,3 +131,27 @@ pub async fn create( Ok(db_user) } + +pub async fn get_by_name_and_domain, D: AsRef>( + _conf: &Config, + db: &DatabaseConnection, + name: N, + domain_id: D, +) -> Result, IdentityDatabaseError> { + Ok(LocalUser::find() + .filter(local_user::Column::Name.eq(name.as_ref())) + .filter(local_user::Column::DomainId.eq(domain_id.as_ref())) + .one(db) + .await?) +} + +pub async fn get_by_user_id>( + _conf: &Config, + db: &DatabaseConnection, + user_id: U, +) -> Result, IdentityDatabaseError> { + Ok(LocalUser::find() + .filter(local_user::Column::UserId.eq(user_id.as_ref())) + .one(db) + .await?) +} diff --git a/src/identity/backends/sql/passkey.rs b/src/identity/backends/sql/passkey.rs index 080fbe5d..e5cfd0fe 100644 --- a/src/identity/backends/sql/passkey.rs +++ b/src/identity/backends/sql/passkey.rs @@ -21,15 +21,15 @@ use webauthn_rs::prelude::Passkey; use crate::db::entity::{prelude::WebauthnCredential as DbPasskey, webauthn_credential}; use crate::identity::backends::error::IdentityDatabaseError; -pub(super) async fn create( +pub(super) async fn create>( db: &DatabaseConnection, - user_id: &str, + user_id: U, passkey: Passkey, ) -> Result<(), IdentityDatabaseError> { let now = Local::now().naive_utc(); let entry = webauthn_credential::ActiveModel { id: NotSet, - user_id: Set(user_id.to_string()), + user_id: Set(user_id.as_ref().to_string()), credential_id: Set(passkey.cred_id().escape_ascii().to_string()), passkey: Set(serde_json::to_string(&passkey)?), r#type: Set("cross-platform".to_string()), @@ -42,12 +42,12 @@ pub(super) async fn create( Ok(()) } -pub async fn list( +pub async fn list>( db: &DatabaseConnection, - user_id: &str, + user_id: U, ) -> Result, IdentityDatabaseError> { let res: Result, _> = DbPasskey::find() - .filter(webauthn_credential::Column::UserId.eq(user_id)) + .filter(webauthn_credential::Column::UserId.eq(user_id.as_ref())) .all(db) .await? .into_iter() diff --git a/src/identity/backends/sql/passkey_state.rs b/src/identity/backends/sql/passkey_state.rs index 7eb6a568..bd06ed86 100644 --- a/src/identity/backends/sql/passkey_state.rs +++ b/src/identity/backends/sql/passkey_state.rs @@ -21,14 +21,14 @@ use webauthn_rs::prelude::{PasskeyAuthentication, PasskeyRegistration}; use crate::db::entity::{prelude::WebauthnState as DbPasskeyState, webauthn_state}; use crate::identity::backends::error::IdentityDatabaseError; -pub(super) async fn create_register( +pub(super) async fn create_register>( db: &DatabaseConnection, - user_id: &str, + user_id: U, state: PasskeyRegistration, ) -> Result<(), IdentityDatabaseError> { let now = Local::now().naive_utc(); let entry = webauthn_state::ActiveModel { - user_id: Set(user_id.to_string()), + user_id: Set(user_id.as_ref().to_string()), state: Set(serde_json::to_string(&state)?), r#type: Set("register".into()), created_at: Set(now), @@ -37,14 +37,14 @@ pub(super) async fn create_register( Ok(()) } -pub(super) async fn create_auth( +pub(super) async fn create_auth>( db: &DatabaseConnection, - user_id: &str, + user_id: U, state: PasskeyAuthentication, ) -> Result<(), IdentityDatabaseError> { let now = Local::now().naive_utc(); let entry = webauthn_state::ActiveModel { - user_id: Set(user_id.to_string()), + user_id: Set(user_id.as_ref().to_string()), state: Set(serde_json::to_string(&state)?), r#type: Set("auth".into()), created_at: Set(now), @@ -53,11 +53,11 @@ pub(super) async fn create_auth( Ok(()) } -pub async fn get_register( +pub async fn get_register>( db: &DatabaseConnection, - user_id: &str, + user_id: U, ) -> Result, IdentityDatabaseError> { - match DbPasskeyState::find_by_id(user_id) + match DbPasskeyState::find_by_id(user_id.as_ref()) .filter(webauthn_state::Column::Type.eq("register")) .one(db) .await? @@ -67,11 +67,11 @@ pub async fn get_register( } } -pub async fn get_auth( +pub async fn get_auth>( db: &DatabaseConnection, - user_id: &str, + user_id: U, ) -> Result, IdentityDatabaseError> { - match DbPasskeyState::find_by_id(user_id) + match DbPasskeyState::find_by_id(user_id.as_ref()) .filter(webauthn_state::Column::Type.eq("auth")) .one(db) .await? @@ -81,8 +81,13 @@ pub async fn get_auth( } } -pub async fn delete(db: &DatabaseConnection, user_id: &str) -> Result<(), IdentityDatabaseError> { - DbPasskeyState::delete_by_id(user_id).exec(db).await?; +pub async fn delete>( + db: &DatabaseConnection, + user_id: U, +) -> Result<(), IdentityDatabaseError> { + DbPasskeyState::delete_by_id(user_id.as_ref()) + .exec(db) + .await?; Ok(()) } diff --git a/src/identity/backends/sql/user.rs b/src/identity/backends/sql/user.rs index fbab9187..40db792d 100644 --- a/src/identity/backends/sql/user.rs +++ b/src/identity/backends/sql/user.rs @@ -21,6 +21,13 @@ use crate::db::entity::{prelude::User as DbUser, user}; use crate::identity::backends::error::IdentityDatabaseError; use crate::identity::types::UserCreate; +pub async fn get>( + db: &DatabaseConnection, + user_id: U, +) -> Result, IdentityDatabaseError> { + Ok(DbUser::find_by_id(user_id.as_ref()).one(db).await?) +} + pub(super) async fn create( conf: &Config, db: &DatabaseConnection, @@ -55,16 +62,18 @@ pub(super) async fn create( Ok(db_user) } -pub async fn delete( +pub async fn delete>( _conf: &Config, db: &DatabaseConnection, - user_id: &str, + user_id: U, ) -> Result<(), IdentityDatabaseError> { - let res = DbUser::delete_by_id(user_id).exec(db).await?; + let res = DbUser::delete_by_id(user_id.as_ref()).exec(db).await?; if res.rows_affected == 1 { Ok(()) } else { - Err(IdentityDatabaseError::UserNotFound(user_id.to_string())) + Err(IdentityDatabaseError::UserNotFound( + user_id.as_ref().to_string(), + )) } } diff --git a/src/identity/backends/sql/user_option.rs b/src/identity/backends/sql/user_option.rs index eec0ce68..0122cd44 100644 --- a/src/identity/backends/sql/user_option.rs +++ b/src/identity/backends/sql/user_option.rs @@ -12,9 +12,24 @@ // // SPDX-License-Identifier: Apache-2.0 -use crate::db::entity::user_option; +use sea_orm::DatabaseConnection; +use sea_orm::entity::*; +use sea_orm::query::*; + +use crate::db::entity::{prelude::UserOption as DbUserOptions, user_option}; +use crate::identity::backends::sql::IdentityDatabaseError; use crate::identity::types::*; +pub async fn get>( + db: &DatabaseConnection, + user_id: S, +) -> Result, IdentityDatabaseError> { + Ok(DbUserOptions::find() + .filter(user_option::Column::UserId.eq(user_id.as_ref())) + .all(db) + .await?) +} + impl FromIterator for UserOptions { fn from_iter>(iter: I) -> Self { let mut user_opts: UserOptions = UserOptions::default(); diff --git a/src/identity/error.rs b/src/identity/error.rs index 770ab0c1..1dfb9697 100644 --- a/src/identity/error.rs +++ b/src/identity/error.rs @@ -15,7 +15,8 @@ use thiserror::Error; use crate::identity::backends::error::*; -use crate::identity::types::UserBuilderError; +use crate::identity::types::{DomainBuilderError, UserResponseBuilderError}; +use crate::resource::error::ResourceProviderError; #[derive(Error, Debug)] pub enum IdentityProviderError { @@ -37,23 +38,41 @@ pub enum IdentityProviderError { GroupNotFound(String), /// Identity provider error - #[error("identity provider error")] - IdentityDatabaseError { + #[error(transparent)] + IdentityDatabase { #[from] source: IdentityDatabaseError, }, - #[error("building user data")] - UserBuilderError { + #[error(transparent)] + UserBuilder { + #[from] + source: UserResponseBuilderError, + }, + + #[error(transparent)] + DomainBuilder { #[from] - source: UserBuilderError, + source: DomainBuilderError, }, + #[error("either user id or user name with user domain id or name must be given")] + UserIdOrNameWithDomain, + #[error("password hashing error")] PasswordHash { #[from] source: IdentityProviderPasswordHashError, }, + + #[error(transparent)] + ResourceProvider { + #[from] + source: ResourceProviderError, + }, + + #[error("wrong username or password")] + WrongUsernamePassword, } impl IdentityProviderError { @@ -61,7 +80,7 @@ impl IdentityProviderError { match source { IdentityDatabaseError::UserNotFound(x) => Self::UserNotFound(x), IdentityDatabaseError::GroupNotFound(x) => Self::GroupNotFound(x), - _ => Self::IdentityDatabaseError { source }, + _ => Self::IdentityDatabase { source }, } } } diff --git a/src/identity/mod.rs b/src/identity/mod.rs index 825bdd81..97ca4136 100644 --- a/src/identity/mod.rs +++ b/src/identity/mod.rs @@ -28,10 +28,12 @@ use crate::config::Config; use crate::identity::backends::sql::SqlBackend; use crate::identity::error::IdentityProviderError; use crate::identity::types::{ - IdentityBackend, - {Group, GroupCreate, GroupListParameters, User, UserCreate, UserListParameters}, + Group, GroupCreate, GroupListParameters, IdentityBackend, UserCreate, UserListParameters, + UserPasswordAuthRequest, UserResponse, }; use crate::plugin_manager::PluginManager; +use crate::provider::Provider; +use crate::resource::{ResourceApi, error::ResourceProviderError}; #[derive(Clone, Debug)] pub struct IdentityProvider { @@ -40,23 +42,30 @@ pub struct IdentityProvider { #[async_trait] pub trait IdentityApi: Send + Sync + Clone { + async fn authenticate_by_password( + &self, + db: &DatabaseConnection, + provider: &Provider, + auth: UserPasswordAuthRequest, + ) -> Result; + async fn list_users( &self, db: &DatabaseConnection, params: &UserListParameters, - ) -> Result, IdentityProviderError>; + ) -> Result, IdentityProviderError>; async fn get_user<'a>( &self, db: &DatabaseConnection, user_id: &'a str, - ) -> Result, IdentityProviderError>; + ) -> Result, IdentityProviderError>; async fn create_user( &self, db: &DatabaseConnection, user: UserCreate, - ) -> Result; + ) -> Result; async fn delete_user<'a>( &self, @@ -157,23 +166,30 @@ mock! { #[async_trait] impl IdentityApi for IdentityProvider { + async fn authenticate_by_password( + &self, + db: &DatabaseConnection, + provider: &Provider, + auth: UserPasswordAuthRequest, + ) -> Result; + async fn list_users( &self, db: &DatabaseConnection, params: &UserListParameters, - ) -> Result, IdentityProviderError>; + ) -> Result, IdentityProviderError>; async fn get_user<'a>( &self, db: &DatabaseConnection, user_id: &'a str, - ) -> Result, IdentityProviderError>; + ) -> Result, IdentityProviderError>; async fn create_user( &self, db: &DatabaseConnection, user: UserCreate, - ) -> Result; + ) -> Result; async fn delete_user<'a>( &self, @@ -295,13 +311,46 @@ impl IdentityProvider { #[async_trait] impl IdentityApi for IdentityProvider { + /// Authenticate user with the password auth method + #[tracing::instrument(level = "info", skip(self, db, provider, auth))] + async fn authenticate_by_password( + &self, + db: &DatabaseConnection, + provider: &Provider, + auth: UserPasswordAuthRequest, + ) -> Result { + let mut auth = auth; + if auth.id.is_none() { + if auth.name.is_none() { + return Err(IdentityProviderError::UserIdOrNameWithDomain); + } + + if let Some(ref mut domain) = auth.domain { + if let Some(dname) = &domain.name { + let d = provider + .get_resource_provider() + .find_domain_by_name(db, dname) + .await? + .ok_or(ResourceProviderError::DomainNotFound(dname.clone()))?; + domain.id = Some(d.id); + } else if domain.id.is_none() { + return Err(IdentityProviderError::UserIdOrNameWithDomain); + } + } else { + return Err(IdentityProviderError::UserIdOrNameWithDomain); + } + } + + self.backend_driver.authenticate_by_password(db, auth).await + } + /// List users #[tracing::instrument(level = "info", skip(self, db))] async fn list_users( &self, db: &DatabaseConnection, params: &UserListParameters, - ) -> Result, IdentityProviderError> { + ) -> Result, IdentityProviderError> { self.backend_driver.list_users(db, params).await } @@ -311,7 +360,7 @@ impl IdentityApi for IdentityProvider { &self, db: &DatabaseConnection, user_id: &'a str, - ) -> Result, IdentityProviderError> { + ) -> Result, IdentityProviderError> { self.backend_driver.get_user(db, user_id).await } @@ -321,7 +370,7 @@ impl IdentityApi for IdentityProvider { &self, db: &DatabaseConnection, user: UserCreate, - ) -> Result { + ) -> Result { let mut mod_user = user; mod_user.id = Uuid::new_v4().into(); if mod_user.enabled.is_none() { diff --git a/src/identity/password_hashing.rs b/src/identity/password_hashing.rs index 008b329e..76925143 100644 --- a/src/identity/password_hashing.rs +++ b/src/identity/password_hashing.rs @@ -42,6 +42,22 @@ pub fn hash_password>( } } +pub fn verify_password, H: AsRef>( + conf: &Config, + password: P, + hash: H, +) -> Result { + match conf.identity.password_hashing_algorithm { + PasswordHashingAlgo::Bcrypt => { + let password_bytes = verify_length_and_trunc_password( + password.as_ref(), + max(conf.identity.max_password_length, 72), + ); + Ok(bcrypt::verify(password_bytes, hash.as_ref())?) + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/identity/types.rs b/src/identity/types.rs index f84234f0..f97dd18d 100644 --- a/src/identity/types.rs +++ b/src/identity/types.rs @@ -25,7 +25,9 @@ use crate::identity::IdentityProviderError; pub use crate::identity::types::group::{Group, GroupCreate, GroupListParameters}; pub use crate::identity::types::user::{ - User, UserBuilder, UserBuilderError, UserCreate, UserListParameters, UserOptions, + DomainBuilder, DomainBuilderError, UserCreate, UserListParameters, UserOptions, + UserPasswordAuthRequest, UserPasswordAuthRequestBuilder, UserResponse, UserResponseBuilder, + UserResponseBuilderError, }; #[async_trait] @@ -33,26 +35,33 @@ pub trait IdentityBackend: DynClone + Send + Sync + std::fmt::Debug { /// Set config fn set_config(&mut self, config: Config); + /// Authenticate a user by a password + async fn authenticate_by_password( + &self, + db: &DatabaseConnection, + auth: UserPasswordAuthRequest, + ) -> Result; + /// List Users async fn list_users( &self, db: &DatabaseConnection, params: &UserListParameters, - ) -> Result, IdentityProviderError>; + ) -> Result, IdentityProviderError>; /// Get single user by ID async fn get_user<'a>( &self, db: &DatabaseConnection, user_id: &'a str, - ) -> Result, IdentityProviderError>; + ) -> Result, IdentityProviderError>; /// Create user async fn create_user( &self, db: &DatabaseConnection, user: UserCreate, - ) -> Result; + ) -> Result; /// Delete user async fn delete_user<'a>( diff --git a/src/identity/types/user.rs b/src/identity/types/user.rs index 72e2e4f3..2a3b8585 100644 --- a/src/identity/types/user.rs +++ b/src/identity/types/user.rs @@ -19,7 +19,7 @@ use serde_json::Value; #[derive(Builder, Clone, Debug, Default, Deserialize, PartialEq, Serialize)] #[builder(setter(strip_option, into))] -pub struct User { +pub struct UserResponse { /// The user ID. pub id: String, /// The user name. Must be unique within the owning domain. @@ -102,7 +102,7 @@ pub struct UserUpdate { pub federated: Option>, } -impl UserBuilder { +impl UserResponseBuilder { pub fn get_options(&self) -> Option<&UserOptions> { self.options.as_ref() } @@ -147,3 +147,33 @@ pub struct UserListParameters { /// Filter users by the name attribute pub name: Option, } + +/// User password information +#[derive(Builder, Clone, Debug, Default, Deserialize, PartialEq, Serialize)] +#[builder(setter(strip_option, into))] +pub struct UserPasswordAuthRequest { + /// User ID + #[builder(default)] + pub id: Option, + /// User Name + #[builder(default)] + pub name: Option, + /// User domain + #[builder(default)] + pub domain: Option, + /// User password expiry date + #[builder(default)] + pub password: String, +} + +/// Domain information +#[derive(Builder, Clone, Debug, Default, Deserialize, PartialEq, Serialize)] +#[builder(setter(strip_option, into))] +pub struct Domain { + /// Domain ID + #[builder(default)] + pub id: Option, + /// Domain Name + #[builder(default)] + pub name: Option, +} diff --git a/src/resource/backends/sql.rs b/src/resource/backends/sql.rs index 6120fa59..97258c4c 100644 --- a/src/resource/backends/sql.rs +++ b/src/resource/backends/sql.rs @@ -39,17 +39,24 @@ impl ResourceBackend for SqlBackend { } /// Get single domain by ID - #[tracing::instrument(level = "debug", skip(self, db))] async fn get_domain<'a>( &self, db: &DatabaseConnection, domain_id: &'a str, ) -> Result, ResourceProviderError> { - Ok(get_domain(&self.config, db, domain_id).await?) + Ok(get_domain_by_id(&self.config, db, domain_id).await?) + } + + /// Get single domain by Name + async fn get_domain_by_name<'a>( + &self, + db: &DatabaseConnection, + domain_name: &'a str, + ) -> Result, ResourceProviderError> { + Ok(get_domain_by_name(&self.config, db, domain_name).await?) } /// Get single project by ID - #[tracing::instrument(level = "debug", skip(self, db))] async fn get_project<'a>( &self, db: &DatabaseConnection, @@ -57,27 +64,65 @@ impl ResourceBackend for SqlBackend { ) -> Result, ResourceProviderError> { Ok(get_project(&self.config, db, project_id).await?) } + + /// Get single project by Name and Domain ID + async fn get_project_by_name<'a>( + &self, + db: &DatabaseConnection, + name: &'a str, + domain_id: &'a str, + ) -> Result, ResourceProviderError> { + Ok(get_project_by_name(&self.config, db, name, domain_id).await?) + } } -pub async fn get_domain( +pub async fn get_domain_by_id>( _conf: &Config, db: &DatabaseConnection, - domain_id: &str, + domain_id: I, ) -> Result, ResourceDatabaseError> { let domain_select = - DbProject::find_by_id(domain_id).filter(db_project::Column::IsDomain.eq(true)); + DbProject::find_by_id(domain_id.as_ref()).filter(db_project::Column::IsDomain.eq(true)); let domain_entry: Option = domain_select.one(db).await?; domain_entry.map(TryInto::try_into).transpose() } -pub async fn get_project( +pub async fn get_domain_by_name>( _conf: &Config, db: &DatabaseConnection, - domain_id: &str, + domain_name: N, +) -> Result, ResourceDatabaseError> { + let domain_select = DbProject::find() + .filter(db_project::Column::IsDomain.eq(true)) + .filter(db_project::Column::Name.eq(domain_name.as_ref())); + + let domain_entry: Option = domain_select.one(db).await?; + domain_entry.map(TryInto::try_into).transpose() +} + +pub async fn get_project>( + _conf: &Config, + db: &DatabaseConnection, + domain_id: I, ) -> Result, ResourceDatabaseError> { let project_select = - DbProject::find_by_id(domain_id).filter(db_project::Column::IsDomain.eq(false)); + DbProject::find_by_id(domain_id.as_ref()).filter(db_project::Column::IsDomain.eq(false)); + + let project_entry: Option = project_select.one(db).await?; + project_entry.map(TryInto::try_into).transpose() +} + +pub async fn get_project_by_name, D: AsRef>( + _conf: &Config, + db: &DatabaseConnection, + name: N, + domain_id: D, +) -> Result, ResourceDatabaseError> { + let project_select = DbProject::find() + .filter(db_project::Column::IsDomain.eq(false)) + .filter(db_project::Column::Name.eq(name.as_ref())) + .filter(db_project::Column::DomainId.eq(domain_id.as_ref())); let project_entry: Option = project_select.one(db).await?; project_entry.map(TryInto::try_into).transpose() diff --git a/src/resource/error.rs b/src/resource/error.rs index 1a2b8dea..53b8607c 100644 --- a/src/resource/error.rs +++ b/src/resource/error.rs @@ -34,14 +34,14 @@ pub enum ResourceProviderError { DomainNotFound(String), /// Identity provider error - #[error("resource provider error")] - ResourceDatabaseError { + #[error(transparent)] + ResourceDatabase { #[from] source: ResourceDatabaseError, }, - #[error("building domain data")] - DomainBuilderError { + #[error(transparent)] + DomainBuilder { #[from] source: DomainBuilderError, }, @@ -51,7 +51,7 @@ impl ResourceProviderError { pub fn database(source: ResourceDatabaseError) -> Self { match source { ResourceDatabaseError::DomainNotFound(x) => Self::DomainNotFound(x), - _ => Self::ResourceDatabaseError { source }, + _ => Self::ResourceDatabase { source }, } } } diff --git a/src/resource/mod.rs b/src/resource/mod.rs index 3ae573e4..945676fd 100644 --- a/src/resource/mod.rs +++ b/src/resource/mod.rs @@ -40,11 +40,24 @@ pub trait ResourceApi: Send + Sync + Clone { domain_id: &'a str, ) -> Result, ResourceProviderError>; + async fn find_domain_by_name<'a>( + &self, + db: &DatabaseConnection, + domain_name: &'a str, + ) -> Result, ResourceProviderError>; + async fn get_project<'a>( &self, db: &DatabaseConnection, project_id: &'a str, ) -> Result, ResourceProviderError>; + + async fn get_project_by_name<'a>( + &self, + db: &DatabaseConnection, + name: &'a str, + domain_id: &'a str, + ) -> Result, ResourceProviderError>; } #[cfg(test)] @@ -55,17 +68,31 @@ mock! { #[async_trait] impl ResourceApi for ResourceProvider { - async fn get_domain<'a>( - &self, - db: &DatabaseConnection, - domain_id: &'a str, - ) -> Result, ResourceProviderError>; - - async fn get_project<'a>( - &self, - db: &DatabaseConnection, - project_id: &'a str, - ) -> Result, ResourceProviderError>; + async fn get_domain<'a>( + &self, + db: &DatabaseConnection, + domain_id: &'a str, + ) -> Result, ResourceProviderError>; + + async fn find_domain_by_name<'a>( + &self, + db: &DatabaseConnection, + domain_name: &'a str, + ) -> Result, ResourceProviderError>; + + async fn get_project<'a>( + &self, + db: &DatabaseConnection, + project_id: &'a str, + ) -> Result, ResourceProviderError>; + + async fn get_project_by_name<'a>( + &self, + db: &DatabaseConnection, + name: &'a str, + domain_id: &'a str, + ) -> Result, ResourceProviderError>; + } impl Clone for ResourceProvider { @@ -109,6 +136,18 @@ impl ResourceApi for ResourceProvider { self.backend_driver.get_domain(db, domain_id).await } + /// Get single domain by its name + #[tracing::instrument(level = "info", skip(self, db))] + async fn find_domain_by_name<'a>( + &self, + db: &DatabaseConnection, + domain_name: &'a str, + ) -> Result, ResourceProviderError> { + self.backend_driver + .get_domain_by_name(db, domain_name) + .await + } + /// Get single project #[tracing::instrument(level = "info", skip(self, db))] async fn get_project<'a>( @@ -118,4 +157,17 @@ impl ResourceApi for ResourceProvider { ) -> Result, ResourceProviderError> { self.backend_driver.get_project(db, project_id).await } + + /// Get single project by Name and Domain ID + #[tracing::instrument(level = "info", skip(self, db))] + async fn get_project_by_name<'a>( + &self, + db: &DatabaseConnection, + name: &'a str, + domain_id: &'a str, + ) -> Result, ResourceProviderError> { + self.backend_driver + .get_project_by_name(db, name, domain_id) + .await + } } diff --git a/src/resource/types.rs b/src/resource/types.rs index 4f419571..64c61217 100644 --- a/src/resource/types.rs +++ b/src/resource/types.rs @@ -37,12 +37,27 @@ pub trait ResourceBackend: DynClone + Send + Sync + std::fmt::Debug { domain_id: &'a str, ) -> Result, ResourceProviderError>; + /// Get single domain by Name + async fn get_domain_by_name<'a>( + &self, + db: &DatabaseConnection, + domain_name: &'a str, + ) -> Result, ResourceProviderError>; + /// Get single project by ID async fn get_project<'a>( &self, db: &DatabaseConnection, project_id: &'a str, ) -> Result, ResourceProviderError>; + + /// Get single project by Name and Domain ID + async fn get_project_by_name<'a>( + &self, + db: &DatabaseConnection, + name: &'a str, + domain_id: &'a str, + ) -> Result, ResourceProviderError>; } dyn_clone::clone_trait_object!(ResourceBackend); diff --git a/src/token/domain_scoped.rs b/src/token/domain_scoped.rs index c50c30a1..3e3050fd 100644 --- a/src/token/domain_scoped.rs +++ b/src/token/domain_scoped.rs @@ -26,6 +26,7 @@ use crate::token::{ }; #[derive(Builder, Clone, Debug, Default, PartialEq)] +#[builder(setter(strip_option, into))] pub struct DomainScopeToken { pub user_id: String, #[builder(default, setter(name = _methods))] diff --git a/src/token/error.rs b/src/token/error.rs index ba53ec20..a1f86772 100644 --- a/src/token/error.rs +++ b/src/token/error.rs @@ -101,4 +101,18 @@ pub enum TokenProviderError { #[from] source: crate::token::unscoped::UnscopedTokenBuilderError, }, + + #[error(transparent)] + ProjectScopeBuilder { + /// The source of the error. + #[from] + source: crate::token::project_scoped::ProjectScopeTokenBuilderError, + }, + + #[error(transparent)] + DomainScopeBuilder { + /// The source of the error. + #[from] + source: crate::token::domain_scoped::DomainScopeTokenBuilderError, + }, } diff --git a/src/token/mod.rs b/src/token/mod.rs index d6aacbce..464eb4fe 100644 --- a/src/token/mod.rs +++ b/src/token/mod.rs @@ -27,12 +27,13 @@ pub mod types; pub mod unscoped; use crate::config::{Config, TokenProvider as TokenProviderType}; +use crate::resource::types::{Domain, Project}; pub use error::TokenProviderError; use types::TokenBackend; pub use application_credential::ApplicationCredentialToken; -pub use domain_scoped::DomainScopeToken; -pub use project_scoped::ProjectScopeToken; +pub use domain_scoped::{DomainScopeToken, DomainScopeTokenBuilder}; +pub use project_scoped::{ProjectScopeToken, ProjectScopeTokenBuilder}; pub use types::Token; pub use unscoped::{UnscopedToken, UnscopedTokenBuilder}; @@ -68,6 +69,8 @@ pub trait TokenApi: Send + Sync + Clone { user_id: U, methods: Vec, audit_ids: Vec, + project: Option<&Project>, + domain: Option<&Domain>, ) -> Result where U: AsRef; @@ -78,7 +81,7 @@ pub trait TokenApi: Send + Sync + Clone { #[async_trait] impl TokenApi for TokenProvider { /// Validate token - #[tracing::instrument(level = "info", skip(self))] + #[tracing::instrument(level = "info", skip(self, credential))] async fn validate_token<'a>( &self, credential: &'a str, @@ -101,23 +104,63 @@ impl TokenApi for TokenProvider { user_id: U, methods: Vec, audit_ids: Vec, + project: Option<&Project>, + domain: Option<&Domain>, ) -> Result where U: AsRef, { - let token = Token::Unscoped( - UnscopedTokenBuilder::default() - .user_id(user_id.as_ref()) - .methods(methods.into_iter()) - .audit_ids(audit_ids.into_iter()) - .expires_at( - Local::now() - .to_utc() - .checked_add_signed(TimeDelta::seconds(self.config.token.expiration as i64)) - .ok_or(TokenProviderError::ExpiryCalculation)?, - ) - .build()?, - ); + let token = if let Some(project) = &project { + Token::ProjectScope( + ProjectScopeTokenBuilder::default() + .user_id(user_id.as_ref()) + .methods(methods.into_iter()) + .audit_ids(audit_ids.into_iter()) + .expires_at( + Local::now() + .to_utc() + .checked_add_signed(TimeDelta::seconds( + self.config.token.expiration as i64, + )) + .ok_or(TokenProviderError::ExpiryCalculation)?, + ) + .project_id(project.id.clone()) + .build()?, + ) + } else if let Some(domain) = &domain { + Token::DomainScope( + DomainScopeTokenBuilder::default() + .user_id(user_id.as_ref()) + .methods(methods.into_iter()) + .audit_ids(audit_ids.into_iter()) + .expires_at( + Local::now() + .to_utc() + .checked_add_signed(TimeDelta::seconds( + self.config.token.expiration as i64, + )) + .ok_or(TokenProviderError::ExpiryCalculation)?, + ) + .domain_id(domain.id.clone()) + .build()?, + ) + } else { + Token::Unscoped( + UnscopedTokenBuilder::default() + .user_id(user_id.as_ref()) + .methods(methods.into_iter()) + .audit_ids(audit_ids.into_iter()) + .expires_at( + Local::now() + .to_utc() + .checked_add_signed(TimeDelta::seconds( + self.config.token.expiration as i64, + )) + .ok_or(TokenProviderError::ExpiryCalculation)?, + ) + .build()?, + ) + }; Ok(token) } @@ -147,6 +190,8 @@ mock! { user_id: U, methods: Vec, audit_ids: Vec, + project: Option<&Project>, + domain: Option<&Domain>, ) -> Result where U: AsRef; diff --git a/src/token/project_scoped.rs b/src/token/project_scoped.rs index 99bb53bc..c0b69f57 100644 --- a/src/token/project_scoped.rs +++ b/src/token/project_scoped.rs @@ -26,6 +26,7 @@ use crate::token::{ }; #[derive(Builder, Clone, Debug, Default, PartialEq)] +#[builder(setter(strip_option, into))] pub struct ProjectScopeToken { pub user_id: String, #[builder(default, setter(name = _methods))]