From 57b331c0f024de00a2ec2a6caef47390e8a9f84d Mon Sep 17 00:00:00 2001 From: Artem Goncharov Date: Sun, 23 Feb 2025 11:00:48 +0100 Subject: [PATCH] feat: Improve auth - add required auth to all necessary methods - implement mocks for provider Since some Keystone endpoints have different auth requirements it cannot be handled in the middleware and should be instead processed with extractors. Access to the state from extractor is not supporting state with generics. Due to that it is required to again rework the state, but now we finally have the proper mock support. --- .github/workflows/linters.yml | 4 +- Cargo.lock | 112 ++++++++-- Cargo.toml | 8 +- benches/fernet_token.rs | 2 +- loadtest/Cargo.toml | 5 +- src/api/auth.rs | 68 +++--- src/api/error.rs | 2 +- src/api/mod.rs | 7 +- src/api/v3/group/mod.rs | 283 ++++++++++++++++-------- src/api/v3/group/types.rs | 18 +- src/api/v3/mod.rs | 7 +- src/api/v3/user/mod.rs | 272 +++++++++++++++-------- src/api/v3/user/types.rs | 39 +++- src/bin/keystone.rs | 30 +-- src/identity/backends/sql.rs | 47 ++-- src/identity/backends/sql/group.rs | 13 +- src/identity/backends/sql/local_user.rs | 13 +- src/identity/backends/sql/password.rs | 2 +- src/identity/backends/sql/user.rs | 2 +- src/identity/mod.rs | 272 ++++++++--------------- src/identity/types/group.rs | 6 +- src/identity/types/user.rs | 2 +- src/keystone.rs | 23 +- src/lib.rs | 2 +- src/provider.rs | 72 ++---- src/tests.rs | 15 ++ src/tests/api.rs | 59 +++++ src/token/fernet.rs | 19 +- src/token/fernet_utils.rs | 6 +- src/token/mod.rs | 34 +-- 30 files changed, 841 insertions(+), 603 deletions(-) create mode 100644 src/tests.rs create mode 100644 src/tests/api.rs diff --git a/.github/workflows/linters.yml b/.github/workflows/linters.yml index 8fffebe0..71ed1e5c 100644 --- a/.github/workflows/linters.yml +++ b/.github/workflows/linters.yml @@ -13,8 +13,6 @@ on: - 'Cargo.toml' - 'Cargo.lock' - '.github/workflows/linters.yml' - - 'keystone/**' - - 'fuzz/**' concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }} @@ -22,7 +20,7 @@ concurrency: env: CARGO_TERM_COLOR: always - rust_min: 1.76.0 + rust_min: 1.85.0 jobs: rustfmt: diff --git a/Cargo.lock b/Cargo.lock index 5d0723d7..ace92415 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -496,9 +496,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.14" +version = "1.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c3d1b2e905a3a7b00a6141adb0e4c0bb941d11caf55349d863942a1cc44e3c9" +checksum = "c736e259eea577f443d5c86c304f9f4ae0295c43f3ba05c21f1d66b5f06001af" dependencies = [ "jobserver", "libc", @@ -968,6 +968,12 @@ version = "0.15.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" +[[package]] +name = "downcast" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1" + [[package]] name = "dyn-clone" version = "1.0.18" @@ -1116,6 +1122,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fragile" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c2141d6d6c8512188a7891b4b01590a45f6dac67afb4f255c4124dbb86d4eaa" + [[package]] name = "funty" version = "2.0.0" @@ -1635,9 +1647,9 @@ dependencies = [ [[package]] name = "inout" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5" +checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" dependencies = [ "generic-array", ] @@ -1765,9 +1777,9 @@ checksum = "9374ef4228402d4b7e403e5838cb880d9ee663314b0a900d5a6aabf0c213552e" [[package]] name = "log" -version = "0.4.25" +version = "0.4.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" +checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e" [[package]] name = "matchit" @@ -1836,6 +1848,44 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "mockall" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39a6bfcc6c8c7eed5ee98b9c3e33adc726054389233e201c95dab2d41a3839d2" +dependencies = [ + "cfg-if", + "downcast", + "fragile", + "mockall_derive", + "predicates", + "predicates-tree", +] + +[[package]] +name = "mockall_derive" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25ca3004c2efe9011bd4e461bd8256445052b9615405b4f7ea43fc8ca5c20898" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "syn 2.0.98", +] + +[[package]] +name = "mockall_double" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1ca96e5ac35256ae3e13536edd39b172b88f41615e1d7b653c8ad24524113e8" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "syn 2.0.98", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -1987,6 +2037,8 @@ dependencies = [ "eyre", "fernet", "http-body-util", + "mockall", + "mockall_double", "regex", "rmp", "sea-orm", @@ -2244,6 +2296,32 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "predicates" +version = "3.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5d19ee57562043d37e82899fade9a22ebab7be9cef5026b07fda9cdd4293573" +dependencies = [ + "anstyle", + "predicates-core", +] + +[[package]] +name = "predicates-core" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "727e462b119fe9c93fd0eb1429a5f7647394014cf3c04ab2c0350eeb09095ffa" + +[[package]] +name = "predicates-tree" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72dd2d6d381dfb73a193c7fca536518d7caee39fc8503f74e7dc0be0531b425c" +dependencies = [ + "predicates-core", + "termtree", +] + [[package]] name = "proc-macro-crate" version = "3.2.0" @@ -3194,6 +3272,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "termtree" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683" + [[package]] name = "thiserror" version = "1.0.69" @@ -3672,9 +3756,9 @@ checksum = "e2eebbbfe4093922c2b6734d7c679ebfebd704a0d7e56dfcb0d05818ce28977d" [[package]] name = "uuid" -version = "1.13.2" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c1f41ffb7cf259f1ecc2876861a17e7142e63ead296f671f81f6ae85903e0d6" +checksum = "93d59ca99a559661b96bf898d8fce28ed87935fd2bea9f05983c1464dd6c71b1" dependencies = [ "getrandom 0.3.1", "serde", @@ -4192,27 +4276,27 @@ dependencies = [ [[package]] name = "zstd" -version = "0.13.2" +version = "0.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" dependencies = [ "zstd-safe", ] [[package]] name = "zstd-safe" -version = "7.2.1" +version = "7.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059" +checksum = "f3051792fbdc2e1e143244dc28c60f73d8470e93f3f9cbd0ead44da5ed802722" dependencies = [ "zstd-sys", ] [[package]] name = "zstd-sys" -version = "2.0.13+zstd.1.5.6" +version = "2.0.14+zstd.1.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa" +checksum = "8fb060d4926e4ac3a3ad15d864e99ceb5f343c6b34f5bd6d81ae6ed417311be5" dependencies = [ "cc", "pkg-config", diff --git a/Cargo.toml b/Cargo.toml index c9538b49..6c607dea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,10 +1,10 @@ [package] name = "openstack_keystone" version = "0.1.0" -edition = "2021" +edition = "2024" license = "Apache-2.0" authors = ["Artem Goncharov (gtema)"] -rust-version = "1.83" # MSRV +rust-version = "1.85" # MSRV repository = "https://github.com/gtema/keystone" [[bin]] @@ -29,6 +29,7 @@ derive_builder = { version = "^0.20" } dyn-clone = { version = "^1.0" } eyre = { version = "^0.6" } fernet = { version = "^0.2" } +mockall_double = { version = "^0.3" } regex = { version = "^1.11"} rmp = { version = "^0.8" } sea-orm = { version = "^1.1", features = ["sqlx-mysql", "sqlx-postgres", "runtime-tokio"] } @@ -44,11 +45,12 @@ tracing-subscriber = { version = "^0.3" } utoipa = { version = "^5.3", features = ["axum_extras", "chrono"] } utoipa-axum = { version = "^0.2" } utoipa-swagger-ui = { version = "^9.0", features = ["axum", "vendored"], default-features = false } -uuid = { version = "^1.13", features = ["v4"] } +uuid = { version = "^1.14", features = ["v4"] } [dev-dependencies] criterion = { version = "^0.5", features = ["async_tokio"] } http-body-util = "^0.1" +mockall = { version = "^0.13" } sea-orm = { version = "*", features = ["mock"]} tempfile = { version = "^3.17" } diff --git a/benches/fernet_token.rs b/benches/fernet_token.rs index 0a75e07d..795bd129 100644 --- a/benches/fernet_token.rs +++ b/benches/fernet_token.rs @@ -1,4 +1,4 @@ -use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use std::fs::File; use std::io::Write; use std::path::PathBuf; diff --git a/loadtest/Cargo.toml b/loadtest/Cargo.toml index 0a0c5cc3..c3326ab5 100644 --- a/loadtest/Cargo.toml +++ b/loadtest/Cargo.toml @@ -1,8 +1,9 @@ [package] name = "load_test" version = "0.1.0" +rust-version = "1.85" # MSRV edition = "2024" [dependencies] -goose = "0.17.2" -tokio = "1.43.0" +goose = { version = "^0.17" } +tokio = { version = "^1.43" } diff --git a/src/api/auth.rs b/src/api/auth.rs index d71b985a..95f7d471 100644 --- a/src/api/auth.rs +++ b/src/api/auth.rs @@ -13,51 +13,45 @@ // SPDX-License-Identifier: Apache-2.0 use axum::{ - extract::{Request, State}, - http::StatusCode, - middleware::Next, - response::Response, + extract::{FromRef, FromRequestParts}, + http::{StatusCode, request::Parts}, }; use std::sync::Arc; use crate::keystone::ServiceState; -use crate::provider::Provider; use crate::token::{Token, TokenApi}; #[derive(Debug, Clone)] -pub struct Auth { - pub token: Token, -} +pub struct Auth(pub Token); -pub async fn auth

( - State(state): State>>, - mut req: Request, - next: Next, -) -> Result +impl FromRequestParts for Auth where - P: Provider, + ServiceState: FromRef, + S: Send + Sync, { - let auth_header = req - .headers() - .get("X-Auth-Token") - .and_then(|header| header.to_str().ok()); - - let auth_header = if let Some(auth_header) = auth_header { - auth_header - } else { - return Err(StatusCode::UNAUTHORIZED); - }; - - // insert the current user into a request extension so the handler can - // extract it - state - .provider - .get_token_provider() - .validate_token(auth_header.to_string(), None) - .await - .map(|token| { - req.extensions_mut().insert(Auth { token }); - }) - .map_err(|_| StatusCode::UNAUTHORIZED)?; - Ok(next.run(req).await) + type Rejection = (StatusCode, &'static str); + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let auth_header = parts + .headers + .get("X-Auth-Token") + .and_then(|header| header.to_str().ok()); + + let auth_header = if let Some(auth_header) = auth_header { + auth_header + } else { + return Err((StatusCode::UNAUTHORIZED, "not authorized")); + }; + + let state = Arc::from_ref(state); + + Ok(Self( + state + .provider + .get_token_provider() + .validate_token(auth_header.to_string(), None) + .await + .map_err(|_| (StatusCode::UNAUTHORIZED, "not authorized"))?, + )) + } } diff --git a/src/api/error.rs b/src/api/error.rs index 454f435a..cddd073c 100644 --- a/src/api/error.rs +++ b/src/api/error.rs @@ -13,9 +13,9 @@ // SPDX-License-Identifier: Apache-2.0 use axum::{ + Json, http::StatusCode, response::{IntoResponse, Response}, - Json, }; use serde_json::json; use thiserror::Error; diff --git a/src/api/mod.rs b/src/api/mod.rs index 00f470cd..d0d0f90e 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -12,12 +12,10 @@ // // SPDX-License-Identifier: Apache-2.0 -use std::sync::Arc; use utoipa::OpenApi; use utoipa_axum::router::OpenApiRouter; use crate::keystone::ServiceState; -use crate::provider::Provider; pub mod auth; pub mod error; @@ -27,9 +25,6 @@ pub mod v3; #[openapi(info(version = "3.14.0"))] pub struct ApiDoc; -pub fn openapi_router

() -> OpenApiRouter>> -where - P: Provider + 'static, -{ +pub fn openapi_router() -> OpenApiRouter { OpenApiRouter::new().nest("/v3", v3::openapi_router()) } diff --git a/src/api/v3/group/mod.rs b/src/api/v3/group/mod.rs index befcbd83..132a0183 100644 --- a/src/api/v3/group/mod.rs +++ b/src/api/v3/group/mod.rs @@ -13,26 +13,22 @@ // SPDX-License-Identifier: Apache-2.0 use axum::{ + Json, debug_handler, extract::{Path, Query, State}, http::StatusCode, response::IntoResponse, - Json, }; -use std::sync::Arc; use utoipa_axum::{router::OpenApiRouter, routes}; +use crate::api::auth::Auth; use crate::api::error::KeystoneApiError; use crate::identity::IdentityApi; use crate::keystone::ServiceState; -use crate::provider::Provider; use types::{Group, GroupCreateRequest, GroupList, GroupListParameters, GroupResponse}; mod types; -pub(super) fn openapi_router

() -> OpenApiRouter>> -where - P: Provider + 'static, -{ +pub(super) fn openapi_router() -> OpenApiRouter { OpenApiRouter::new() .routes(routes!(list, create)) .routes(routes!(show, remove)) @@ -51,13 +47,11 @@ where tag="groups" )] #[tracing::instrument(name = "api::group_list", level = "debug", skip(state))] -async fn list

( +async fn list( + Auth(user_auth): Auth, Query(query): Query, - State(state): State>>, -) -> Result -where - P: Provider, -{ + State(state): State, +) -> Result { let groups: Vec = state .provider .get_identity_provider() @@ -83,17 +77,15 @@ where tag="groups" )] #[tracing::instrument(name = "api::group_get", level = "debug", skip(state))] -async fn show

( +async fn show( + Auth(user_auth): Auth, Path(group_id): Path, - State(state): State>>, -) -> Result -where - P: Provider, -{ + State(state): State, +) -> Result { state .provider .get_identity_provider() - .get_group(&state.db, &group_id) + .get_group(&state.db, group_id.clone()) .await .map(|x| { x.ok_or_else(|| KeystoneApiError::NotFound { @@ -114,13 +106,12 @@ where tag="groups" )] #[tracing::instrument(name = "api::create_group", level = "debug", skip(state))] -async fn create

( - State(state): State>>, +#[debug_handler] +async fn create( + Auth(user_auth): Auth, + State(state): State, Json(req): Json, -) -> Result -where - P: Provider, -{ +) -> Result { let res = state .provider .get_identity_provider() @@ -143,17 +134,15 @@ where tag="groups" )] #[tracing::instrument(name = "api::group_delete", level = "debug", skip(state))] -async fn remove

( +async fn remove( + Auth(user_auth): Auth, Path(group_id): Path, - State(state): State>>, -) -> Result -where - P: Provider, -{ + State(state): State, +) -> Result { state .provider .get_identity_provider() - .delete_group(&state.db, &group_id) + .delete_group(&state.db, group_id) .await .map_err(KeystoneApiError::identity)?; Ok((StatusCode::NO_CONTENT).into_response()) @@ -163,68 +152,164 @@ where mod tests { use axum::{ body::Body, - http::{self, Request, StatusCode}, + http::{Request, StatusCode, header}, }; use http_body_util::BodyExt; // for `collect` use sea_orm::DatabaseConnection; - use std::sync::Arc; + use serde_json::json; + use tower::ServiceExt; // for `call`, `oneshot`, and `ready` use tower_http::trace::TraceLayer; use super::openapi_router; - use crate::api::v3::group::types::*; - use crate::config::Config; - use crate::identity::IdentityApi; - use crate::keystone::ServiceState; - use crate::provider::{FakeProviderApi, Provider}; + use crate::api::v3::group::types::{ + Group as ApiGroup, GroupCreate as ApiGroupCreate, GroupCreateRequest, GroupList, + GroupResponse, + }; + use crate::identity::{ + MockIdentityProvider, + error::IdentityProviderError, + types::{Group, GroupCreate, GroupListParameters}, + }; + + use crate::tests::api::{get_mocked_state, get_mocked_state_unauthed}; #[tokio::test] async fn test_list() { - let db = DatabaseConnection::Disconnected; - let config = Config::default(); - let provider = FakeProviderApi::new(config.clone()).unwrap(); - let state = Arc::new(ServiceState::new(config, db, provider).unwrap()); - let mut api = openapi_router().with_state(state); + let mut identity_mock = MockIdentityProvider::default(); + identity_mock + .expect_list_groups() + .withf(|_: &DatabaseConnection, _: &GroupListParameters| true) + .returning(|_, _| { + Ok(vec![Group { + id: "1".into(), + name: "2".into(), + ..Default::default() + }]) + }); + + let state = get_mocked_state(identity_mock); + + let mut api = openapi_router() + .layer(TraceLayer::new_for_http()) + .with_state(state); let response = api .as_service() - .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) + .oneshot( + Request::builder() + .uri("/") + .header("x-auth-token", "foo") + .body(Body::empty()) + .unwrap(), + ) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = response.into_body().collect().await.unwrap().to_bytes(); - let _res: GroupList = serde_json::from_slice(&body).unwrap(); + let res: GroupList = serde_json::from_slice(&body).unwrap(); + assert_eq!( + vec![ApiGroup { + id: "1".into(), + name: "2".into(), + // for some reason when deserializing missing value appears still as an empty + // object + extra: Some(json!({})), + ..Default::default() + }], + res.groups + ); } #[tokio::test] - async fn test_get() { - let db = DatabaseConnection::Disconnected; - let config = Config::default(); - let provider = FakeProviderApi::new(config.clone()).unwrap(); - let state = Arc::new(ServiceState::new(config, db, provider).unwrap()); + async fn test_list_qp() { + let mut identity_mock = MockIdentityProvider::default(); + identity_mock + .expect_list_groups() + .withf(|_: &DatabaseConnection, qp: &GroupListParameters| { + GroupListParameters { + domain_id: Some("domain".into()), + name: Some("name".into()), + } == *qp + }) + .returning(|_, _| Ok(Vec::new())); + + let state = get_mocked_state(identity_mock); let mut api = openapi_router() .layer(TraceLayer::new_for_http()) - .with_state(state.clone()); + .with_state(state); - let group = crate::identity::types::GroupCreate { - domain_id: "domain".into(), - name: "name".into(), - ..Default::default() - }; + let response = api + .as_service() + .oneshot( + Request::builder() + .uri("/?domain_id=domain&name=name") + .header("x-auth-token", "foo") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let body = response.into_body().collect().await.unwrap().to_bytes(); + let _res: GroupList = serde_json::from_slice(&body).unwrap(); + } - let created_group = state - .provider - .get_identity_provider() - .create_group(&DatabaseConnection::Disconnected, group) + #[tokio::test] + async fn test_list_unauth() { + let state = get_mocked_state_unauthed(); + + let mut api = openapi_router() + .layer(TraceLayer::new_for_http()) + .with_state(state); + + let response = api + .as_service() + .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) .await .unwrap(); + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + } + + #[tokio::test] + async fn test_get() { + let mut identity_mock = MockIdentityProvider::default(); + identity_mock + .expect_get_group() + .withf(|_: &DatabaseConnection, id: &String| *id == "foo") + .returning(|_, _| Ok(None)); + + identity_mock + .expect_get_group() + .withf(|_: &DatabaseConnection, id: &String| *id == "bar") + .returning(|_, _| { + Ok(Some(Group { + id: "bar".into(), + ..Default::default() + })) + }); + + let state = get_mocked_state(identity_mock); + + let mut api = openapi_router() + .layer(TraceLayer::new_for_http()) + .with_state(state.clone()); + let response = api .as_service() - .oneshot(Request::builder().uri("/foo").body(Body::empty()).unwrap()) + .oneshot( + Request::builder() + .uri("/foo") + .header("x-auth-token", "foo") + .body(Body::empty()) + .unwrap(), + ) .await .unwrap(); @@ -234,7 +319,8 @@ mod tests { .as_service() .oneshot( Request::builder() - .uri(format!("/{}", created_group.id)) + .uri("/bar") + .header("x-auth-token", "foo") .body(Body::empty()) .unwrap(), ) @@ -244,22 +330,42 @@ mod tests { assert_eq!(response.status(), StatusCode::OK); let body = response.into_body().collect().await.unwrap().to_bytes(); - let _user: GroupResponse = serde_json::from_slice(&body).unwrap(); + let res: GroupResponse = serde_json::from_slice(&body).unwrap(); + assert_eq!( + ApiGroup { + id: "bar".into(), + extra: Some(json!({})), + ..Default::default() + }, + res.group, + ); } #[tokio::test] async fn test_create() { - let db = DatabaseConnection::Disconnected; - let config = Config::default(); - let provider = FakeProviderApi::new(config.clone()).unwrap(); - let state = Arc::new(ServiceState::new(config, db, provider).unwrap()); + let mut identity_mock = MockIdentityProvider::default(); + identity_mock + .expect_create_group() + .withf(|_: &DatabaseConnection, req: &GroupCreate| { + req.domain_id == "domain" && req.name == "name" + }) + .returning(|_, req| { + Ok(Group { + id: "bar".into(), + domain_id: req.domain_id, + name: req.name, + ..Default::default() + }) + }); + + let state = get_mocked_state(identity_mock); let mut api = openapi_router() .layer(TraceLayer::new_for_http()) .with_state(state.clone()); let req = GroupCreateRequest { - group: GroupCreate { + group: ApiGroupCreate { domain_id: "domain".into(), name: "name".into(), ..Default::default() @@ -271,8 +377,9 @@ mod tests { .oneshot( Request::builder() .method("POST") - .header(http::header::CONTENT_TYPE, "application/json") + .header(header::CONTENT_TYPE, "application/json") .uri("/") + .header("x-auth-token", "foo") .body(Body::from(serde_json::to_string(&req).unwrap())) .unwrap(), ) @@ -284,38 +391,35 @@ mod tests { let body = response.into_body().collect().await.unwrap().to_bytes(); let res: GroupResponse = serde_json::from_slice(&body).unwrap(); assert_eq!(res.group.name, req.group.name); + assert_eq!(res.group.domain_id, req.group.domain_id); } #[tokio::test] async fn test_delete() { - let db = DatabaseConnection::Disconnected; - let config = Config::default(); - let provider = FakeProviderApi::new(config.clone()).unwrap(); - let state = Arc::new(ServiceState::new(config, db, provider).unwrap()); + let mut identity_mock = MockIdentityProvider::default(); + identity_mock + .expect_delete_group() + .withf(|_: &DatabaseConnection, id: &String| *id == "foo") + .returning(|_, _| Err(IdentityProviderError::GroupNotFound("foo".into()))); + + identity_mock + .expect_delete_group() + .withf(|_: &DatabaseConnection, id: &String| *id == "bar") + .returning(|_, _| Ok(())); + + let state = get_mocked_state(identity_mock); let mut api = openapi_router() .layer(TraceLayer::new_for_http()) .with_state(state.clone()); - let group = crate::identity::types::GroupCreate { - domain_id: "domain".into(), - name: "name".into(), - ..Default::default() - }; - - let created_group = state - .provider - .get_identity_provider() - .create_group(&DatabaseConnection::Disconnected, group) - .await - .unwrap(); - let response = api .as_service() .oneshot( Request::builder() .method("DELETE") .uri("/foo") + .header("x-auth-token", "foo") .body(Body::empty()) .unwrap(), ) @@ -329,7 +433,8 @@ mod tests { .oneshot( Request::builder() .method("DELETE") - .uri(format!("/{}", created_group.id)) + .uri("/bar") + .header("x-auth-token", "foo") .body(Body::empty()) .unwrap(), ) diff --git a/src/api/v3/group/types.rs b/src/api/v3/group/types.rs index 3faa8c7a..391ddb1a 100644 --- a/src/api/v3/group/types.rs +++ b/src/api/v3/group/types.rs @@ -13,9 +13,9 @@ // SPDX-License-Identifier: Apache-2.0 use axum::{ + Json, http::StatusCode, response::{IntoResponse, Response}, - Json, }; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -23,7 +23,7 @@ use utoipa::{IntoParams, ToSchema}; use crate::identity::types; -#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)] +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] pub struct Group { /// Group ID pub id: String, @@ -33,17 +33,17 @@ pub struct Group { pub name: String, /// Group description pub description: Option, - #[serde(flatten)] + #[serde(flatten, skip_serializing_if = "Option::is_none")] pub extra: Option, } -#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)] +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] pub struct GroupResponse { /// group object pub group: Group, } -#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)] +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] pub struct GroupCreate { /// Group domain ID pub domain_id: String, @@ -51,11 +51,11 @@ pub struct GroupCreate { pub name: String, /// Group description pub description: Option, - #[serde(flatten)] + #[serde(default, flatten, skip_serializing_if = "Option::is_none")] pub extra: Option, } -#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)] +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] pub struct GroupCreateRequest { /// Group object pub group: GroupCreate, @@ -77,7 +77,7 @@ impl From for types::GroupCreate { fn from(value: GroupCreateRequest) -> Self { let group = value.group; Self { - id: String::new(), + id: None, name: group.name, domain_id: group.domain_id, extra: group.extra, @@ -105,7 +105,7 @@ impl IntoResponse for types::Group { } /// Groups -#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)] +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] pub struct GroupList { /// Collection of group objects pub groups: Vec, diff --git a/src/api/v3/mod.rs b/src/api/v3/mod.rs index 31777f02..5c4afc38 100644 --- a/src/api/v3/mod.rs +++ b/src/api/v3/mod.rs @@ -12,19 +12,14 @@ // // SPDX-License-Identifier: Apache-2.0 -use std::sync::Arc; use utoipa_axum::router::OpenApiRouter; use crate::keystone::ServiceState; -use crate::provider::Provider; pub mod group; pub mod user; -pub(super) fn openapi_router

() -> OpenApiRouter>> -where - P: Provider + 'static, -{ +pub(super) fn openapi_router() -> OpenApiRouter { OpenApiRouter::new() .nest("/users", user::openapi_router()) .nest("/groups", group::openapi_router()) diff --git a/src/api/v3/user/mod.rs b/src/api/v3/user/mod.rs index dedfdca6..4e9423b3 100644 --- a/src/api/v3/user/mod.rs +++ b/src/api/v3/user/mod.rs @@ -13,27 +13,22 @@ // SPDX-License-Identifier: Apache-2.0 use axum::{ - extract::{Extension, Path, Query, State}, + Json, + extract::{Path, Query, State}, http::StatusCode, response::IntoResponse, - Json, }; -use std::sync::Arc; use utoipa_axum::{router::OpenApiRouter, routes}; use crate::api::auth::Auth; use crate::api::error::KeystoneApiError; use crate::identity::IdentityApi; use crate::keystone::ServiceState; -use crate::provider::Provider; use types::{User, UserCreateRequest, UserList, UserListParameters, UserResponse}; mod types; -pub(super) fn openapi_router

() -> OpenApiRouter>> -where - P: Provider + 'static, -{ +pub(super) fn openapi_router() -> OpenApiRouter { OpenApiRouter::new() .routes(routes!(list, create)) .routes(routes!(show, remove)) @@ -52,14 +47,11 @@ where tag="users" )] #[tracing::instrument(name = "api::user_list", level = "debug", skip(state))] -async fn list

( - Extension(current_user): Extension, +async fn list( + Auth(user_auth): Auth, Query(query): Query, - State(state): State>>, -) -> Result -where - P: Provider, -{ + State(state): State, +) -> Result { let users: Vec = state .provider .get_identity_provider() @@ -84,17 +76,15 @@ where tag="users" )] #[tracing::instrument(name = "api::user_get", level = "debug", skip(state))] -async fn show

( +async fn show( + Auth(user_auth): Auth, Path(user_id): Path, - State(state): State>>, -) -> Result -where - P: Provider, -{ + State(state): State, +) -> Result { state .provider .get_identity_provider() - .get_user(&state.db, &user_id) + .get_user(&state.db, user_id.clone()) .await .map(|x| { x.ok_or_else(|| KeystoneApiError::NotFound { @@ -115,14 +105,12 @@ where tag="users" )] #[tracing::instrument(name = "api::create_user", level = "debug", skip(state))] -async fn create

( +async fn create( + Auth(user_auth): Auth, Query(query): Query, - State(state): State>>, + State(state): State, Json(req): Json, -) -> Result -where - P: Provider, -{ +) -> Result { let user = state .provider .get_identity_provider() @@ -145,17 +133,15 @@ where tag="users" )] #[tracing::instrument(name = "api::user_delete", level = "debug", skip(state))] -async fn remove

( +async fn remove( + Auth(user_auth): Auth, Path(user_id): Path, - State(state): State>>, -) -> Result -where - P: Provider, -{ + State(state): State, +) -> Result { state .provider .get_identity_provider() - .delete_user(&state.db, &user_id) + .delete_user(&state.db, user_id) .await .map_err(KeystoneApiError::identity)?; Ok((StatusCode::NO_CONTENT).into_response()) @@ -166,31 +152,44 @@ mod tests { use axum::{ body::Body, http::{self, Request, StatusCode}, - middleware, }; use http_body_util::BodyExt; // for `collect` use sea_orm::DatabaseConnection; - use std::sync::Arc; + use serde_json::json; + use tower::ServiceExt; // for `call`, `oneshot`, and `ready` use tower_http::trace::TraceLayer; use super::openapi_router; - use crate::api::auth::auth; - use crate::api::v3::user::types::*; - use crate::config::Config; - use crate::identity::IdentityApi; - use crate::keystone::ServiceState; - use crate::provider::{FakeProviderApi, Provider}; + use crate::api::v3::user::types::{ + User as ApiUser, UserCreate as ApiUserCreate, UserCreateRequest, UserList, UserResponse, + }; + use crate::identity::{ + MockIdentityProvider, + error::IdentityProviderError, + types::{User, UserCreate, UserListParameters}, + }; + + use crate::tests::api::{get_mocked_state, get_mocked_state_unauthed}; #[tokio::test] async fn test_list() { - let db = DatabaseConnection::Disconnected; - let config = Config::default(); - let provider = FakeProviderApi::new(config.clone()).unwrap(); - let state = Arc::new(ServiceState::new(config, db, provider).unwrap()); + let mut identity_mock = MockIdentityProvider::default(); + identity_mock + .expect_list_users() + .withf(|_: &DatabaseConnection, _: &UserListParameters| true) + .returning(|_, _| { + Ok(vec![User { + id: "1".into(), + name: "2".into(), + ..Default::default() + }]) + }); + + let state = get_mocked_state(identity_mock); + let mut api = openapi_router() .layer(TraceLayer::new_for_http()) - .route_layer(middleware::from_fn_with_state(state.clone(), auth)) .with_state(state); let response = api @@ -208,22 +207,98 @@ mod tests { assert_eq!(response.status(), StatusCode::OK); let body = response.into_body().collect().await.unwrap().to_bytes(); - let _users: UserList = serde_json::from_slice(&body).unwrap(); + let res: UserList = serde_json::from_slice(&body).unwrap(); + assert_eq!( + vec![ApiUser { + id: "1".into(), + name: "2".into(), + // object + extra: Some(json!({})), + ..Default::default() + }], + res.users + ); + } + + #[tokio::test] + async fn test_list_qp() { + let mut identity_mock = MockIdentityProvider::default(); + identity_mock + .expect_list_users() + .withf(|_: &DatabaseConnection, qp: &UserListParameters| { + UserListParameters { + domain_id: Some("domain".into()), + name: Some("name".into()), + } == *qp + }) + .returning(|_, _| Ok(Vec::new())); + + let state = get_mocked_state(identity_mock); + + let mut api = openapi_router() + .layer(TraceLayer::new_for_http()) + .with_state(state); + + let response = api + .as_service() + .oneshot( + Request::builder() + .uri("/?domain_id=domain&name=name") + .header("x-auth-token", "foo") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let body = response.into_body().collect().await.unwrap().to_bytes(); + let _res: UserList = serde_json::from_slice(&body).unwrap(); + } + + #[tokio::test] + async fn test_list_unauth() { + let state = get_mocked_state_unauthed(); + + let mut api = openapi_router() + .layer(TraceLayer::new_for_http()) + .with_state(state); + + let response = api + .as_service() + .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); } #[tokio::test] async fn test_create() { - let db = DatabaseConnection::Disconnected; - let config = Config::default(); - let provider = FakeProviderApi::new(config.clone()).unwrap(); - let state = Arc::new(ServiceState::new(config, db, provider).unwrap()); + let mut identity_mock = MockIdentityProvider::default(); + identity_mock + .expect_create_user() + .withf(|_: &DatabaseConnection, req: &UserCreate| { + req.domain_id == "domain" && req.name == "name" + }) + .returning(|_, req| { + Ok(User { + id: "bar".into(), + domain_id: req.domain_id, + name: req.name, + ..Default::default() + }) + }); + + let state = get_mocked_state(identity_mock); let mut api = openapi_router() .layer(TraceLayer::new_for_http()) .with_state(state.clone()); let user = UserCreateRequest { - user: UserCreate { + user: ApiUserCreate { domain_id: "domain".into(), name: "name".into(), ..Default::default() @@ -237,6 +312,7 @@ mod tests { .method("POST") .header(http::header::CONTENT_TYPE, "application/json") .uri("/") + .header("x-auth-token", "foo") .body(Body::from(serde_json::to_string(&user).unwrap())) .unwrap(), ) @@ -252,31 +328,37 @@ mod tests { #[tokio::test] async fn test_get() { - let db = DatabaseConnection::Disconnected; - let config = Config::default(); - let provider = FakeProviderApi::new(config.clone()).unwrap(); - let state = Arc::new(ServiceState::new(config, db, provider).unwrap()); + let mut identity_mock = MockIdentityProvider::default(); + identity_mock + .expect_get_user() + .withf(|_: &DatabaseConnection, id: &String| *id == "foo") + .returning(|_, _| Ok(None)); + + identity_mock + .expect_get_user() + .withf(|_: &DatabaseConnection, id: &String| *id == "bar") + .returning(|_, _| { + Ok(Some(User { + id: "bar".into(), + ..Default::default() + })) + }); + + let state = get_mocked_state(identity_mock); let mut api = openapi_router() .layer(TraceLayer::new_for_http()) .with_state(state.clone()); - let user = crate::identity::types::UserCreate { - domain_id: "domain".into(), - name: "name".into(), - ..Default::default() - }; - - let created_user = state - .provider - .get_identity_provider() - .create_user(&DatabaseConnection::Disconnected, user) - .await - .unwrap(); - let response = api .as_service() - .oneshot(Request::builder().uri("/foo").body(Body::empty()).unwrap()) + .oneshot( + Request::builder() + .uri("/foo") + .header("x-auth-token", "foo") + .body(Body::empty()) + .unwrap(), + ) .await .unwrap(); @@ -286,7 +368,8 @@ mod tests { .as_service() .oneshot( Request::builder() - .uri(format!("/{}", created_user.id)) + .uri("/bar") + .header("x-auth-token", "foo") .body(Body::empty()) .unwrap(), ) @@ -296,39 +379,43 @@ mod tests { assert_eq!(response.status(), StatusCode::OK); let body = response.into_body().collect().await.unwrap().to_bytes(); - let _user: UserResponse = serde_json::from_slice(&body).unwrap(); + let res: UserResponse = serde_json::from_slice(&body).unwrap(); + assert_eq!( + ApiUser { + id: "bar".into(), + extra: Some(json!({})), + ..Default::default() + }, + res.user, + ); } #[tokio::test] async fn test_delete() { - let db = DatabaseConnection::Disconnected; - let config = Config::default(); - let provider = FakeProviderApi::new(config.clone()).unwrap(); - let state = Arc::new(ServiceState::new(config, db, provider).unwrap()); + let mut identity_mock = MockIdentityProvider::default(); + identity_mock + .expect_delete_user() + .withf(|_: &DatabaseConnection, id: &String| *id == "foo") + .returning(|_, _| Err(IdentityProviderError::UserNotFound("foo".into()))); + + identity_mock + .expect_delete_user() + .withf(|_: &DatabaseConnection, id: &String| *id == "bar") + .returning(|_, _| Ok(())); + + let state = get_mocked_state(identity_mock); let mut api = openapi_router() .layer(TraceLayer::new_for_http()) .with_state(state.clone()); - let user = crate::identity::types::UserCreate { - domain_id: "domain".into(), - name: "name".into(), - ..Default::default() - }; - - let created_user = state - .provider - .get_identity_provider() - .create_user(&DatabaseConnection::Disconnected, user) - .await - .unwrap(); - let response = api .as_service() .oneshot( Request::builder() .method("DELETE") .uri("/foo") + .header("x-auth-token", "foo") .body(Body::empty()) .unwrap(), ) @@ -342,7 +429,8 @@ mod tests { .oneshot( Request::builder() .method("DELETE") - .uri(format!("/{}", created_user.id)) + .uri("/bar") + .header("x-auth-token", "foo") .body(Body::empty()) .unwrap(), ) diff --git a/src/api/v3/user/types.rs b/src/api/v3/user/types.rs index 29ef9111..ca153876 100644 --- a/src/api/v3/user/types.rs +++ b/src/api/v3/user/types.rs @@ -13,9 +13,9 @@ // SPDX-License-Identifier: Apache-2.0 use axum::{ + Json, http::StatusCode, response::{IntoResponse, Response}, - Json, }; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; @@ -24,7 +24,7 @@ use utoipa::{IntoParams, ToSchema}; use crate::identity::types; -#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)] +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] pub struct User { /// User ID pub id: String, @@ -41,25 +41,28 @@ pub struct User { /// default project, the default project is ignored at token creation. (Since v3.1) /// Additionally, if your default project is not valid, a token is issued without an explicit /// scope of authorization. + #[serde(skip_serializing_if = "Option::is_none")] pub default_project_id: Option, - #[serde(flatten)] + #[serde(flatten, skip_serializing_if = "Option::is_none")] pub extra: Option, /// The date and time when the password expires. The time zone is UTC. + #[serde(skip_serializing_if = "Option::is_none")] pub password_expires_at: Option>, /// The resource options for the user. Available resource options are /// ignore_change_password_upon_first_use, ignore_password_expiry, /// ignore_lockout_failure_attempts, lock_password, multi_factor_auth_enabled, and /// multi_factor_auth_rules ignore_user_inactivity. + #[serde(skip_serializing_if = "Option::is_none")] pub options: Option, } -#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)] +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] pub struct UserResponse { /// User object pub user: User, } -#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)] +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] pub struct UserCreate { /// User domain ID pub domain_id: String, @@ -87,7 +90,7 @@ pub struct UserCreate { pub extra: Option, } -#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)] +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] pub struct UserUpdateRequest { /// The user name. Must be unique within the owning domain. pub name: Option, @@ -113,7 +116,7 @@ pub struct UserUpdateRequest { pub extra: Option, } -#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)] +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] pub struct UserOptions { #[serde(skip_serializing_if = "Option::is_none")] pub ignore_change_password_upon_first_use: Option, @@ -159,7 +162,7 @@ impl From for types::UserOptions { } } -#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)] +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] pub struct UserCreateRequest { /// User object pub user: UserCreate, @@ -167,6 +170,20 @@ pub struct UserCreateRequest { impl From for User { fn from(value: types::User) -> Self { + let opts: UserOptions = value.options.clone().into(); + // We only want to see user options if there is at least 1 option set + let opts = if opts.ignore_change_password_upon_first_use.is_some() + || opts.ignore_password_expiry.is_some() + || opts.ignore_lockout_failure_attempts.is_some() + || opts.lock_password.is_some() + || opts.ignore_user_inactivity.is_some() + || opts.multi_factor_auth_rules.is_some() + || opts.multi_factor_auth_enabled.is_some() + { + Some(opts) + } else { + None + }; Self { id: value.id, domain_id: value.domain_id, @@ -175,7 +192,7 @@ impl From for User { default_project_id: value.default_project_id, extra: value.extra, password_expires_at: value.password_expires_at, - options: Some(value.options.into()), + options: opts, } } } @@ -216,7 +233,7 @@ impl IntoResponse for types::User { } /// Users -#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)] +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] pub struct UserList { /// Collection of user objects pub users: Vec, @@ -235,7 +252,7 @@ impl IntoResponse for UserList { } } -#[derive(Clone, Debug, Default, Deserialize, Serialize, IntoParams)] +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, IntoParams)] pub struct UserListParameters { /// Filter users by Domain ID pub domain_id: Option, diff --git a/src/bin/keystone.rs b/src/bin/keystone.rs index cf1978cb..889c6b85 100644 --- a/src/bin/keystone.rs +++ b/src/bin/keystone.rs @@ -12,10 +12,7 @@ // // SPDX-License-Identifier: Apache-2.0 -use axum::{ - http::{self, header, HeaderName, Request}, - middleware::{self}, -}; +use axum::http::{self, HeaderName, Request, header}; use clap::Parser; use color_eyre::eyre::{Report, Result}; use sea_orm::ConnectOptions; @@ -23,15 +20,14 @@ use sea_orm::Database; use std::io; use std::net::{Ipv4Addr, SocketAddr}; use std::sync::Arc; -use tokio::net::TcpListener; -use tokio::signal; +use tokio::{net::TcpListener, signal}; use tower::ServiceBuilder; use tower_http::{ + LatencyUnit, ServiceBuilderExt, request_id::{MakeRequestId, PropagateRequestIdLayer, RequestId, SetRequestIdLayer}, trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer}, - LatencyUnit, ServiceBuilderExt, }; -use tracing::{info_span, Level}; +use tracing::{Level, info_span}; use tracing_subscriber::{filter::LevelFilter, prelude::*}; use utoipa::OpenApi; use utoipa_axum::router::OpenApiRouter; @@ -39,11 +35,10 @@ use utoipa_swagger_ui::SwaggerUi; use uuid::Uuid; use openstack_keystone::api; -use openstack_keystone::api::auth::auth; use openstack_keystone::config::Config; -use openstack_keystone::keystone::ServiceState; +use openstack_keystone::keystone::{Service, ServiceState}; use openstack_keystone::plugin_manager::PluginManager; -use openstack_keystone::provider::{Provider, ProviderApi}; +use openstack_keystone::provider::Provider; /// Simple program to greet a person #[derive(Parser, Debug)] @@ -101,12 +96,12 @@ async fn main() -> Result<(), Report> { let plugin_manager = PluginManager::default(); - let provider = ProviderApi::new(cfg.clone(), plugin_manager)?; + let provider = Provider::new(cfg.clone(), plugin_manager)?; - let shared_state = Arc::new(ServiceState::new(cfg, conn, provider).unwrap()); + let shared_state = Arc::new(Service::new(cfg, conn, provider).unwrap()); let (router, api) = OpenApiRouter::with_openapi(api::ApiDoc::openapi()) - .merge(api::openapi_router::()) + .merge(api::openapi_router()) .split_for_parts(); let x_request_id = HeaderName::from_static("x-openstack-request-id"); @@ -147,21 +142,18 @@ async fn main() -> Result<(), Report> { .layer(PropagateRequestIdLayer::new(x_request_id)); let app = router - // Router::new() - //.merge(api::router()) .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", api)) - .route_layer(middleware::from_fn_with_state(shared_state.clone(), auth)) .layer(middleware) .with_state(shared_state.clone()); let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, 8080)); let listener = TcpListener::bind(&address).await?; Ok(axum::serve(listener, app.into_make_service()) - .with_graceful_shutdown(shutdown_signal(shared_state.clone())) + .with_graceful_shutdown(shutdown_signal(shared_state)) .await?) } -async fn shutdown_signal(state: Arc>) { +async fn shutdown_signal(state: ServiceState) { let ctrl_c = async { signal::ctrl_c() .await diff --git a/src/identity/backends/sql.rs b/src/identity/backends/sql.rs index 85672869..56000e69 100644 --- a/src/identity/backends/sql.rs +++ b/src/identity/backends/sql.rs @@ -13,9 +13,9 @@ // SPDX-License-Identifier: Apache-2.0 use async_trait::async_trait; +use sea_orm::DatabaseConnection; use sea_orm::entity::*; use sea_orm::query::*; -use sea_orm::DatabaseConnection; mod common; mod federated_user; @@ -33,9 +33,9 @@ use crate::db::entity::{ prelude::{FederatedUser, LocalUser, NonlocalUser, User as DbUser, UserOption}, user as db_user, user_option as db_user_option, }; +use crate::identity::IdentityProviderError; use crate::identity::backends::error::IdentityDatabaseError; use crate::identity::password_hashing; -use crate::identity::IdentityProviderError; #[derive(Clone, Debug, Default)] pub struct SqlBackend { @@ -231,26 +231,29 @@ pub async fn get_user( if let Some(user) = &user_entry { let user_opts: Vec = user.find_related(UserOption).all(db).await?; - let user_builder: UserBuilder = if let Some(local_user_with_passwords) = - local_user::load_local_user_with_passwords(db, &user_id).await? - { - common::get_local_user_builder( - conf, - user, - local_user_with_passwords.0, - Some(local_user_with_passwords.1), - user_opts, - ) - } else if let Some(nonlocal_user) = user.find_related(NonlocalUser).one(db).await? { - common::get_nonlocal_user_builder(user, nonlocal_user, user_opts) - } else { - let federated_user = user.find_related(FederatedUser).all(db).await?; - if !federated_user.is_empty() { - common::get_federated_user_builder(user, federated_user, user_opts) - } else { - return Err(IdentityDatabaseError::MalformedUser(user_id.clone()))?; - } - }; + let user_builder: UserBuilder = + match local_user::load_local_user_with_passwords(db, &user_id).await? { + Some(local_user_with_passwords) => common::get_local_user_builder( + conf, + user, + local_user_with_passwords.0, + Some(local_user_with_passwords.1), + user_opts, + ), + _ => match user.find_related(NonlocalUser).one(db).await? { + Some(nonlocal_user) => { + common::get_nonlocal_user_builder(user, nonlocal_user, user_opts) + } + _ => { + let federated_user = user.find_related(FederatedUser).all(db).await?; + if !federated_user.is_empty() { + common::get_federated_user_builder(user, federated_user, user_opts) + } else { + return Err(IdentityDatabaseError::MalformedUser(user_id.clone()))?; + } + } + }, + }; return Ok(Some(user_builder.build()?)); } diff --git a/src/identity/backends/sql/group.rs b/src/identity/backends/sql/group.rs index 5c9b8c08..f5bde82c 100644 --- a/src/identity/backends/sql/group.rs +++ b/src/identity/backends/sql/group.rs @@ -12,15 +12,16 @@ // // SPDX-License-Identifier: Apache-2.0 +use sea_orm::DatabaseConnection; use sea_orm::entity::*; use sea_orm::query::*; -use sea_orm::DatabaseConnection; use serde_json::Value; +use serde_json::json; use crate::db::entity::{group, prelude::Group as DbGroup}; +use crate::identity::Config; use crate::identity::backends::sql::IdentityDatabaseError; use crate::identity::types::{Group, GroupCreate, GroupListParameters}; -use crate::identity::Config; pub async fn list( _conf: &Config, @@ -60,7 +61,7 @@ pub async fn create( group: GroupCreate, ) -> Result { let entry = group::ActiveModel { - id: Set(group.id.clone()), + id: Set(group.id.clone().unwrap_or_default()), domain_id: Set(group.domain_id.clone()), name: Set(group.name.clone()), description: Set(group.description.clone()), @@ -94,7 +95,7 @@ impl From for Group { domain_id: value.domain_id.clone(), extra: value .extra - .map(|x| serde_json::from_str::(&x).unwrap_or(Value::Null)), + .map(|x| serde_json::from_str::(&x).unwrap_or(json!(true))), } } } @@ -107,8 +108,8 @@ mod tests { use serde_json::json; use crate::db::entity::group; - use crate::identity::types::group::GroupListParametersBuilder; use crate::identity::Config; + use crate::identity::types::group::GroupListParametersBuilder; use super::*; @@ -236,7 +237,7 @@ mod tests { let config = Config::default(); let req = GroupCreate { - id: "1".into(), + id: Some("1".into()), domain_id: "foo_domain".into(), name: "group".into(), description: Some("fake".into()), diff --git a/src/identity/backends/sql/local_user.rs b/src/identity/backends/sql/local_user.rs index d63fcc49..b1f3cbc6 100644 --- a/src/identity/backends/sql/local_user.rs +++ b/src/identity/backends/sql/local_user.rs @@ -12,9 +12,9 @@ // // SPDX-License-Identifier: Apache-2.0 +use sea_orm::DatabaseConnection; use sea_orm::entity::*; use sea_orm::query::*; -use sea_orm::DatabaseConnection; use std::collections::HashMap; use crate::config::Config; @@ -30,7 +30,10 @@ pub async fn load_local_user_with_passwords>( db: &DatabaseConnection, user_id: S, ) -> Result< - Option<(local_user::Model, impl IntoIterator)>, + Option<( + local_user::Model, + impl IntoIterator + use, + )>, IdentityDatabaseError, > { let results: Vec<(local_user::Model, Vec)> = LocalUser::find() @@ -48,11 +51,7 @@ pub async fn load_local_user_with_passwords>( pub async fn load_local_users_passwords>>( db: &DatabaseConnection, user_ids: L, -) -> Result< - //impl IntoIterator>>, - Vec>>, - IdentityDatabaseError, -> { +) -> Result>>, IdentityDatabaseError> { let ids: Vec> = user_ids.into_iter().collect(); // Collect local user IDs that we need to query let keys: Vec = ids.iter().filter_map(Option::as_ref).copied().collect(); diff --git a/src/identity/backends/sql/password.rs b/src/identity/backends/sql/password.rs index 003108e0..f1983023 100644 --- a/src/identity/backends/sql/password.rs +++ b/src/identity/backends/sql/password.rs @@ -13,8 +13,8 @@ // SPDX-License-Identifier: Apache-2.0 use chrono::{DateTime, Local, Utc}; -use sea_orm::entity::*; use sea_orm::DatabaseConnection; +use sea_orm::entity::*; use crate::db::entity::password; use crate::identity::backends::error::IdentityDatabaseError; diff --git a/src/identity/backends/sql/user.rs b/src/identity/backends/sql/user.rs index 17521a19..404e3a00 100644 --- a/src/identity/backends/sql/user.rs +++ b/src/identity/backends/sql/user.rs @@ -13,8 +13,8 @@ // SPDX-License-Identifier: Apache-2.0 use chrono::Local; -use sea_orm::entity::*; use sea_orm::DatabaseConnection; +use sea_orm::entity::*; use crate::config::Config; use crate::db::entity::{prelude::User as DbUser, user}; diff --git a/src/identity/mod.rs b/src/identity/mod.rs index 64954d5e..43f7e2eb 100644 --- a/src/identity/mod.rs +++ b/src/identity/mod.rs @@ -13,12 +13,9 @@ // SPDX-License-Identifier: Apache-2.0 use async_trait::async_trait; -use sea_orm::DatabaseConnection; #[cfg(test)] -use std::{ - collections::HashMap, - sync::{Arc, Mutex}, -}; +use mockall::mock; +use sea_orm::DatabaseConnection; use uuid::Uuid; pub mod backends; @@ -29,9 +26,10 @@ pub(crate) mod types; use crate::config::Config; use crate::identity::backends::sql::SqlBackend; use crate::identity::error::IdentityProviderError; -use crate::identity::types::IdentityBackend; -use crate::identity::types::{Group, GroupCreate, GroupListParameters}; -use crate::identity::types::{User, UserCreate, UserListParameters}; +use crate::identity::types::{ + IdentityBackend, + {Group, GroupCreate, GroupListParameters, User, UserCreate, UserListParameters}, +}; use crate::plugin_manager::PluginManager; #[derive(Clone, Debug)] @@ -47,10 +45,10 @@ pub trait IdentityApi: Send + Sync + Clone { params: &UserListParameters, ) -> Result, IdentityProviderError>; - async fn get_user + std::fmt::Debug + Send + Sync>( + async fn get_user( &self, db: &DatabaseConnection, - user_id: S, + user_id: String, ) -> Result, IdentityProviderError>; async fn create_user( @@ -59,10 +57,10 @@ pub trait IdentityApi: Send + Sync + Clone { user: UserCreate, ) -> Result; - async fn delete_user + std::fmt::Debug + Send + Sync>( + async fn delete_user( &self, db: &DatabaseConnection, - user_id: S, + user_id: String, ) -> Result<(), IdentityProviderError>; async fn list_groups( @@ -71,10 +69,10 @@ pub trait IdentityApi: Send + Sync + Clone { params: &GroupListParameters, ) -> Result, IdentityProviderError>; - async fn get_group + std::fmt::Debug + Send + Sync>( + async fn get_group( &self, db: &DatabaseConnection, - group_id: S, + group_id: String, ) -> Result, IdentityProviderError>; async fn create_group( @@ -83,13 +81,76 @@ pub trait IdentityApi: Send + Sync + Clone { group: GroupCreate, ) -> Result; - async fn delete_group + std::fmt::Debug + Send + Sync>( + async fn delete_group( &self, db: &DatabaseConnection, - group_id: S, + group_id: String, ) -> Result<(), IdentityProviderError>; } +#[cfg(test)] +mock! { + pub IdentityProvider { + pub fn new(cfg: &Config, plugin_manager: &PluginManager) -> Result; + } + + #[async_trait] + impl IdentityApi for IdentityProvider { + async fn list_users( + &self, + db: &DatabaseConnection, + params: &UserListParameters, + ) -> Result, IdentityProviderError>; + + async fn get_user( + &self, + db: &DatabaseConnection, + user_id: String, + ) -> Result, IdentityProviderError>; + + async fn create_user( + &self, + db: &DatabaseConnection, + user: UserCreate, + ) -> Result; + + async fn delete_user( + &self, + db: &DatabaseConnection, + user_id: String, + ) -> Result<(), IdentityProviderError>; + + async fn list_groups( + &self, + db: &DatabaseConnection, + params: &GroupListParameters, + ) -> Result, IdentityProviderError>; + + async fn get_group( + &self, + db: &DatabaseConnection, + group_id: String, + ) -> Result, IdentityProviderError>; + + async fn create_group( + &self, + db: &DatabaseConnection, + group: GroupCreate, + ) -> Result; + + async fn delete_group( + &self, + db: &DatabaseConnection, + group_id: String, + ) -> Result<(), IdentityProviderError>; + } + + impl Clone for IdentityProvider { + fn clone(&self) -> Self; + } + +} + impl IdentityProvider { pub fn new( config: &Config, @@ -128,14 +189,12 @@ impl IdentityApi for IdentityProvider { /// Get single user #[tracing::instrument(level = "info", skip(self, db))] - async fn get_user + std::fmt::Debug + Send + Sync>( + async fn get_user( &self, db: &DatabaseConnection, - user_id: S, + user_id: String, ) -> Result, IdentityProviderError> { - self.backend_driver - .get_user(db, user_id.as_ref().to_string()) - .await + self.backend_driver.get_user(db, user_id).await } /// Create user @@ -155,14 +214,12 @@ impl IdentityApi for IdentityProvider { /// Delete user #[tracing::instrument(level = "info", skip(self, db))] - async fn delete_user + std::fmt::Debug + Send + Sync>( + async fn delete_user( &self, db: &DatabaseConnection, - user_id: S, + user_id: String, ) -> Result<(), IdentityProviderError> { - self.backend_driver - .delete_user(db, user_id.as_ref().to_string()) - .await + self.backend_driver.delete_user(db, user_id).await } /// List groups @@ -177,14 +234,12 @@ impl IdentityApi for IdentityProvider { /// Get single group #[tracing::instrument(level = "info", skip(self, db))] - async fn get_group + std::fmt::Debug + Send + Sync>( + async fn get_group( &self, db: &DatabaseConnection, - group_id: S, + group_id: String, ) -> Result, IdentityProviderError> { - self.backend_driver - .get_group(db, group_id.as_ref().to_string()) - .await + self.backend_driver.get_group(db, group_id).await } /// Create group @@ -195,164 +250,17 @@ impl IdentityApi for IdentityProvider { group: GroupCreate, ) -> Result { let mut res = group; - res.id = Uuid::new_v4().into(); + res.id = Some(Uuid::new_v4().into()); self.backend_driver.create_group(db, res).await } /// Delete group #[tracing::instrument(level = "info", skip(self, db))] - async fn delete_group + std::fmt::Debug + Send + Sync>( + async fn delete_group( &self, db: &DatabaseConnection, - group_id: S, - ) -> Result<(), IdentityProviderError> { - self.backend_driver - .delete_group(db, group_id.as_ref().to_string()) - .await - } -} - -#[cfg(test)] -#[derive(Clone, Debug, Default)] -pub(crate) struct FakeIdentityProvider { - users: Arc>>, - groups: Arc>>, -} - -#[cfg(test)] -impl From for User { - fn from(value: UserCreate) -> Self { - Self { - id: value.id, - name: value.name, - domain_id: value.domain_id, - ..Default::default() - } - } -} -#[cfg(test)] -impl From for Group { - fn from(value: GroupCreate) -> Self { - Self { - id: value.id, - name: value.name, - domain_id: value.domain_id, - ..Default::default() - } - } -} - -#[cfg(test)] -#[async_trait] -impl IdentityApi for FakeIdentityProvider { - /// List users - async fn list_users( - &self, - _db: &DatabaseConnection, - _params: &types::UserListParameters, - ) -> Result, IdentityProviderError> { - Ok(self - .users - .lock() - .unwrap() - .values() - .cloned() - .collect::>()) - } - - /// Get single user - async fn get_user + std::fmt::Debug + Send + Sync>( - &self, - _db: &DatabaseConnection, - user_id: S, - ) -> Result, IdentityProviderError> { - Ok(self.users.lock().unwrap().get(user_id.as_ref()).cloned()) - } - - /// Create user - async fn create_user( - &self, - _db: &DatabaseConnection, - user: UserCreate, - ) -> Result { - let mut mod_user = user; - mod_user.id = Uuid::new_v4().into(); - let res = User::from(mod_user); - self.users - .lock() - .unwrap() - .insert(res.id.clone(), res.clone()); - Ok(res) - } - - /// Delete user - async fn delete_user + std::fmt::Debug + Send + Sync>( - &self, - _db: &DatabaseConnection, - user_id: S, - ) -> Result<(), IdentityProviderError> { - Ok(self - .users - .lock() - .unwrap() - .remove(user_id.as_ref()) - .map(|_| ()) - .ok_or(IdentityProviderError::UserNotFound( - user_id.as_ref().to_string(), - ))?) - } - - async fn list_groups( - &self, - _db: &DatabaseConnection, - _params: &GroupListParameters, - ) -> Result, IdentityProviderError> { - Ok(self - .groups - .lock() - .unwrap() - .values() - .cloned() - .collect::>()) - } - - /// Get single user - async fn get_group + std::fmt::Debug + Send + Sync>( - &self, - _db: &DatabaseConnection, - group_id: S, - ) -> Result, IdentityProviderError> { - Ok(self.groups.lock().unwrap().get(group_id.as_ref()).cloned()) - } - - /// Create group - async fn create_group( - &self, - _db: &DatabaseConnection, - group: GroupCreate, - ) -> Result { - let mut req = group; - req.id = Uuid::new_v4().into(); - let res = Group::from(req); - self.groups - .lock() - .unwrap() - .insert(res.id.clone(), res.clone()); - Ok(res) - } - /// - /// Delete group - async fn delete_group + std::fmt::Debug + Send + Sync>( - &self, - _db: &DatabaseConnection, - id: S, + group_id: String, ) -> Result<(), IdentityProviderError> { - Ok(self - .groups - .lock() - .unwrap() - .remove(id.as_ref()) - .map(|_| ()) - .ok_or(IdentityProviderError::UserNotFound(id.as_ref().to_string()))?) + self.backend_driver.delete_group(db, group_id).await } } diff --git a/src/identity/types/group.rs b/src/identity/types/group.rs index befbad88..e4a2657f 100644 --- a/src/identity/types/group.rs +++ b/src/identity/types/group.rs @@ -32,7 +32,7 @@ pub struct Group { pub name: String, } -#[derive(Builder, Clone, Debug, Default, Deserialize, Serialize)] +#[derive(Builder, Clone, Debug, Default, Deserialize, PartialEq, Serialize)] #[builder(setter(strip_option, into))] pub struct GroupListParameters { /// Filter groups by the domain @@ -41,7 +41,7 @@ pub struct GroupListParameters { pub name: Option, } -#[derive(Builder, Clone, Debug, Default, Deserialize, Serialize, PartialEq)] +#[derive(Builder, Clone, Debug, Default, Deserialize, PartialEq, Serialize)] #[builder(setter(strip_option, into))] pub struct GroupCreate { /// The description of the group. @@ -52,7 +52,7 @@ pub struct GroupCreate { #[builder(default)] pub extra: Option, /// The ID of the group. - pub id: String, + pub id: Option, /// The user name. Must be unique within the owning domain. pub name: String, } diff --git a/src/identity/types/user.rs b/src/identity/types/user.rs index c3d2bb06..72e2e4f3 100644 --- a/src/identity/types/user.rs +++ b/src/identity/types/user.rs @@ -140,7 +140,7 @@ pub struct FederationProtocol { pub unique_id: String, } -#[derive(Builder, Clone, Debug, Default, Deserialize, Serialize)] +#[derive(Builder, Clone, Debug, Default, Deserialize, PartialEq, Serialize)] pub struct UserListParameters { /// Filter users by the domain pub domain_id: Option, diff --git a/src/keystone.rs b/src/keystone.rs index a6d8e058..05ea6c5d 100644 --- a/src/keystone.rs +++ b/src/keystone.rs @@ -12,9 +12,9 @@ // // SPDX-License-Identifier: Apache-2.0 +use axum::extract::FromRef; use sea_orm::DatabaseConnection; -//use tokio::sync::RwLock; - +use std::sync::Arc; use tracing::info; use crate::config::Config; @@ -24,17 +24,22 @@ use crate::provider::Provider; // Placing ServiceState behind Arc is necessary to address DatabaseConnection not implementing // Clone //#[derive(Clone)] -pub struct ServiceState

{ +#[derive(FromRef)] +pub struct Service { pub config: Config, - pub provider: P, + pub provider: Provider, + #[from_ref(skip)] pub db: DatabaseConnection, } -impl

ServiceState

-where - P: Provider, -{ - pub fn new(cfg: Config, db: DatabaseConnection, provider: P) -> Result { +pub type ServiceState = Arc; + +impl Service { + pub fn new( + cfg: Config, + db: DatabaseConnection, + provider: Provider, + ) -> Result { Ok(Self { config: cfg.clone(), provider, diff --git a/src/lib.rs b/src/lib.rs index a9735f7c..af4629c3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,4 +24,4 @@ pub mod resource; pub mod token; #[cfg(test)] -mod tests {} +mod tests; diff --git a/src/provider.rs b/src/provider.rs index e7ecacd7..ce49a6d5 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -11,28 +11,35 @@ // limitations under the License. // // SPDX-License-Identifier: Apache-2.0 +use derive_builder::Builder; +use mockall_double::double; use crate::config::Config; use crate::error::KeystoneError; -use crate::identity::{IdentityApi, IdentityProvider}; +use crate::identity::IdentityApi; +#[double] +use crate::identity::IdentityProvider; use crate::plugin_manager::PluginManager; -use crate::token::{TokenApi, TokenProvider}; -#[cfg(test)] -use crate::{identity::FakeIdentityProvider, token::FakeTokenProvider}; - -pub trait Provider: Clone + Send + Sync { - fn get_identity_provider(&self) -> &impl IdentityApi; - fn get_token_provider(&self) -> &impl TokenApi; -} - -#[derive(Clone)] -pub struct ProviderApi { +use crate::token::TokenApi; +#[double] +use crate::token::TokenProvider; + +//pub trait Provider: Clone + Send + Sync { +// fn get_identity_provider(&self) -> &impl IdentityApi; +// fn get_token_provider(&self) -> &impl TokenApi; +//} + +#[derive(Builder, Clone)] +// It is necessary to use the owned pattern since otherwise builder invokes clone which immediately +// confuses mockall used in tests +#[builder(pattern = "owned")] +pub struct Provider { pub config: Config, identity: IdentityProvider, token: TokenProvider, } -impl ProviderApi { +impl Provider { pub fn new(cfg: Config, plugin_manager: PluginManager) -> Result { let identity_provider = IdentityProvider::new(&cfg, &plugin_manager)?; let token_provider = TokenProvider::new(&cfg)?; @@ -43,47 +50,12 @@ impl ProviderApi { token: token_provider, }) } -} - -impl Provider for ProviderApi { - fn get_identity_provider(&self) -> &impl IdentityApi { - &self.identity - } - - fn get_token_provider(&self) -> &impl TokenApi { - &self.token - } -} - -#[cfg(test)] -#[derive(Clone)] -pub struct FakeProviderApi { - pub config: Config, - identity: FakeIdentityProvider, - token: FakeTokenProvider, -} - -#[cfg(test)] -impl FakeProviderApi { - pub fn new(cfg: Config) -> Result { - let identity_provider = FakeIdentityProvider::default(); - let token_provider = FakeTokenProvider::default(); - - Ok(Self { - config: cfg, - identity: identity_provider, - token: token_provider, - }) - } -} -#[cfg(test)] -impl Provider for FakeProviderApi { - fn get_identity_provider(&self) -> &impl IdentityApi { + pub fn get_identity_provider(&self) -> &impl IdentityApi { &self.identity } - fn get_token_provider(&self) -> &impl TokenApi { + pub fn get_token_provider(&self) -> &impl TokenApi { &self.token } } diff --git a/src/tests.rs b/src/tests.rs new file mode 100644 index 00000000..bafcca80 --- /dev/null +++ b/src/tests.rs @@ -0,0 +1,15 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +pub(crate) mod api; diff --git a/src/tests/api.rs b/src/tests/api.rs new file mode 100644 index 00000000..b2db2c3a --- /dev/null +++ b/src/tests/api.rs @@ -0,0 +1,59 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +use sea_orm::DatabaseConnection; +use std::sync::Arc; + +use crate::config::Config; +use crate::identity::MockIdentityProvider; +use crate::keystone::{Service, ServiceState}; +use crate::provider::ProviderBuilder; +use crate::token::{MockTokenProvider, Token, TokenProviderError}; + +pub(crate) fn get_mocked_state_unauthed() -> ServiceState { + let db = DatabaseConnection::Disconnected; + let config = Config::default(); + let identity_mock = MockIdentityProvider::default(); + let mut token_mock = MockTokenProvider::default(); + token_mock + .expect_validate_token() + .returning(|_, _| Err(TokenProviderError::InvalidToken)); + + let provider = ProviderBuilder::default() + .config(config.clone()) + .identity(identity_mock) + .token(token_mock) + .build() + .unwrap(); + + Arc::new(Service::new(config, db, provider).unwrap()) +} + +pub(crate) fn get_mocked_state(identity_mock: MockIdentityProvider) -> ServiceState { + let db = DatabaseConnection::Disconnected; + let config = Config::default(); + let mut token_mock = MockTokenProvider::default(); + token_mock + .expect_validate_token() + .returning(|_, _| Ok(Token::default())); + + let provider = ProviderBuilder::default() + .config(config.clone()) + .identity(identity_mock) + .token(token_mock) + .build() + .unwrap(); + + Arc::new(Service::new(config, db, provider).unwrap()) +} diff --git a/src/token/fernet.rs b/src/token/fernet.rs index f4540bf6..2b8dd2ba 100644 --- a/src/token/fernet.rs +++ b/src/token/fernet.rs @@ -12,10 +12,10 @@ // // SPDX-License-Identifier: Apache-2.0 -use base64::{engine::general_purpose::URL_SAFE, Engine as _}; +use base64::{Engine as _, engine::general_purpose::URL_SAFE}; use chrono::{DateTime, Utc}; use fernet::{Fernet, MultiFernet}; -use rmp::{decode::*, Marker}; +use rmp::{Marker, decode::*}; use std::collections::BTreeMap; use std::fmt; use std::io; @@ -24,9 +24,9 @@ use uuid::Uuid; use crate::config::Config; use crate::token::{ + TokenProviderError, fernet_utils::FernetUtils, types::{Token, TokenBackend}, - TokenProviderError, }; #[derive(Default, Clone)] @@ -118,7 +118,7 @@ fn read_time(rd: &mut &[u8]) -> Result, TokenProviderError> { fn decode_auth_methods( value: usize, auth_map: &BTreeMap, -) -> Result, TokenProviderError> { +) -> Result + use<>, TokenProviderError> { let mut results: Vec = Vec::new(); let mut auth: usize = value; for (idx, name) in auth_map.iter() { @@ -141,7 +141,9 @@ fn decode_auth_methods( } /// Decode array of audit ids from the payload -fn read_audit_ids(rd: &mut &[u8]) -> Result, TokenProviderError> { +fn read_audit_ids( + rd: &mut &[u8], +) -> Result + use<>, TokenProviderError> { if let Marker::FixArray(len) = read_marker(rd).map_err(ValueReadError::from)? { let mut result: Vec = Vec::new(); for _ in 0..len { @@ -198,10 +200,9 @@ impl FernetTokenProvider { pub fn decrypt(&self, credential: String) -> Result { // TODO: Implement fernet keys change watching. Keystone loads them from FS on every // request and in the best case it costs 15µs. - let payload = if let Some(fernet) = &self.fernet { - fernet.decrypt(credential.as_ref())? - } else { - self.get_fernet()?.decrypt(credential.as_ref())? + let payload = match &self.fernet { + Some(fernet) => fernet.decrypt(credential.as_ref())?, + _ => self.get_fernet()?.decrypt(credential.as_ref())?, }; self.parse(&mut payload.as_slice()) } diff --git a/src/token/fernet_utils.rs b/src/token/fernet_utils.rs index 2423337b..19e249f6 100644 --- a/src/token/fernet_utils.rs +++ b/src/token/fernet_utils.rs @@ -30,7 +30,9 @@ impl FernetUtils { Ok(self.key_repository.exists()) } - pub fn load_keys(&self) -> Result, TokenProviderError> { + pub fn load_keys( + &self, + ) -> Result + use<>, TokenProviderError> { let mut keys: BTreeMap = BTreeMap::new(); if self.validate_key_repository()? { for entry in fs::read_dir(&self.key_repository)? { @@ -49,7 +51,7 @@ impl FernetUtils { } pub async fn load_keys_async( &self, - ) -> Result, TokenProviderError> { + ) -> Result + use<>, TokenProviderError> { let mut keys: BTreeMap = BTreeMap::new(); if self.validate_key_repository()? { let mut entries = fs_async::read_dir(&self.key_repository).await?; diff --git a/src/token/mod.rs b/src/token/mod.rs index a8b81579..db2602fb 100644 --- a/src/token/mod.rs +++ b/src/token/mod.rs @@ -14,6 +14,8 @@ use async_trait::async_trait; use chrono::{Local, TimeDelta}; +#[cfg(test)] +use mockall::mock; mod error; pub mod fernet; @@ -75,22 +77,22 @@ impl TokenApi for TokenProvider { } #[cfg(test)] -#[derive(Clone, Debug, Default)] -pub(crate) struct FakeTokenProvider {} +mock! { + pub TokenProvider { + pub fn new(cfg: &Config) -> Result; + } -#[cfg(test)] -#[async_trait] -impl TokenApi for FakeTokenProvider { - /// Validate token - #[tracing::instrument(level = "info", skip(self))] - async fn validate_token( - &self, - _credential: String, - _window_seconds: Option, - ) -> Result { - Ok(Token { - user_id: String::new(), - ..Default::default() - }) + #[async_trait] + impl TokenApi for TokenProvider { + async fn validate_token( + &self, + credential: String, + window_seconds: Option, + ) -> Result; } + + impl Clone for TokenProvider { + fn clone(&self) -> Self; + } + }