diff --git a/benches/fernet_token.rs b/benches/fernet_token.rs index 795bd129..d5f1649f 100644 --- a/benches/fernet_token.rs +++ b/benches/fernet_token.rs @@ -9,7 +9,7 @@ use openstack_keystone::token::fernet::FernetTokenProvider; use openstack_keystone::token::types::TokenBackend; fn decode(backend: &FernetTokenProvider, token: &str) { - backend.decrypt(token.into()).unwrap(); + backend.decrypt(token).unwrap(); } fn bench_decrypt_token(c: &mut Criterion) { diff --git a/src/api/auth.rs b/src/api/auth.rs index 95f7d471..02d1d5a5 100644 --- a/src/api/auth.rs +++ b/src/api/auth.rs @@ -49,7 +49,7 @@ where state .provider .get_token_provider() - .validate_token(auth_header.to_string(), None) + .validate_token(auth_header, None) .await .map_err(|_| (StatusCode::UNAUTHORIZED, "not authorized"))?, )) diff --git a/src/api/error.rs b/src/api/error.rs index 0460ae63..a4ee3b33 100644 --- a/src/api/error.rs +++ b/src/api/error.rs @@ -48,15 +48,9 @@ pub enum KeystoneApiError { InvalidToken, #[error("error building token data: {}", source)] - TokenBuilder { + Token { #[from] - source: crate::api::v3::auth::token::types::TokenBuilderError, - }, - - #[error("error building token user data: {}", source)] - TokenUserBuilder { - #[from] - source: crate::api::v3::auth::token::types::UserBuilderError, + source: TokenError, }, #[error("internal server error")] @@ -102,7 +96,7 @@ impl IntoResponse for KeystoneApiError { Json(json!({"error": {"code": StatusCode::INTERNAL_SERVER_ERROR.as_u16(), "message": self.to_string()}})), ).into_response() } - KeystoneApiError::SubjectTokenMissing | KeystoneApiError::InvalidHeader | KeystoneApiError::InvalidToken | KeystoneApiError::TokenBuilder{..} | KeystoneApiError::TokenUserBuilder {..}=> { + KeystoneApiError::SubjectTokenMissing | KeystoneApiError::InvalidHeader | KeystoneApiError::InvalidToken | KeystoneApiError::Token{..} => { (StatusCode::BAD_REQUEST, Json(json!({"error": {"code": StatusCode::BAD_REQUEST.as_u16(), "message": self.to_string()}})), ).into_response() @@ -135,3 +129,23 @@ impl KeystoneApiError { } } } + +#[derive(Debug, Error)] +pub enum TokenError { + #[error("error building token data: {}", source)] + Builder { + #[from] + source: crate::api::v3::auth::token::types::TokenBuilderError, + }, + + #[error("error building token user data: {}", source)] + UserBuilder { + #[from] + source: crate::api::v3::auth::token::types::UserBuilderError, + }, + #[error("error building token user data: {}", source)] + ProjectBuilder { + #[from] + source: crate::api::v3::auth::token::types::ProjectBuilderError, + }, +} diff --git a/src/api/v3/auth/token/common.rs b/src/api/v3/auth/token/common.rs new file mode 100644 index 00000000..773498e9 --- /dev/null +++ b/src/api/v3/auth/token/common.rs @@ -0,0 +1,313 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +use crate::api::error::{KeystoneApiError, TokenError}; +use crate::api::v3::auth::token::types::{ProjectBuilder, Token, TokenBuilder, UserBuilder}; +use crate::identity::IdentityApi; +use crate::keystone::ServiceState; +use crate::resource::ResourceApi; +use crate::token::Token as ProviderToken; + +impl Token { + pub async fn from_provider_token( + token: &ProviderToken, + state: &ServiceState, + ) -> Result { + let mut response = TokenBuilder::default(); + response.audit_ids(token.audit_ids().clone()); + response.methods(token.methods().clone()); + response.expires_at(*token.expires_at()); + + let user = state + .provider + .get_identity_provider() + .get_user(&state.db, token.user_id()) + .await + .map_err(KeystoneApiError::identity)? + .ok_or_else(|| KeystoneApiError::NotFound { + resource: "user".into(), + identifier: token.user_id().clone(), + })?; + + let user_domain = state + .provider + .get_resource_provider() + .get_domain(&state.db, &user.domain_id) + .await + .map_err(KeystoneApiError::resource)? + .ok_or_else(|| KeystoneApiError::NotFound { + resource: "domain".into(), + identifier: user.domain_id.clone(), + })?; + + let mut user_response: UserBuilder = UserBuilder::default(); + user_response.id(user.id); + user_response.name(user.name); + user_response.password_expires_at(user.password_expires_at); + user_response.domain(user_domain.clone()); + response.user(user_response.build().map_err(TokenError::from)?); + + match token { + ProviderToken::Unscoped(_token) => { + // Nothing to do + } + ProviderToken::DomainScope(token) => { + if token.domain_id == user.domain_id { + response.domain(user_domain.clone()); + } else { + let domain = state + .provider + .get_resource_provider() + .get_domain(&state.db, &token.domain_id) + .await + .map_err(KeystoneApiError::resource)? + .ok_or_else(|| KeystoneApiError::NotFound { + resource: "domain".into(), + identifier: token.domain_id.clone(), + })?; + response.domain(domain.clone()); + } + } + ProviderToken::ProjectScope(token) => { + let project = state + .provider + .get_resource_provider() + .get_project(&state.db, &token.project_id) + .await + .map_err(KeystoneApiError::resource)? + .ok_or_else(|| KeystoneApiError::NotFound { + resource: "project".into(), + identifier: token.project_id.clone(), + })?; + + let mut project_response = ProjectBuilder::default(); + project_response.id(project.id.clone()); + project_response.name(project.name.clone()); + if project.domain_id == user.domain_id { + project_response.domain(user_domain.clone().into()); + } else { + let project_domain = state + .provider + .get_resource_provider() + .get_domain(&state.db, &project.domain_id) + .await + .map_err(KeystoneApiError::resource)? + .ok_or_else(|| KeystoneApiError::NotFound { + resource: "domain".into(), + identifier: user.domain_id.clone(), + })?; + project_response.domain(project_domain.clone().into()); + } + response.project(project_response.build().map_err(TokenError::from)?); + } + ProviderToken::ApplicationCredential(_token) => { + todo!(); + } + } + Ok(response.build().map_err(TokenError::from)?) + } +} + +#[cfg(test)] +mod tests { + use sea_orm::DatabaseConnection; + use std::sync::Arc; + + use crate::api::v3::auth::token::types::Token; + use crate::config::Config; + + use crate::identity::{MockIdentityProvider, types::User}; + use crate::keystone::Service; + + use crate::provider::ProviderBuilder; + + use crate::resource::{ + MockResourceProvider, + types::{Domain, Project}, + }; + + use crate::token::{ + DomainScopeToken, MockTokenProvider, ProjectScopeToken, Token as ProviderToken, + UnscopedToken, + }; + + #[tokio::test] + async fn test_from_unscoped() { + let db = DatabaseConnection::Disconnected; + let config = Config::default(); + let mut identity_mock = MockIdentityProvider::default(); + identity_mock + .expect_get_user() + .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") + .returning(|_, _| { + Ok(Some(User { + id: "bar".into(), + domain_id: "user_domain_id".into(), + ..Default::default() + })) + }); + + let mut resource_mock = MockResourceProvider::default(); + resource_mock + .expect_get_domain() + .withf(|_: &DatabaseConnection, id: &'_ str| id == "user_domain_id") + .returning(|_, _| { + Ok(Some(Domain { + id: "user_domain_id".into(), + ..Default::default() + })) + }); + let token_mock = MockTokenProvider::default(); + let provider = ProviderBuilder::default() + .config(config.clone()) + .identity(identity_mock) + .resource(resource_mock) + .token(token_mock) + .build() + .unwrap(); + + let state = Arc::new(Service::new(config, db, provider).unwrap()); + + let api_token = Token::from_provider_token( + &ProviderToken::Unscoped(UnscopedToken { + user_id: "bar".into(), + ..Default::default() + }), + &state, + ) + .await + .unwrap(); + assert_eq!("bar", api_token.user.id); + assert_eq!("user_domain_id", api_token.user.domain.id); + assert!(api_token.project.is_none()); + assert!(api_token.domain.is_none()); + } + + #[tokio::test] + async fn test_from_domain_scoped() { + let db = DatabaseConnection::Disconnected; + let config = Config::default(); + let mut identity_mock = MockIdentityProvider::default(); + identity_mock + .expect_get_user() + .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") + .returning(|_, _| { + Ok(Some(User { + id: "bar".into(), + domain_id: "user_domain_id".into(), + ..Default::default() + })) + }); + + let mut resource_mock = MockResourceProvider::default(); + resource_mock + .expect_get_domain() + .returning(|_, id: &'_ str| { + Ok(Some(Domain { + id: id.to_string(), + ..Default::default() + })) + }); + let token_mock = MockTokenProvider::default(); + let provider = ProviderBuilder::default() + .config(config.clone()) + .identity(identity_mock) + .resource(resource_mock) + .token(token_mock) + .build() + .unwrap(); + + let state = Arc::new(Service::new(config, db, provider).unwrap()); + + let api_token = Token::from_provider_token( + &ProviderToken::DomainScope(DomainScopeToken { + user_id: "bar".into(), + domain_id: "domain_id".into(), + ..Default::default() + }), + &state, + ) + .await + .unwrap(); + + assert_eq!("bar", api_token.user.id); + assert_eq!("user_domain_id", api_token.user.domain.id); + assert_eq!("domain_id", api_token.domain.expect("domain scope").id); + assert!(api_token.project.is_none()); + } + + #[tokio::test] + async fn test_from_project_scoped() { + let db = DatabaseConnection::Disconnected; + let config = Config::default(); + let mut identity_mock = MockIdentityProvider::default(); + identity_mock + .expect_get_user() + .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") + .returning(|_, _| { + Ok(Some(User { + id: "bar".into(), + domain_id: "user_domain_id".into(), + ..Default::default() + })) + }); + + let mut resource_mock = MockResourceProvider::default(); + resource_mock + .expect_get_domain() + .returning(|_, id: &'_ str| { + Ok(Some(Domain { + id: id.to_string(), + ..Default::default() + })) + }); + resource_mock + .expect_get_project() + .returning(|_, id: &'_ str| { + Ok(Some(Project { + id: id.to_string(), + domain_id: "project_domain_id".into(), + ..Default::default() + })) + }); + let token_mock = MockTokenProvider::default(); + let provider = ProviderBuilder::default() + .config(config.clone()) + .identity(identity_mock) + .resource(resource_mock) + .token(token_mock) + .build() + .unwrap(); + + let state = Arc::new(Service::new(config, db, provider).unwrap()); + + let api_token = Token::from_provider_token( + &ProviderToken::ProjectScope(ProjectScopeToken { + user_id: "bar".into(), + project_id: "project_id".into(), + ..Default::default() + }), + &state, + ) + .await + .unwrap(); + + assert_eq!("bar", api_token.user.id); + assert_eq!("user_domain_id", api_token.user.domain.id); + let project = api_token.project.expect("project_scope"); + assert_eq!("project_domain_id", project.domain.id); + assert_eq!("project_id", project.id); + assert!(api_token.domain.is_none()); + } +} diff --git a/src/api/v3/auth/token/mod.rs b/src/api/v3/auth/token/mod.rs index 63e77475..f0f0fc45 100644 --- a/src/api/v3/auth/token/mod.rs +++ b/src/api/v3/auth/token/mod.rs @@ -17,16 +17,15 @@ use utoipa_axum::{router::OpenApiRouter, routes}; use crate::api::auth::Auth; use crate::api::error::KeystoneApiError; -use crate::identity::IdentityApi; use crate::keystone::ServiceState; -use crate::resource::ResourceApi; use crate::token::TokenApi; -use types::{TokenBuilder, TokenResponse, UserBuilder}; +use types::{Token as ApiResponseToken, TokenResponse}; +mod common; pub mod types; pub(super) fn openapi_router() -> OpenApiRouter { - OpenApiRouter::new().routes(routes!(validate)) + OpenApiRouter::new().routes(routes!(show)) } /// Validate token @@ -41,7 +40,7 @@ pub(super) fn openapi_router() -> OpenApiRouter { tag="auth" )] #[tracing::instrument(name = "api::token_get", level = "debug", skip(state))] -async fn validate( +async fn show( Auth(user_auth): Auth, headers: HeaderMap, State(state): State, @@ -56,46 +55,14 @@ async fn validate( let token = state .provider .get_token_provider() - .validate_token(subject_token, None) + .validate_token(&subject_token, None) .await .map_err(|_| KeystoneApiError::InvalidToken)?; - let mut response = TokenBuilder::default(); - response.audit_ids(token.audit_ids().clone()); - response.methods(token.methods().clone()); - response.expires_at(*token.expires_at()); - - let user = state - .provider - .get_identity_provider() - .get_user(&state.db, token.user_id().clone()) - .await - .map_err(KeystoneApiError::identity)? - .ok_or_else(|| KeystoneApiError::NotFound { - resource: "user".into(), - identifier: token.user_id().clone(), - })?; - - let user_domain = state - .provider - .get_resource_provider() - .get_domain(&state.db, user.domain_id.clone()) - .await - .map_err(KeystoneApiError::resource)? - .ok_or_else(|| KeystoneApiError::NotFound { - resource: "domain".into(), - identifier: user.domain_id.clone(), - })?; - - let mut user_response: UserBuilder = UserBuilder::default(); - user_response.id(user.id); - user_response.name(user.name); - user_response.password_expires_at(user.password_expires_at); - user_response.domain(user_domain); - response.user(user_response.build()?); + let response_token = ApiResponseToken::from_provider_token(&token, &state).await?; Ok(TokenResponse { - token: response.build()?, + token: response_token, }) } @@ -119,31 +86,28 @@ mod tests { use crate::provider::ProviderBuilder; use crate::resource::{MockResourceProvider, types::Domain}; use crate::tests::api::get_mocked_state_unauthed; - use crate::token::{MockTokenProvider, Token, UnscopedToken}; + use crate::token::*; #[tokio::test] async fn test_get() { let db = DatabaseConnection::Disconnected; let config = Config::default(); let mut identity_mock = MockIdentityProvider::default(); - identity_mock - .expect_get_user() - .withf(|_: &DatabaseConnection, id: &String| *id == "bar") - .returning(|_, _| { - Ok(Some(User { - id: "bar".into(), - domain_id: "domain_id".into(), - ..Default::default() - })) - }); + identity_mock.expect_get_user().returning(|_, id: &'_ str| { + Ok(Some(User { + id: id.to_string(), + domain_id: "user_domain_id".into(), + ..Default::default() + })) + }); let mut resource_mock = MockResourceProvider::default(); resource_mock .expect_get_domain() - .withf(|_: &DatabaseConnection, id: &String| *id == "domain_id") + .withf(|_: &DatabaseConnection, id: &'_ str| id == "user_domain_id") .returning(|_, _| { Ok(Some(Domain { - id: "domain_id".into(), + id: "user_domain_id".into(), ..Default::default() })) }); diff --git a/src/api/v3/auth/token/types.rs b/src/api/v3/auth/token/types.rs index dafcba30..92c5a2b2 100644 --- a/src/api/v3/auth/token/types.rs +++ b/src/api/v3/auth/token/types.rs @@ -50,15 +50,21 @@ pub struct Token { /// The date and time when the token expires. pub expires_at: DateTime, + /// A user object. + #[builder(default)] + pub user: User, + /// A project object including the id, name and domain object representing the project the /// token is scoped to. This is only included in tokens that are scoped to a project. #[serde(skip_serializing_if = "Option::is_none")] #[builder(default)] pub project: Option, - /// A user object. + /// A domain object including the id and name representing the domain the token is scoped to. + /// This is only included in tokens that are scoped to a domain. + #[serde(skip_serializing_if = "Option::is_none")] #[builder(default)] - pub user: User, + pub domain: Option, } #[derive(Builder, Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] @@ -75,12 +81,15 @@ impl IntoResponse for TokenResponse { } /// Project information -#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] +#[derive(Builder, Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] pub struct Project { /// Project ID pub id: String, /// Project Name pub name: String, + + /// project domain + pub domain: Domain, } /// User information @@ -94,7 +103,6 @@ pub struct User { /// User domain pub domain: Domain, /// User password expiry date - #[serde(skip_serializing_if = "Option::is_none")] pub password_expires_at: Option>, } @@ -108,15 +116,6 @@ pub struct Domain { pub name: String, } -//impl From for User { -// fn from(value: identity_provider_types::User) -> Self { -// Self { -// id: value.id.clone(), -// name: value.name.clone(), -// } -// } -//} - impl From for Domain { fn from(value: resource_provider_types::Domain) -> Self { Self { diff --git a/src/api/v3/group/mod.rs b/src/api/v3/group/mod.rs index 132a0183..922c4e21 100644 --- a/src/api/v3/group/mod.rs +++ b/src/api/v3/group/mod.rs @@ -85,7 +85,7 @@ async fn show( state .provider .get_identity_provider() - .get_group(&state.db, group_id.clone()) + .get_group(&state.db, &group_id) .await .map(|x| { x.ok_or_else(|| KeystoneApiError::NotFound { @@ -142,7 +142,7 @@ async fn remove( state .provider .get_identity_provider() - .delete_group(&state.db, group_id) + .delete_group(&state.db, &group_id) .await .map_err(KeystoneApiError::identity)?; Ok((StatusCode::NO_CONTENT).into_response()) @@ -282,12 +282,12 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_get_group() - .withf(|_: &DatabaseConnection, id: &String| *id == "foo") + .withf(|_: &DatabaseConnection, id: &'_ str| id == "foo") .returning(|_, _| Ok(None)); identity_mock .expect_get_group() - .withf(|_: &DatabaseConnection, id: &String| *id == "bar") + .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") .returning(|_, _| { Ok(Some(Group { id: "bar".into(), @@ -399,12 +399,12 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_delete_group() - .withf(|_: &DatabaseConnection, id: &String| *id == "foo") + .withf(|_: &DatabaseConnection, id: &'_ str| id == "foo") .returning(|_, _| Err(IdentityProviderError::GroupNotFound("foo".into()))); identity_mock .expect_delete_group() - .withf(|_: &DatabaseConnection, id: &String| *id == "bar") + .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") .returning(|_, _| Ok(())); let state = get_mocked_state(identity_mock); diff --git a/src/api/v3/user/mod.rs b/src/api/v3/user/mod.rs index 4e9423b3..d595d69a 100644 --- a/src/api/v3/user/mod.rs +++ b/src/api/v3/user/mod.rs @@ -84,7 +84,7 @@ async fn show( state .provider .get_identity_provider() - .get_user(&state.db, user_id.clone()) + .get_user(&state.db, &user_id) .await .map(|x| { x.ok_or_else(|| KeystoneApiError::NotFound { @@ -141,7 +141,7 @@ async fn remove( state .provider .get_identity_provider() - .delete_user(&state.db, user_id) + .delete_user(&state.db, &user_id) .await .map_err(KeystoneApiError::identity)?; Ok((StatusCode::NO_CONTENT).into_response()) @@ -331,12 +331,12 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_get_user() - .withf(|_: &DatabaseConnection, id: &String| *id == "foo") + .withf(|_: &DatabaseConnection, id: &'_ str| id == "foo") .returning(|_, _| Ok(None)); identity_mock .expect_get_user() - .withf(|_: &DatabaseConnection, id: &String| *id == "bar") + .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") .returning(|_, _| { Ok(Some(User { id: "bar".into(), @@ -395,12 +395,12 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_delete_user() - .withf(|_: &DatabaseConnection, id: &String| *id == "foo") + .withf(|_: &DatabaseConnection, id: &'_ str| id == "foo") .returning(|_, _| Err(IdentityProviderError::UserNotFound("foo".into()))); identity_mock .expect_delete_user() - .withf(|_: &DatabaseConnection, id: &String| *id == "bar") + .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") .returning(|_, _| Ok(())); let state = get_mocked_state(identity_mock); diff --git a/src/identity/backends/sql.rs b/src/identity/backends/sql.rs index 56000e69..3f5df991 100644 --- a/src/identity/backends/sql.rs +++ b/src/identity/backends/sql.rs @@ -63,10 +63,10 @@ impl IdentityBackend for SqlBackend { /// Get single user by ID #[tracing::instrument(level = "debug", skip(self, db))] - async fn get_user( + async fn get_user<'a>( &self, db: &DatabaseConnection, - user_id: String, + user_id: &'a str, ) -> Result, IdentityProviderError> { Ok(get_user(&self.config, db, user_id).await?) } @@ -83,10 +83,10 @@ impl IdentityBackend for SqlBackend { /// Delete user #[tracing::instrument(level = "debug", skip(self, db))] - async fn delete_user( + async fn delete_user<'a>( &self, db: &DatabaseConnection, - user_id: String, + user_id: &'a str, ) -> Result<(), IdentityProviderError> { user::delete(&self.config, db, user_id) .await @@ -105,10 +105,10 @@ impl IdentityBackend for SqlBackend { /// Get single group by ID #[tracing::instrument(level = "debug", skip(self, db))] - async fn get_group( + async fn get_group<'a>( &self, db: &DatabaseConnection, - group_id: String, + group_id: &'a str, ) -> Result, IdentityProviderError> { Ok(group::get(&self.config, db, group_id).await?) } @@ -125,10 +125,10 @@ impl IdentityBackend for SqlBackend { /// Delete group #[tracing::instrument(level = "debug", skip(self, db))] - async fn delete_group( + async fn delete_group<'a>( &self, db: &DatabaseConnection, - group_id: String, + group_id: &'a str, ) -> Result<(), IdentityProviderError> { group::delete(&self.config, db, group_id) .await @@ -222,9 +222,9 @@ async fn list_users( pub async fn get_user( conf: &Config, db: &DatabaseConnection, - user_id: String, + user_id: &str, ) -> Result, IdentityDatabaseError> { - let user_select = DbUser::find_by_id(&user_id); + let user_select = DbUser::find_by_id(user_id); let user_entry: Option = user_select.one(db).await?; @@ -249,7 +249,7 @@ pub async fn get_user( if !federated_user.is_empty() { common::get_federated_user_builder(user, federated_user, user_opts) } else { - return Err(IdentityDatabaseError::MalformedUser(user_id.clone()))?; + return Err(IdentityDatabaseError::MalformedUser(user_id.to_string()))?; } } }, @@ -369,7 +369,7 @@ mod tests { .into_connection(); let config = Config::default(); assert_eq!( - get_user(&config, &db, "1".into()).await.unwrap().unwrap(), + get_user(&config, &db, "1").await.unwrap().unwrap(), User { id: "1".into(), domain_id: "foo_domain".into(), diff --git a/src/identity/backends/sql/group.rs b/src/identity/backends/sql/group.rs index f5bde82c..7b693d49 100644 --- a/src/identity/backends/sql/group.rs +++ b/src/identity/backends/sql/group.rs @@ -47,16 +47,13 @@ pub async fn list( pub async fn get( _conf: &Config, db: &DatabaseConnection, - group_id: String, + group_id: &str, ) -> Result, IdentityDatabaseError> { - Ok(DbGroup::find_by_id(&group_id) - .one(db) - .await? - .map(Into::into)) + Ok(DbGroup::find_by_id(group_id).one(db).await?.map(Into::into)) } pub async fn create( - conf: &Config, + _conf: &Config, db: &DatabaseConnection, group: GroupCreate, ) -> Result { @@ -76,13 +73,13 @@ pub async fn create( pub async fn delete( _conf: &Config, db: &DatabaseConnection, - group_id: String, + group_id: &str, ) -> Result<(), IdentityDatabaseError> { - let res = DbGroup::delete_by_id(&group_id).exec(db).await?; + let res = DbGroup::delete_by_id(group_id).exec(db).await?; if res.rows_affected == 1 { Ok(()) } else { - Err(IdentityDatabaseError::GroupNotFound(group_id)) + Err(IdentityDatabaseError::GroupNotFound(group_id.to_string())) } } @@ -199,7 +196,7 @@ mod tests { let config = Config::default(); assert_eq!( - get(&config, &db, "id".into()).await.unwrap(), + get(&config, &db, "id").await.unwrap(), Some(Group { id: "1".into(), domain_id: "foo_domain".into(), @@ -208,7 +205,7 @@ mod tests { extra: Some(json!({"foo": "bar"})) }) ); - assert!(get(&config, &db, "missing".into()).await.unwrap().is_none()); + assert!(get(&config, &db, "missing").await.unwrap().is_none()); // Checking transaction log assert_eq!( @@ -275,7 +272,7 @@ mod tests { .into_connection(); let config = Config::default(); - delete(&config, &db, "id".into()).await.unwrap(); + delete(&config, &db, "id").await.unwrap(); // Checking transaction log assert_eq!( db.into_transaction_log(), diff --git a/src/identity/backends/sql/user.rs b/src/identity/backends/sql/user.rs index 404e3a00..fbab9187 100644 --- a/src/identity/backends/sql/user.rs +++ b/src/identity/backends/sql/user.rs @@ -58,13 +58,13 @@ pub(super) async fn create( pub async fn delete( _conf: &Config, db: &DatabaseConnection, - user_id: String, + user_id: &str, ) -> Result<(), IdentityDatabaseError> { - let res = DbUser::delete_by_id(&user_id).exec(db).await?; + let res = DbUser::delete_by_id(user_id).exec(db).await?; if res.rows_affected == 1 { Ok(()) } else { - Err(IdentityDatabaseError::UserNotFound(user_id)) + Err(IdentityDatabaseError::UserNotFound(user_id.to_string())) } } @@ -89,7 +89,7 @@ mod tests { .into_connection(); let config = Config::default(); - delete(&config, &db, "id".into()).await.unwrap(); + delete(&config, &db, "id").await.unwrap(); // Checking transaction log assert_eq!( db.into_transaction_log(), diff --git a/src/identity/error.rs b/src/identity/error.rs index 50f648bb..770ab0c1 100644 --- a/src/identity/error.rs +++ b/src/identity/error.rs @@ -32,6 +32,7 @@ pub enum IdentityProviderError { #[error("user {0} not found")] UserNotFound(String), + #[error("group {0} not found")] GroupNotFound(String), diff --git a/src/identity/mod.rs b/src/identity/mod.rs index 43f7e2eb..5c8a5e4f 100644 --- a/src/identity/mod.rs +++ b/src/identity/mod.rs @@ -45,10 +45,10 @@ pub trait IdentityApi: Send + Sync + Clone { params: &UserListParameters, ) -> Result, IdentityProviderError>; - async fn get_user( + async fn get_user<'a>( &self, db: &DatabaseConnection, - user_id: String, + user_id: &'a str, ) -> Result, IdentityProviderError>; async fn create_user( @@ -57,10 +57,10 @@ pub trait IdentityApi: Send + Sync + Clone { user: UserCreate, ) -> Result; - async fn delete_user( + async fn delete_user<'a>( &self, db: &DatabaseConnection, - user_id: String, + user_id: &'a str, ) -> Result<(), IdentityProviderError>; async fn list_groups( @@ -69,10 +69,10 @@ pub trait IdentityApi: Send + Sync + Clone { params: &GroupListParameters, ) -> Result, IdentityProviderError>; - async fn get_group( + async fn get_group<'a>( &self, db: &DatabaseConnection, - group_id: String, + group_id: &'a str, ) -> Result, IdentityProviderError>; async fn create_group( @@ -81,10 +81,10 @@ pub trait IdentityApi: Send + Sync + Clone { group: GroupCreate, ) -> Result; - async fn delete_group( + async fn delete_group<'a>( &self, db: &DatabaseConnection, - group_id: String, + group_id: &'a str, ) -> Result<(), IdentityProviderError>; } @@ -102,10 +102,10 @@ mock! { params: &UserListParameters, ) -> Result, IdentityProviderError>; - async fn get_user( + async fn get_user<'a>( &self, db: &DatabaseConnection, - user_id: String, + user_id: &'a str, ) -> Result, IdentityProviderError>; async fn create_user( @@ -114,10 +114,10 @@ mock! { user: UserCreate, ) -> Result; - async fn delete_user( + async fn delete_user<'a>( &self, db: &DatabaseConnection, - user_id: String, + user_id: &'a str, ) -> Result<(), IdentityProviderError>; async fn list_groups( @@ -126,10 +126,10 @@ mock! { params: &GroupListParameters, ) -> Result, IdentityProviderError>; - async fn get_group( + async fn get_group<'a>( &self, db: &DatabaseConnection, - group_id: String, + group_id: &'a str, ) -> Result, IdentityProviderError>; async fn create_group( @@ -138,10 +138,10 @@ mock! { group: GroupCreate, ) -> Result; - async fn delete_group( + async fn delete_group<'a>( &self, db: &DatabaseConnection, - group_id: String, + group_id: &'a str, ) -> Result<(), IdentityProviderError>; } @@ -189,10 +189,10 @@ impl IdentityApi for IdentityProvider { /// Get single user #[tracing::instrument(level = "info", skip(self, db))] - async fn get_user( + async fn get_user<'a>( &self, db: &DatabaseConnection, - user_id: String, + user_id: &'a str, ) -> Result, IdentityProviderError> { self.backend_driver.get_user(db, user_id).await } @@ -214,10 +214,10 @@ impl IdentityApi for IdentityProvider { /// Delete user #[tracing::instrument(level = "info", skip(self, db))] - async fn delete_user( + async fn delete_user<'a>( &self, db: &DatabaseConnection, - user_id: String, + user_id: &'a str, ) -> Result<(), IdentityProviderError> { self.backend_driver.delete_user(db, user_id).await } @@ -234,10 +234,10 @@ impl IdentityApi for IdentityProvider { /// Get single group #[tracing::instrument(level = "info", skip(self, db))] - async fn get_group( + async fn get_group<'a>( &self, db: &DatabaseConnection, - group_id: String, + group_id: &'a str, ) -> Result, IdentityProviderError> { self.backend_driver.get_group(db, group_id).await } @@ -256,10 +256,10 @@ impl IdentityApi for IdentityProvider { /// Delete group #[tracing::instrument(level = "info", skip(self, db))] - async fn delete_group( + async fn delete_group<'a>( &self, db: &DatabaseConnection, - group_id: String, + group_id: &'a str, ) -> Result<(), IdentityProviderError> { self.backend_driver.delete_group(db, group_id).await } diff --git a/src/identity/types.rs b/src/identity/types.rs index f2731a6b..aff26f7a 100644 --- a/src/identity/types.rs +++ b/src/identity/types.rs @@ -19,7 +19,7 @@ use async_trait::async_trait; use dyn_clone::DynClone; use sea_orm::DatabaseConnection; -use crate::identity::Config; +use crate::config::Config; use crate::identity::IdentityProviderError; pub use crate::identity::types::group::{Group, GroupCreate, GroupListParameters}; @@ -40,10 +40,10 @@ pub trait IdentityBackend: DynClone + Send + Sync + std::fmt::Debug { ) -> Result, IdentityProviderError>; /// Get single user by ID - async fn get_user( + async fn get_user<'a>( &self, db: &DatabaseConnection, - user_id: String, + user_id: &'a str, ) -> Result, IdentityProviderError>; /// Create user @@ -54,10 +54,10 @@ pub trait IdentityBackend: DynClone + Send + Sync + std::fmt::Debug { ) -> Result; /// Delete user - async fn delete_user( + async fn delete_user<'a>( &self, db: &DatabaseConnection, - user_id: String, + user_id: &'a str, ) -> Result<(), IdentityProviderError>; /// List groups @@ -68,10 +68,10 @@ pub trait IdentityBackend: DynClone + Send + Sync + std::fmt::Debug { ) -> Result, IdentityProviderError>; /// Get single group by ID - async fn get_group( + async fn get_group<'a>( &self, db: &DatabaseConnection, - group_id: String, + group_id: &'a str, ) -> Result, IdentityProviderError>; /// Create group @@ -82,10 +82,10 @@ pub trait IdentityBackend: DynClone + Send + Sync + std::fmt::Debug { ) -> Result; /// Delete group by ID - async fn delete_group( + async fn delete_group<'a>( &self, db: &DatabaseConnection, - group_id: String, + group_id: &'a str, ) -> Result<(), IdentityProviderError>; } diff --git a/src/resource/backends/error.rs b/src/resource/backends/error.rs index 27740473..57524d12 100644 --- a/src/resource/backends/error.rs +++ b/src/resource/backends/error.rs @@ -27,12 +27,18 @@ pub enum ResourceDatabaseError { source: serde_json::Error, }, - #[error("building domain data")] + #[error("error building domain data")] DomainBuilderError { #[from] source: DomainBuilderError, }, + #[error("error building project data")] + ProjectBuilderError { + #[from] + source: ProjectBuilderError, + }, + #[error("database data")] Database { #[from] diff --git a/src/resource/backends/sql.rs b/src/resource/backends/sql.rs index 8a3d26da..6120fa59 100644 --- a/src/resource/backends/sql.rs +++ b/src/resource/backends/sql.rs @@ -18,8 +18,6 @@ use sea_orm::entity::*; use sea_orm::query::*; use serde_json::Value; -mod domain; - use super::super::types::*; use crate::config::Config; use crate::db::entity::{prelude::Project as DbProject, project as db_project}; @@ -42,41 +40,86 @@ impl ResourceBackend for SqlBackend { /// Get single domain by ID #[tracing::instrument(level = "debug", skip(self, db))] - async fn get_domain( + async fn get_domain<'a>( &self, db: &DatabaseConnection, - domain_id: String, + domain_id: &'a str, ) -> Result, ResourceProviderError> { Ok(get_domain(&self.config, db, domain_id).await?) } + + /// Get single project by ID + #[tracing::instrument(level = "debug", skip(self, db))] + async fn get_project<'a>( + &self, + db: &DatabaseConnection, + project_id: &'a str, + ) -> Result, ResourceProviderError> { + Ok(get_project(&self.config, db, project_id).await?) + } } pub async fn get_domain( - conf: &Config, + _conf: &Config, db: &DatabaseConnection, - domain_id: String, + domain_id: &str, ) -> Result, ResourceDatabaseError> { let domain_select = - DbProject::find_by_id(&domain_id).filter(db_project::Column::IsDomain.eq(true)); + DbProject::find_by_id(domain_id).filter(db_project::Column::IsDomain.eq(true)); let domain_entry: Option = domain_select.one(db).await?; + domain_entry.map(TryInto::try_into).transpose() +} + +pub async fn get_project( + _conf: &Config, + db: &DatabaseConnection, + domain_id: &str, +) -> Result, ResourceDatabaseError> { + let project_select = + DbProject::find_by_id(domain_id).filter(db_project::Column::IsDomain.eq(false)); + + let project_entry: Option = project_select.one(db).await?; + project_entry.map(TryInto::try_into).transpose() +} + +impl TryFrom for Project { + type Error = ResourceDatabaseError; + + fn try_from(value: db_project::Model) -> Result { + let mut project_builder = ProjectBuilder::default(); + project_builder.id(value.id.clone()); + project_builder.name(value.name.clone()); + project_builder.domain_id(value.domain_id.clone()); + if let Some(description) = &value.description { + project_builder.description(description.clone()); + } + project_builder.enabled(value.enabled.unwrap_or(false)); + if let Some(extra) = &value.extra { + project_builder.extra(serde_json::from_str::(extra).unwrap()); + } - if let Some(domain) = &domain_entry { + Ok(project_builder.build()?) + } +} + +impl TryFrom for Domain { + type Error = ResourceDatabaseError; + + fn try_from(value: db_project::Model) -> Result { let mut domain_builder = DomainBuilder::default(); - domain_builder.id(domain.id.clone()); - domain_builder.name(domain.name.clone()); - if let Some(description) = &domain.description { + domain_builder.id(value.id.clone()); + domain_builder.name(value.name.clone()); + if let Some(description) = &value.description { domain_builder.description(description.clone()); } - domain_builder.enabled(domain.enabled.unwrap_or(false)); - if let Some(extra) = &domain.extra { + domain_builder.enabled(value.enabled.unwrap_or(false)); + if let Some(extra) = &value.extra { domain_builder.extra(serde_json::from_str::(extra).unwrap()); } - return Ok(Some(domain_builder.build()?)); + Ok(domain_builder.build()?) } - - Ok(None) } //#[cfg(test)] diff --git a/src/resource/mod.rs b/src/resource/mod.rs index 2f9697ad..3ae573e4 100644 --- a/src/resource/mod.rs +++ b/src/resource/mod.rs @@ -25,7 +25,7 @@ use crate::config::Config; use crate::plugin_manager::PluginManager; use crate::resource::backends::sql::SqlBackend; use crate::resource::error::ResourceProviderError; -use crate::resource::types::{Domain, ResourceBackend}; +use crate::resource::types::{Domain, Project, ResourceBackend}; #[derive(Clone, Debug)] pub struct ResourceProvider { @@ -34,11 +34,17 @@ pub struct ResourceProvider { #[async_trait] pub trait ResourceApi: Send + Sync + Clone { - async fn get_domain( + async fn get_domain<'a>( &self, db: &DatabaseConnection, - domain_id: String, + domain_id: &'a str, ) -> Result, ResourceProviderError>; + + async fn get_project<'a>( + &self, + db: &DatabaseConnection, + project_id: &'a str, + ) -> Result, ResourceProviderError>; } #[cfg(test)] @@ -49,11 +55,17 @@ mock! { #[async_trait] impl ResourceApi for ResourceProvider { - async fn get_domain( + async fn get_domain<'a>( &self, db: &DatabaseConnection, - domain_id: String, + domain_id: &'a str, ) -> Result, ResourceProviderError>; + + async fn get_project<'a>( + &self, + db: &DatabaseConnection, + project_id: &'a str, + ) -> Result, ResourceProviderError>; } impl Clone for ResourceProvider { @@ -89,11 +101,21 @@ impl ResourceProvider { impl ResourceApi for ResourceProvider { /// Get single domain #[tracing::instrument(level = "info", skip(self, db))] - async fn get_domain( + async fn get_domain<'a>( &self, db: &DatabaseConnection, - domain_id: String, + domain_id: &'a str, ) -> Result, ResourceProviderError> { self.backend_driver.get_domain(db, domain_id).await } + + /// Get single project + #[tracing::instrument(level = "info", skip(self, db))] + async fn get_project<'a>( + &self, + db: &DatabaseConnection, + project_id: &'a str, + ) -> Result, ResourceProviderError> { + self.backend_driver.get_project(db, project_id).await + } } diff --git a/src/resource/types.rs b/src/resource/types.rs index 842b1cc5..4f419571 100644 --- a/src/resource/types.rs +++ b/src/resource/types.rs @@ -13,6 +13,7 @@ // SPDX-License-Identifier: Apache-2.0 pub mod domain; +pub mod project; use async_trait::async_trait; use dyn_clone::DynClone; @@ -22,6 +23,7 @@ use crate::config::Config; use crate::resource::ResourceProviderError; pub use crate::resource::types::domain::{Domain, DomainBuilder, DomainBuilderError}; +pub use crate::resource::types::project::{Project, ProjectBuilder, ProjectBuilderError}; #[async_trait] pub trait ResourceBackend: DynClone + Send + Sync + std::fmt::Debug { @@ -29,11 +31,18 @@ pub trait ResourceBackend: DynClone + Send + Sync + std::fmt::Debug { fn set_config(&mut self, config: Config); /// Get single domain by ID - async fn get_domain( + async fn get_domain<'a>( &self, db: &DatabaseConnection, - domain_id: String, + domain_id: &'a str, ) -> Result, ResourceProviderError>; + + /// Get single project by ID + async fn get_project<'a>( + &self, + db: &DatabaseConnection, + project_id: &'a str, + ) -> Result, ResourceProviderError>; } dyn_clone::clone_trait_object!(ResourceBackend); diff --git a/src/resource/types/domain.rs b/src/resource/types/domain.rs index 34164bae..f36487e9 100644 --- a/src/resource/types/domain.rs +++ b/src/resource/types/domain.rs @@ -27,7 +27,7 @@ pub struct Domain { /// The resource description #[builder(default)] pub description: Option, - /// Additional user properties + /// Additional domain properties #[builder(default)] pub extra: Option, } diff --git a/src/resource/backends/sql/domain.rs b/src/resource/types/project.rs similarity index 50% rename from src/resource/backends/sql/domain.rs rename to src/resource/types/project.rs index 85ce919c..131a5613 100644 --- a/src/resource/backends/sql/domain.rs +++ b/src/resource/types/project.rs @@ -12,13 +12,24 @@ // // SPDX-License-Identifier: Apache-2.0 -//use chrono::Local; -//use sea_orm::DatabaseConnection; -//use sea_orm::entity::*; -// -//use crate::config::Config; -// use crate::db::entity::{prelude::Domain as DbDomain, domain}; -//use crate::resource::backends::error::ResourceDatabaseError; +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; +use serde_json::Value; -#[cfg(test)] -mod tests {} +#[derive(Builder, Clone, Debug, Default, Deserialize, PartialEq, Serialize)] +#[builder(setter(strip_option, into))] +pub struct Project { + /// The project ID. + pub id: String, + /// The project name. + pub name: String, + /// The project domain_id. + pub domain_id: String, + pub enabled: bool, + /// The resource description + #[builder(default)] + pub description: Option, + /// Additional project properties + #[builder(default)] + pub extra: Option, +} diff --git a/src/token/fernet.rs b/src/token/fernet.rs index f76c105d..3f15121d 100644 --- a/src/token/fernet.rs +++ b/src/token/fernet.rs @@ -118,12 +118,12 @@ impl FernetTokenProvider { /// /// 1. Decrypt as Fernet /// 2. Unpack MessagePack payload - pub fn decrypt(&self, credential: String) -> Result { + pub fn decrypt(&self, credential: &str) -> Result { // TODO: Implement fernet keys change watching. Keystone loads them from FS on every // request and in the best case it costs 15µs. let payload = match &self.fernet { - Some(fernet) => fernet.decrypt(credential.as_ref())?, - _ => self.get_fernet()?.decrypt(credential.as_ref())?, + Some(fernet) => fernet.decrypt(credential)?, + _ => self.get_fernet()?.decrypt(credential)?, }; self.parse(&mut payload.as_slice()) } @@ -148,7 +148,7 @@ impl TokenBackend for FernetTokenProvider { } /// Extract token - fn extract(&self, credential: String) -> Result { + fn extract(&self, credential: &str) -> Result { self.decrypt(credential) } } @@ -187,7 +187,7 @@ pub(super) mod tests { backend.set_config(config); backend.load_keys().unwrap(); - if let Token::Unscoped(decrypted) = backend.decrypt(token.into()).unwrap() { + if let Token::Unscoped(decrypted) = backend.decrypt(token).unwrap() { assert_eq!(decrypted.user_id, "4b7d364ad87d400bbd91798e3c15e9c2"); assert_eq!(decrypted.methods, vec!["token"]); assert_eq!( @@ -212,7 +212,7 @@ pub(super) mod tests { backend.set_config(config); backend.load_keys().unwrap(); - if let Token::DomainScope(decrypted) = backend.decrypt(token.into()).unwrap() { + if let Token::DomainScope(decrypted) = backend.decrypt(token).unwrap() { assert_eq!(decrypted.user_id, "4b7d364ad87d400bbd91798e3c15e9c2"); assert_eq!(decrypted.domain_id, "default"); assert_eq!(decrypted.methods, vec!["password"]); @@ -235,7 +235,7 @@ pub(super) mod tests { backend.set_config(config); backend.load_keys().unwrap(); - if let Token::ProjectScope(decrypted) = backend.decrypt(token.into()).unwrap() { + if let Token::ProjectScope(decrypted) = backend.decrypt(token).unwrap() { assert_eq!(decrypted.user_id, "4b7d364ad87d400bbd91798e3c15e9c2"); assert_eq!(decrypted.project_id, "97cd761d581b485792a4afc8cc6a998d"); assert_eq!(decrypted.methods, vec!["password"]); @@ -258,7 +258,7 @@ pub(super) mod tests { backend.set_config(config); backend.load_keys().unwrap(); - if let Token::ApplicationCredential(decrypted) = backend.decrypt(token.into()).unwrap() { + if let Token::ApplicationCredential(decrypted) = backend.decrypt(token).unwrap() { assert_eq!(decrypted.user_id, "4b7d364ad87d400bbd91798e3c15e9c2"); assert_eq!(decrypted.project_id, "97cd761d581b485792a4afc8cc6a998d"); assert_eq!(decrypted.methods, vec!["application_credential"]); diff --git a/src/token/mod.rs b/src/token/mod.rs index 3fa25a2e..aab1f983 100644 --- a/src/token/mod.rs +++ b/src/token/mod.rs @@ -55,9 +55,9 @@ impl TokenProvider { #[async_trait] pub trait TokenApi: Send + Sync + Clone { - async fn validate_token( + async fn validate_token<'a>( &self, - credential: String, + credential: &'a str, window_seconds: Option, ) -> Result; } @@ -66,9 +66,9 @@ pub trait TokenApi: Send + Sync + Clone { impl TokenApi for TokenProvider { /// Validate token #[tracing::instrument(level = "info", skip(self))] - async fn validate_token( + async fn validate_token<'a>( &self, - credential: String, + credential: &'a str, window_seconds: Option, ) -> Result { let token = self.backend_driver.extract(credential)?; @@ -92,9 +92,9 @@ mock! { #[async_trait] impl TokenApi for TokenProvider { - async fn validate_token( + async fn validate_token<'a>( &self, - credential: String, + credential: &'a str, window_seconds: Option, ) -> Result; } diff --git a/src/token/types.rs b/src/token/types.rs index 7ba41f8f..5c392de2 100644 --- a/src/token/types.rs +++ b/src/token/types.rs @@ -73,7 +73,7 @@ pub trait TokenBackend: DynClone + Send + Sync + std::fmt::Debug { fn set_config(&mut self, g: Config); /// Extract the token from string - fn extract(&self, credential: String) -> Result; + fn extract(&self, credential: &str) -> Result; } dyn_clone::clone_trait_object!(TokenBackend);