diff --git a/.github/workflows/linters.yml b/.github/workflows/linters.yml index 8fffebe0..71ed1e5c 100644 --- a/.github/workflows/linters.yml +++ b/.github/workflows/linters.yml @@ -13,8 +13,6 @@ on: - 'Cargo.toml' - 'Cargo.lock' - '.github/workflows/linters.yml' - - 'keystone/**' - - 'fuzz/**' concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }} @@ -22,7 +20,7 @@ concurrency: env: CARGO_TERM_COLOR: always - rust_min: 1.76.0 + rust_min: 1.85.0 jobs: rustfmt: diff --git a/Cargo.lock b/Cargo.lock index 5d0723d7..ace92415 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -496,9 +496,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.14" +version = "1.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c3d1b2e905a3a7b00a6141adb0e4c0bb941d11caf55349d863942a1cc44e3c9" +checksum = "c736e259eea577f443d5c86c304f9f4ae0295c43f3ba05c21f1d66b5f06001af" dependencies = [ "jobserver", "libc", @@ -968,6 +968,12 @@ version = "0.15.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" +[[package]] +name = "downcast" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1" + [[package]] name = "dyn-clone" version = "1.0.18" @@ -1116,6 +1122,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fragile" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c2141d6d6c8512188a7891b4b01590a45f6dac67afb4f255c4124dbb86d4eaa" + [[package]] name = "funty" version = "2.0.0" @@ -1635,9 +1647,9 @@ dependencies = [ [[package]] name = "inout" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5" +checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" dependencies = [ "generic-array", ] @@ -1765,9 +1777,9 @@ checksum = "9374ef4228402d4b7e403e5838cb880d9ee663314b0a900d5a6aabf0c213552e" [[package]] name = "log" -version = "0.4.25" +version = "0.4.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" +checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e" [[package]] name = "matchit" @@ -1836,6 +1848,44 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "mockall" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39a6bfcc6c8c7eed5ee98b9c3e33adc726054389233e201c95dab2d41a3839d2" +dependencies = [ + "cfg-if", + "downcast", + "fragile", + "mockall_derive", + "predicates", + "predicates-tree", +] + +[[package]] +name = "mockall_derive" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25ca3004c2efe9011bd4e461bd8256445052b9615405b4f7ea43fc8ca5c20898" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "syn 2.0.98", +] + +[[package]] +name = "mockall_double" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1ca96e5ac35256ae3e13536edd39b172b88f41615e1d7b653c8ad24524113e8" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "syn 2.0.98", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -1987,6 +2037,8 @@ dependencies = [ "eyre", "fernet", "http-body-util", + "mockall", + "mockall_double", "regex", "rmp", "sea-orm", @@ -2244,6 +2296,32 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "predicates" +version = "3.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5d19ee57562043d37e82899fade9a22ebab7be9cef5026b07fda9cdd4293573" +dependencies = [ + "anstyle", + "predicates-core", +] + +[[package]] +name = "predicates-core" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "727e462b119fe9c93fd0eb1429a5f7647394014cf3c04ab2c0350eeb09095ffa" + +[[package]] +name = "predicates-tree" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72dd2d6d381dfb73a193c7fca536518d7caee39fc8503f74e7dc0be0531b425c" +dependencies = [ + "predicates-core", + "termtree", +] + [[package]] name = "proc-macro-crate" version = "3.2.0" @@ -3194,6 +3272,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "termtree" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683" + [[package]] name = "thiserror" version = "1.0.69" @@ -3672,9 +3756,9 @@ checksum = "e2eebbbfe4093922c2b6734d7c679ebfebd704a0d7e56dfcb0d05818ce28977d" [[package]] name = "uuid" -version = "1.13.2" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c1f41ffb7cf259f1ecc2876861a17e7142e63ead296f671f81f6ae85903e0d6" +checksum = "93d59ca99a559661b96bf898d8fce28ed87935fd2bea9f05983c1464dd6c71b1" dependencies = [ "getrandom 0.3.1", "serde", @@ -4192,27 +4276,27 @@ dependencies = [ [[package]] name = "zstd" -version = "0.13.2" +version = "0.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" dependencies = [ "zstd-safe", ] [[package]] name = "zstd-safe" -version = "7.2.1" +version = "7.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059" +checksum = "f3051792fbdc2e1e143244dc28c60f73d8470e93f3f9cbd0ead44da5ed802722" dependencies = [ "zstd-sys", ] [[package]] name = "zstd-sys" -version = "2.0.13+zstd.1.5.6" +version = "2.0.14+zstd.1.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa" +checksum = "8fb060d4926e4ac3a3ad15d864e99ceb5f343c6b34f5bd6d81ae6ed417311be5" dependencies = [ "cc", "pkg-config", diff --git a/Cargo.toml b/Cargo.toml index c9538b49..6c607dea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,10 +1,10 @@ [package] name = "openstack_keystone" version = "0.1.0" -edition = "2021" +edition = "2024" license = "Apache-2.0" authors = ["Artem Goncharov (gtema)"] -rust-version = "1.83" # MSRV +rust-version = "1.85" # MSRV repository = "https://github.com/gtema/keystone" [[bin]] @@ -29,6 +29,7 @@ derive_builder = { version = "^0.20" } dyn-clone = { version = "^1.0" } eyre = { version = "^0.6" } fernet = { version = "^0.2" } +mockall_double = { version = "^0.3" } regex = { version = "^1.11"} rmp = { version = "^0.8" } sea-orm = { version = "^1.1", features = ["sqlx-mysql", "sqlx-postgres", "runtime-tokio"] } @@ -44,11 +45,12 @@ tracing-subscriber = { version = "^0.3" } utoipa = { version = "^5.3", features = ["axum_extras", "chrono"] } utoipa-axum = { version = "^0.2" } utoipa-swagger-ui = { version = "^9.0", features = ["axum", "vendored"], default-features = false } -uuid = { version = "^1.13", features = ["v4"] } +uuid = { version = "^1.14", features = ["v4"] } [dev-dependencies] criterion = { version = "^0.5", features = ["async_tokio"] } http-body-util = "^0.1" +mockall = { version = "^0.13" } sea-orm = { version = "*", features = ["mock"]} tempfile = { version = "^3.17" } diff --git a/benches/fernet_token.rs b/benches/fernet_token.rs index 0a75e07d..795bd129 100644 --- a/benches/fernet_token.rs +++ b/benches/fernet_token.rs @@ -1,4 +1,4 @@ -use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use std::fs::File; use std::io::Write; use std::path::PathBuf; diff --git a/loadtest/Cargo.toml b/loadtest/Cargo.toml index 0a0c5cc3..c3326ab5 100644 --- a/loadtest/Cargo.toml +++ b/loadtest/Cargo.toml @@ -1,8 +1,9 @@ [package] name = "load_test" version = "0.1.0" +rust-version = "1.85" # MSRV edition = "2024" [dependencies] -goose = "0.17.2" -tokio = "1.43.0" +goose = { version = "^0.17" } +tokio = { version = "^1.43" } diff --git a/src/api/auth.rs b/src/api/auth.rs index d71b985a..95f7d471 100644 --- a/src/api/auth.rs +++ b/src/api/auth.rs @@ -13,51 +13,45 @@ // SPDX-License-Identifier: Apache-2.0 use axum::{ - extract::{Request, State}, - http::StatusCode, - middleware::Next, - response::Response, + extract::{FromRef, FromRequestParts}, + http::{StatusCode, request::Parts}, }; use std::sync::Arc; use crate::keystone::ServiceState; -use crate::provider::Provider; use crate::token::{Token, TokenApi}; #[derive(Debug, Clone)] -pub struct Auth { - pub token: Token, -} +pub struct Auth(pub Token); -pub async fn auth

( - State(state): State>>, - mut req: Request, - next: Next, -) -> Result +impl FromRequestParts for Auth where - P: Provider, + ServiceState: FromRef, + S: Send + Sync, { - let auth_header = req - .headers() - .get("X-Auth-Token") - .and_then(|header| header.to_str().ok()); - - let auth_header = if let Some(auth_header) = auth_header { - auth_header - } else { - return Err(StatusCode::UNAUTHORIZED); - }; - - // insert the current user into a request extension so the handler can - // extract it - state - .provider - .get_token_provider() - .validate_token(auth_header.to_string(), None) - .await - .map(|token| { - req.extensions_mut().insert(Auth { token }); - }) - .map_err(|_| StatusCode::UNAUTHORIZED)?; - Ok(next.run(req).await) + type Rejection = (StatusCode, &'static str); + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let auth_header = parts + .headers + .get("X-Auth-Token") + .and_then(|header| header.to_str().ok()); + + let auth_header = if let Some(auth_header) = auth_header { + auth_header + } else { + return Err((StatusCode::UNAUTHORIZED, "not authorized")); + }; + + let state = Arc::from_ref(state); + + Ok(Self( + state + .provider + .get_token_provider() + .validate_token(auth_header.to_string(), None) + .await + .map_err(|_| (StatusCode::UNAUTHORIZED, "not authorized"))?, + )) + } } diff --git a/src/api/error.rs b/src/api/error.rs index 454f435a..cddd073c 100644 --- a/src/api/error.rs +++ b/src/api/error.rs @@ -13,9 +13,9 @@ // SPDX-License-Identifier: Apache-2.0 use axum::{ + Json, http::StatusCode, response::{IntoResponse, Response}, - Json, }; use serde_json::json; use thiserror::Error; diff --git a/src/api/mod.rs b/src/api/mod.rs index 00f470cd..d0d0f90e 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -12,12 +12,10 @@ // // SPDX-License-Identifier: Apache-2.0 -use std::sync::Arc; use utoipa::OpenApi; use utoipa_axum::router::OpenApiRouter; use crate::keystone::ServiceState; -use crate::provider::Provider; pub mod auth; pub mod error; @@ -27,9 +25,6 @@ pub mod v3; #[openapi(info(version = "3.14.0"))] pub struct ApiDoc; -pub fn openapi_router

() -> OpenApiRouter>> -where - P: Provider + 'static, -{ +pub fn openapi_router() -> OpenApiRouter { OpenApiRouter::new().nest("/v3", v3::openapi_router()) } diff --git a/src/api/v3/group/mod.rs b/src/api/v3/group/mod.rs index befcbd83..132a0183 100644 --- a/src/api/v3/group/mod.rs +++ b/src/api/v3/group/mod.rs @@ -13,26 +13,22 @@ // SPDX-License-Identifier: Apache-2.0 use axum::{ + Json, debug_handler, extract::{Path, Query, State}, http::StatusCode, response::IntoResponse, - Json, }; -use std::sync::Arc; use utoipa_axum::{router::OpenApiRouter, routes}; +use crate::api::auth::Auth; use crate::api::error::KeystoneApiError; use crate::identity::IdentityApi; use crate::keystone::ServiceState; -use crate::provider::Provider; use types::{Group, GroupCreateRequest, GroupList, GroupListParameters, GroupResponse}; mod types; -pub(super) fn openapi_router

() -> OpenApiRouter>> -where - P: Provider + 'static, -{ +pub(super) fn openapi_router() -> OpenApiRouter { OpenApiRouter::new() .routes(routes!(list, create)) .routes(routes!(show, remove)) @@ -51,13 +47,11 @@ where tag="groups" )] #[tracing::instrument(name = "api::group_list", level = "debug", skip(state))] -async fn list

( +async fn list( + Auth(user_auth): Auth, Query(query): Query, - State(state): State>>, -) -> Result -where - P: Provider, -{ + State(state): State, +) -> Result { let groups: Vec = state .provider .get_identity_provider() @@ -83,17 +77,15 @@ where tag="groups" )] #[tracing::instrument(name = "api::group_get", level = "debug", skip(state))] -async fn show

( +async fn show( + Auth(user_auth): Auth, Path(group_id): Path, - State(state): State>>, -) -> Result -where - P: Provider, -{ + State(state): State, +) -> Result { state .provider .get_identity_provider() - .get_group(&state.db, &group_id) + .get_group(&state.db, group_id.clone()) .await .map(|x| { x.ok_or_else(|| KeystoneApiError::NotFound { @@ -114,13 +106,12 @@ where tag="groups" )] #[tracing::instrument(name = "api::create_group", level = "debug", skip(state))] -async fn create

( - State(state): State>>, +#[debug_handler] +async fn create( + Auth(user_auth): Auth, + State(state): State, Json(req): Json, -) -> Result -where - P: Provider, -{ +) -> Result { let res = state .provider .get_identity_provider() @@ -143,17 +134,15 @@ where tag="groups" )] #[tracing::instrument(name = "api::group_delete", level = "debug", skip(state))] -async fn remove

( +async fn remove( + Auth(user_auth): Auth, Path(group_id): Path, - State(state): State>>, -) -> Result -where - P: Provider, -{ + State(state): State, +) -> Result { state .provider .get_identity_provider() - .delete_group(&state.db, &group_id) + .delete_group(&state.db, group_id) .await .map_err(KeystoneApiError::identity)?; Ok((StatusCode::NO_CONTENT).into_response()) @@ -163,68 +152,164 @@ where mod tests { use axum::{ body::Body, - http::{self, Request, StatusCode}, + http::{Request, StatusCode, header}, }; use http_body_util::BodyExt; // for `collect` use sea_orm::DatabaseConnection; - use std::sync::Arc; + use serde_json::json; + use tower::ServiceExt; // for `call`, `oneshot`, and `ready` use tower_http::trace::TraceLayer; use super::openapi_router; - use crate::api::v3::group::types::*; - use crate::config::Config; - use crate::identity::IdentityApi; - use crate::keystone::ServiceState; - use crate::provider::{FakeProviderApi, Provider}; + use crate::api::v3::group::types::{ + Group as ApiGroup, GroupCreate as ApiGroupCreate, GroupCreateRequest, GroupList, + GroupResponse, + }; + use crate::identity::{ + MockIdentityProvider, + error::IdentityProviderError, + types::{Group, GroupCreate, GroupListParameters}, + }; + + use crate::tests::api::{get_mocked_state, get_mocked_state_unauthed}; #[tokio::test] async fn test_list() { - let db = DatabaseConnection::Disconnected; - let config = Config::default(); - let provider = FakeProviderApi::new(config.clone()).unwrap(); - let state = Arc::new(ServiceState::new(config, db, provider).unwrap()); - let mut api = openapi_router().with_state(state); + let mut identity_mock = MockIdentityProvider::default(); + identity_mock + .expect_list_groups() + .withf(|_: &DatabaseConnection, _: &GroupListParameters| true) + .returning(|_, _| { + Ok(vec![Group { + id: "1".into(), + name: "2".into(), + ..Default::default() + }]) + }); + + let state = get_mocked_state(identity_mock); + + let mut api = openapi_router() + .layer(TraceLayer::new_for_http()) + .with_state(state); let response = api .as_service() - .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) + .oneshot( + Request::builder() + .uri("/") + .header("x-auth-token", "foo") + .body(Body::empty()) + .unwrap(), + ) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let body = response.into_body().collect().await.unwrap().to_bytes(); - let _res: GroupList = serde_json::from_slice(&body).unwrap(); + let res: GroupList = serde_json::from_slice(&body).unwrap(); + assert_eq!( + vec![ApiGroup { + id: "1".into(), + name: "2".into(), + // for some reason when deserializing missing value appears still as an empty + // object + extra: Some(json!({})), + ..Default::default() + }], + res.groups + ); } #[tokio::test] - async fn test_get() { - let db = DatabaseConnection::Disconnected; - let config = Config::default(); - let provider = FakeProviderApi::new(config.clone()).unwrap(); - let state = Arc::new(ServiceState::new(config, db, provider).unwrap()); + async fn test_list_qp() { + let mut identity_mock = MockIdentityProvider::default(); + identity_mock + .expect_list_groups() + .withf(|_: &DatabaseConnection, qp: &GroupListParameters| { + GroupListParameters { + domain_id: Some("domain".into()), + name: Some("name".into()), + } == *qp + }) + .returning(|_, _| Ok(Vec::new())); + + let state = get_mocked_state(identity_mock); let mut api = openapi_router() .layer(TraceLayer::new_for_http()) - .with_state(state.clone()); + .with_state(state); - let group = crate::identity::types::GroupCreate { - domain_id: "domain".into(), - name: "name".into(), - ..Default::default() - }; + let response = api + .as_service() + .oneshot( + Request::builder() + .uri("/?domain_id=domain&name=name") + .header("x-auth-token", "foo") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let body = response.into_body().collect().await.unwrap().to_bytes(); + let _res: GroupList = serde_json::from_slice(&body).unwrap(); + } - let created_group = state - .provider - .get_identity_provider() - .create_group(&DatabaseConnection::Disconnected, group) + #[tokio::test] + async fn test_list_unauth() { + let state = get_mocked_state_unauthed(); + + let mut api = openapi_router() + .layer(TraceLayer::new_for_http()) + .with_state(state); + + let response = api + .as_service() + .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) .await .unwrap(); + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + } + + #[tokio::test] + async fn test_get() { + let mut identity_mock = MockIdentityProvider::default(); + identity_mock + .expect_get_group() + .withf(|_: &DatabaseConnection, id: &String| *id == "foo") + .returning(|_, _| Ok(None)); + + identity_mock + .expect_get_group() + .withf(|_: &DatabaseConnection, id: &String| *id == "bar") + .returning(|_, _| { + Ok(Some(Group { + id: "bar".into(), + ..Default::default() + })) + }); + + let state = get_mocked_state(identity_mock); + + let mut api = openapi_router() + .layer(TraceLayer::new_for_http()) + .with_state(state.clone()); + let response = api .as_service() - .oneshot(Request::builder().uri("/foo").body(Body::empty()).unwrap()) + .oneshot( + Request::builder() + .uri("/foo") + .header("x-auth-token", "foo") + .body(Body::empty()) + .unwrap(), + ) .await .unwrap(); @@ -234,7 +319,8 @@ mod tests { .as_service() .oneshot( Request::builder() - .uri(format!("/{}", created_group.id)) + .uri("/bar") + .header("x-auth-token", "foo") .body(Body::empty()) .unwrap(), ) @@ -244,22 +330,42 @@ mod tests { assert_eq!(response.status(), StatusCode::OK); let body = response.into_body().collect().await.unwrap().to_bytes(); - let _user: GroupResponse = serde_json::from_slice(&body).unwrap(); + let res: GroupResponse = serde_json::from_slice(&body).unwrap(); + assert_eq!( + ApiGroup { + id: "bar".into(), + extra: Some(json!({})), + ..Default::default() + }, + res.group, + ); } #[tokio::test] async fn test_create() { - let db = DatabaseConnection::Disconnected; - let config = Config::default(); - let provider = FakeProviderApi::new(config.clone()).unwrap(); - let state = Arc::new(ServiceState::new(config, db, provider).unwrap()); + let mut identity_mock = MockIdentityProvider::default(); + identity_mock + .expect_create_group() + .withf(|_: &DatabaseConnection, req: &GroupCreate| { + req.domain_id == "domain" && req.name == "name" + }) + .returning(|_, req| { + Ok(Group { + id: "bar".into(), + domain_id: req.domain_id, + name: req.name, + ..Default::default() + }) + }); + + let state = get_mocked_state(identity_mock); let mut api = openapi_router() .layer(TraceLayer::new_for_http()) .with_state(state.clone()); let req = GroupCreateRequest { - group: GroupCreate { + group: ApiGroupCreate { domain_id: "domain".into(), name: "name".into(), ..Default::default() @@ -271,8 +377,9 @@ mod tests { .oneshot( Request::builder() .method("POST") - .header(http::header::CONTENT_TYPE, "application/json") + .header(header::CONTENT_TYPE, "application/json") .uri("/") + .header("x-auth-token", "foo") .body(Body::from(serde_json::to_string(&req).unwrap())) .unwrap(), ) @@ -284,38 +391,35 @@ mod tests { let body = response.into_body().collect().await.unwrap().to_bytes(); let res: GroupResponse = serde_json::from_slice(&body).unwrap(); assert_eq!(res.group.name, req.group.name); + assert_eq!(res.group.domain_id, req.group.domain_id); } #[tokio::test] async fn test_delete() { - let db = DatabaseConnection::Disconnected; - let config = Config::default(); - let provider = FakeProviderApi::new(config.clone()).unwrap(); - let state = Arc::new(ServiceState::new(config, db, provider).unwrap()); + let mut identity_mock = MockIdentityProvider::default(); + identity_mock + .expect_delete_group() + .withf(|_: &DatabaseConnection, id: &String| *id == "foo") + .returning(|_, _| Err(IdentityProviderError::GroupNotFound("foo".into()))); + + identity_mock + .expect_delete_group() + .withf(|_: &DatabaseConnection, id: &String| *id == "bar") + .returning(|_, _| Ok(())); + + let state = get_mocked_state(identity_mock); let mut api = openapi_router() .layer(TraceLayer::new_for_http()) .with_state(state.clone()); - let group = crate::identity::types::GroupCreate { - domain_id: "domain".into(), - name: "name".into(), - ..Default::default() - }; - - let created_group = state - .provider - .get_identity_provider() - .create_group(&DatabaseConnection::Disconnected, group) - .await - .unwrap(); - let response = api .as_service() .oneshot( Request::builder() .method("DELETE") .uri("/foo") + .header("x-auth-token", "foo") .body(Body::empty()) .unwrap(), ) @@ -329,7 +433,8 @@ mod tests { .oneshot( Request::builder() .method("DELETE") - .uri(format!("/{}", created_group.id)) + .uri("/bar") + .header("x-auth-token", "foo") .body(Body::empty()) .unwrap(), ) diff --git a/src/api/v3/group/types.rs b/src/api/v3/group/types.rs index 3faa8c7a..391ddb1a 100644 --- a/src/api/v3/group/types.rs +++ b/src/api/v3/group/types.rs @@ -13,9 +13,9 @@ // SPDX-License-Identifier: Apache-2.0 use axum::{ + Json, http::StatusCode, response::{IntoResponse, Response}, - Json, }; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -23,7 +23,7 @@ use utoipa::{IntoParams, ToSchema}; use crate::identity::types; -#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)] +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] pub struct Group { /// Group ID pub id: String, @@ -33,17 +33,17 @@ pub struct Group { pub name: String, /// Group description pub description: Option, - #[serde(flatten)] + #[serde(flatten, skip_serializing_if = "Option::is_none")] pub extra: Option, } -#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)] +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] pub struct GroupResponse { /// group object pub group: Group, } -#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)] +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] pub struct GroupCreate { /// Group domain ID pub domain_id: String, @@ -51,11 +51,11 @@ pub struct GroupCreate { pub name: String, /// Group description pub description: Option, - #[serde(flatten)] + #[serde(default, flatten, skip_serializing_if = "Option::is_none")] pub extra: Option, } -#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)] +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] pub struct GroupCreateRequest { /// Group object pub group: GroupCreate, @@ -77,7 +77,7 @@ impl From for types::GroupCreate { fn from(value: GroupCreateRequest) -> Self { let group = value.group; Self { - id: String::new(), + id: None, name: group.name, domain_id: group.domain_id, extra: group.extra, @@ -105,7 +105,7 @@ impl IntoResponse for types::Group { } /// Groups -#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)] +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] pub struct GroupList { /// Collection of group objects pub groups: Vec, diff --git a/src/api/v3/mod.rs b/src/api/v3/mod.rs index 31777f02..5c4afc38 100644 --- a/src/api/v3/mod.rs +++ b/src/api/v3/mod.rs @@ -12,19 +12,14 @@ // // SPDX-License-Identifier: Apache-2.0 -use std::sync::Arc; use utoipa_axum::router::OpenApiRouter; use crate::keystone::ServiceState; -use crate::provider::Provider; pub mod group; pub mod user; -pub(super) fn openapi_router

() -> OpenApiRouter>> -where - P: Provider + 'static, -{ +pub(super) fn openapi_router() -> OpenApiRouter { OpenApiRouter::new() .nest("/users", user::openapi_router()) .nest("/groups", group::openapi_router()) diff --git a/src/api/v3/user/mod.rs b/src/api/v3/user/mod.rs index dedfdca6..4e9423b3 100644 --- a/src/api/v3/user/mod.rs +++ b/src/api/v3/user/mod.rs @@ -13,27 +13,22 @@ // SPDX-License-Identifier: Apache-2.0 use axum::{ - extract::{Extension, Path, Query, State}, + Json, + extract::{Path, Query, State}, http::StatusCode, response::IntoResponse, - Json, }; -use std::sync::Arc; use utoipa_axum::{router::OpenApiRouter, routes}; use crate::api::auth::Auth; use crate::api::error::KeystoneApiError; use crate::identity::IdentityApi; use crate::keystone::ServiceState; -use crate::provider::Provider; use types::{User, UserCreateRequest, UserList, UserListParameters, UserResponse}; mod types; -pub(super) fn openapi_router

() -> OpenApiRouter>> -where - P: Provider + 'static, -{ +pub(super) fn openapi_router() -> OpenApiRouter { OpenApiRouter::new() .routes(routes!(list, create)) .routes(routes!(show, remove)) @@ -52,14 +47,11 @@ where tag="users" )] #[tracing::instrument(name = "api::user_list", level = "debug", skip(state))] -async fn list

( - Extension(current_user): Extension, +async fn list( + Auth(user_auth): Auth, Query(query): Query, - State(state): State>>, -) -> Result -where - P: Provider, -{ + State(state): State, +) -> Result { let users: Vec = state .provider .get_identity_provider() @@ -84,17 +76,15 @@ where tag="users" )] #[tracing::instrument(name = "api::user_get", level = "debug", skip(state))] -async fn show

( +async fn show( + Auth(user_auth): Auth, Path(user_id): Path, - State(state): State>>, -) -> Result -where - P: Provider, -{ + State(state): State, +) -> Result { state .provider .get_identity_provider() - .get_user(&state.db, &user_id) + .get_user(&state.db, user_id.clone()) .await .map(|x| { x.ok_or_else(|| KeystoneApiError::NotFound { @@ -115,14 +105,12 @@ where tag="users" )] #[tracing::instrument(name = "api::create_user", level = "debug", skip(state))] -async fn create

( +async fn create( + Auth(user_auth): Auth, Query(query): Query, - State(state): State>>, + State(state): State, Json(req): Json, -) -> Result -where - P: Provider, -{ +) -> Result { let user = state .provider .get_identity_provider() @@ -145,17 +133,15 @@ where tag="users" )] #[tracing::instrument(name = "api::user_delete", level = "debug", skip(state))] -async fn remove

( +async fn remove( + Auth(user_auth): Auth, Path(user_id): Path, - State(state): State>>, -) -> Result -where - P: Provider, -{ + State(state): State, +) -> Result { state .provider .get_identity_provider() - .delete_user(&state.db, &user_id) + .delete_user(&state.db, user_id) .await .map_err(KeystoneApiError::identity)?; Ok((StatusCode::NO_CONTENT).into_response()) @@ -166,31 +152,44 @@ mod tests { use axum::{ body::Body, http::{self, Request, StatusCode}, - middleware, }; use http_body_util::BodyExt; // for `collect` use sea_orm::DatabaseConnection; - use std::sync::Arc; + use serde_json::json; + use tower::ServiceExt; // for `call`, `oneshot`, and `ready` use tower_http::trace::TraceLayer; use super::openapi_router; - use crate::api::auth::auth; - use crate::api::v3::user::types::*; - use crate::config::Config; - use crate::identity::IdentityApi; - use crate::keystone::ServiceState; - use crate::provider::{FakeProviderApi, Provider}; + use crate::api::v3::user::types::{ + User as ApiUser, UserCreate as ApiUserCreate, UserCreateRequest, UserList, UserResponse, + }; + use crate::identity::{ + MockIdentityProvider, + error::IdentityProviderError, + types::{User, UserCreate, UserListParameters}, + }; + + use crate::tests::api::{get_mocked_state, get_mocked_state_unauthed}; #[tokio::test] async fn test_list() { - let db = DatabaseConnection::Disconnected; - let config = Config::default(); - let provider = FakeProviderApi::new(config.clone()).unwrap(); - let state = Arc::new(ServiceState::new(config, db, provider).unwrap()); + let mut identity_mock = MockIdentityProvider::default(); + identity_mock + .expect_list_users() + .withf(|_: &DatabaseConnection, _: &UserListParameters| true) + .returning(|_, _| { + Ok(vec![User { + id: "1".into(), + name: "2".into(), + ..Default::default() + }]) + }); + + let state = get_mocked_state(identity_mock); + let mut api = openapi_router() .layer(TraceLayer::new_for_http()) - .route_layer(middleware::from_fn_with_state(state.clone(), auth)) .with_state(state); let response = api @@ -208,22 +207,98 @@ mod tests { assert_eq!(response.status(), StatusCode::OK); let body = response.into_body().collect().await.unwrap().to_bytes(); - let _users: UserList = serde_json::from_slice(&body).unwrap(); + let res: UserList = serde_json::from_slice(&body).unwrap(); + assert_eq!( + vec![ApiUser { + id: "1".into(), + name: "2".into(), + // object + extra: Some(json!({})), + ..Default::default() + }], + res.users + ); + } + + #[tokio::test] + async fn test_list_qp() { + let mut identity_mock = MockIdentityProvider::default(); + identity_mock + .expect_list_users() + .withf(|_: &DatabaseConnection, qp: &UserListParameters| { + UserListParameters { + domain_id: Some("domain".into()), + name: Some("name".into()), + } == *qp + }) + .returning(|_, _| Ok(Vec::new())); + + let state = get_mocked_state(identity_mock); + + let mut api = openapi_router() + .layer(TraceLayer::new_for_http()) + .with_state(state); + + let response = api + .as_service() + .oneshot( + Request::builder() + .uri("/?domain_id=domain&name=name") + .header("x-auth-token", "foo") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let body = response.into_body().collect().await.unwrap().to_bytes(); + let _res: UserList = serde_json::from_slice(&body).unwrap(); + } + + #[tokio::test] + async fn test_list_unauth() { + let state = get_mocked_state_unauthed(); + + let mut api = openapi_router() + .layer(TraceLayer::new_for_http()) + .with_state(state); + + let response = api + .as_service() + .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); } #[tokio::test] async fn test_create() { - let db = DatabaseConnection::Disconnected; - let config = Config::default(); - let provider = FakeProviderApi::new(config.clone()).unwrap(); - let state = Arc::new(ServiceState::new(config, db, provider).unwrap()); + let mut identity_mock = MockIdentityProvider::default(); + identity_mock + .expect_create_user() + .withf(|_: &DatabaseConnection, req: &UserCreate| { + req.domain_id == "domain" && req.name == "name" + }) + .returning(|_, req| { + Ok(User { + id: "bar".into(), + domain_id: req.domain_id, + name: req.name, + ..Default::default() + }) + }); + + let state = get_mocked_state(identity_mock); let mut api = openapi_router() .layer(TraceLayer::new_for_http()) .with_state(state.clone()); let user = UserCreateRequest { - user: UserCreate { + user: ApiUserCreate { domain_id: "domain".into(), name: "name".into(), ..Default::default() @@ -237,6 +312,7 @@ mod tests { .method("POST") .header(http::header::CONTENT_TYPE, "application/json") .uri("/") + .header("x-auth-token", "foo") .body(Body::from(serde_json::to_string(&user).unwrap())) .unwrap(), ) @@ -252,31 +328,37 @@ mod tests { #[tokio::test] async fn test_get() { - let db = DatabaseConnection::Disconnected; - let config = Config::default(); - let provider = FakeProviderApi::new(config.clone()).unwrap(); - let state = Arc::new(ServiceState::new(config, db, provider).unwrap()); + let mut identity_mock = MockIdentityProvider::default(); + identity_mock + .expect_get_user() + .withf(|_: &DatabaseConnection, id: &String| *id == "foo") + .returning(|_, _| Ok(None)); + + identity_mock + .expect_get_user() + .withf(|_: &DatabaseConnection, id: &String| *id == "bar") + .returning(|_, _| { + Ok(Some(User { + id: "bar".into(), + ..Default::default() + })) + }); + + let state = get_mocked_state(identity_mock); let mut api = openapi_router() .layer(TraceLayer::new_for_http()) .with_state(state.clone()); - let user = crate::identity::types::UserCreate { - domain_id: "domain".into(), - name: "name".into(), - ..Default::default() - }; - - let created_user = state - .provider - .get_identity_provider() - .create_user(&DatabaseConnection::Disconnected, user) - .await - .unwrap(); - let response = api .as_service() - .oneshot(Request::builder().uri("/foo").body(Body::empty()).unwrap()) + .oneshot( + Request::builder() + .uri("/foo") + .header("x-auth-token", "foo") + .body(Body::empty()) + .unwrap(), + ) .await .unwrap(); @@ -286,7 +368,8 @@ mod tests { .as_service() .oneshot( Request::builder() - .uri(format!("/{}", created_user.id)) + .uri("/bar") + .header("x-auth-token", "foo") .body(Body::empty()) .unwrap(), ) @@ -296,39 +379,43 @@ mod tests { assert_eq!(response.status(), StatusCode::OK); let body = response.into_body().collect().await.unwrap().to_bytes(); - let _user: UserResponse = serde_json::from_slice(&body).unwrap(); + let res: UserResponse = serde_json::from_slice(&body).unwrap(); + assert_eq!( + ApiUser { + id: "bar".into(), + extra: Some(json!({})), + ..Default::default() + }, + res.user, + ); } #[tokio::test] async fn test_delete() { - let db = DatabaseConnection::Disconnected; - let config = Config::default(); - let provider = FakeProviderApi::new(config.clone()).unwrap(); - let state = Arc::new(ServiceState::new(config, db, provider).unwrap()); + let mut identity_mock = MockIdentityProvider::default(); + identity_mock + .expect_delete_user() + .withf(|_: &DatabaseConnection, id: &String| *id == "foo") + .returning(|_, _| Err(IdentityProviderError::UserNotFound("foo".into()))); + + identity_mock + .expect_delete_user() + .withf(|_: &DatabaseConnection, id: &String| *id == "bar") + .returning(|_, _| Ok(())); + + let state = get_mocked_state(identity_mock); let mut api = openapi_router() .layer(TraceLayer::new_for_http()) .with_state(state.clone()); - let user = crate::identity::types::UserCreate { - domain_id: "domain".into(), - name: "name".into(), - ..Default::default() - }; - - let created_user = state - .provider - .get_identity_provider() - .create_user(&DatabaseConnection::Disconnected, user) - .await - .unwrap(); - let response = api .as_service() .oneshot( Request::builder() .method("DELETE") .uri("/foo") + .header("x-auth-token", "foo") .body(Body::empty()) .unwrap(), ) @@ -342,7 +429,8 @@ mod tests { .oneshot( Request::builder() .method("DELETE") - .uri(format!("/{}", created_user.id)) + .uri("/bar") + .header("x-auth-token", "foo") .body(Body::empty()) .unwrap(), ) diff --git a/src/api/v3/user/types.rs b/src/api/v3/user/types.rs index 29ef9111..ca153876 100644 --- a/src/api/v3/user/types.rs +++ b/src/api/v3/user/types.rs @@ -13,9 +13,9 @@ // SPDX-License-Identifier: Apache-2.0 use axum::{ + Json, http::StatusCode, response::{IntoResponse, Response}, - Json, }; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; @@ -24,7 +24,7 @@ use utoipa::{IntoParams, ToSchema}; use crate::identity::types; -#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)] +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] pub struct User { /// User ID pub id: String, @@ -41,25 +41,28 @@ pub struct User { /// default project, the default project is ignored at token creation. (Since v3.1) /// Additionally, if your default project is not valid, a token is issued without an explicit /// scope of authorization. + #[serde(skip_serializing_if = "Option::is_none")] pub default_project_id: Option, - #[serde(flatten)] + #[serde(flatten, skip_serializing_if = "Option::is_none")] pub extra: Option, /// The date and time when the password expires. The time zone is UTC. + #[serde(skip_serializing_if = "Option::is_none")] pub password_expires_at: Option>, /// The resource options for the user. Available resource options are /// ignore_change_password_upon_first_use, ignore_password_expiry, /// ignore_lockout_failure_attempts, lock_password, multi_factor_auth_enabled, and /// multi_factor_auth_rules ignore_user_inactivity. + #[serde(skip_serializing_if = "Option::is_none")] pub options: Option, } -#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)] +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] pub struct UserResponse { /// User object pub user: User, } -#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)] +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] pub struct UserCreate { /// User domain ID pub domain_id: String, @@ -87,7 +90,7 @@ pub struct UserCreate { pub extra: Option, } -#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)] +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] pub struct UserUpdateRequest { /// The user name. Must be unique within the owning domain. pub name: Option, @@ -113,7 +116,7 @@ pub struct UserUpdateRequest { pub extra: Option, } -#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)] +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] pub struct UserOptions { #[serde(skip_serializing_if = "Option::is_none")] pub ignore_change_password_upon_first_use: Option, @@ -159,7 +162,7 @@ impl From for types::UserOptions { } } -#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)] +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] pub struct UserCreateRequest { /// User object pub user: UserCreate, @@ -167,6 +170,20 @@ pub struct UserCreateRequest { impl From for User { fn from(value: types::User) -> Self { + let opts: UserOptions = value.options.clone().into(); + // We only want to see user options if there is at least 1 option set + let opts = if opts.ignore_change_password_upon_first_use.is_some() + || opts.ignore_password_expiry.is_some() + || opts.ignore_lockout_failure_attempts.is_some() + || opts.lock_password.is_some() + || opts.ignore_user_inactivity.is_some() + || opts.multi_factor_auth_rules.is_some() + || opts.multi_factor_auth_enabled.is_some() + { + Some(opts) + } else { + None + }; Self { id: value.id, domain_id: value.domain_id, @@ -175,7 +192,7 @@ impl From for User { default_project_id: value.default_project_id, extra: value.extra, password_expires_at: value.password_expires_at, - options: Some(value.options.into()), + options: opts, } } } @@ -216,7 +233,7 @@ impl IntoResponse for types::User { } /// Users -#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)] +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)] pub struct UserList { /// Collection of user objects pub users: Vec, @@ -235,7 +252,7 @@ impl IntoResponse for UserList { } } -#[derive(Clone, Debug, Default, Deserialize, Serialize, IntoParams)] +#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, IntoParams)] pub struct UserListParameters { /// Filter users by Domain ID pub domain_id: Option, diff --git a/src/bin/keystone.rs b/src/bin/keystone.rs index cf1978cb..889c6b85 100644 --- a/src/bin/keystone.rs +++ b/src/bin/keystone.rs @@ -12,10 +12,7 @@ // // SPDX-License-Identifier: Apache-2.0 -use axum::{ - http::{self, header, HeaderName, Request}, - middleware::{self}, -}; +use axum::http::{self, HeaderName, Request, header}; use clap::Parser; use color_eyre::eyre::{Report, Result}; use sea_orm::ConnectOptions; @@ -23,15 +20,14 @@ use sea_orm::Database; use std::io; use std::net::{Ipv4Addr, SocketAddr}; use std::sync::Arc; -use tokio::net::TcpListener; -use tokio::signal; +use tokio::{net::TcpListener, signal}; use tower::ServiceBuilder; use tower_http::{ + LatencyUnit, ServiceBuilderExt, request_id::{MakeRequestId, PropagateRequestIdLayer, RequestId, SetRequestIdLayer}, trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer}, - LatencyUnit, ServiceBuilderExt, }; -use tracing::{info_span, Level}; +use tracing::{Level, info_span}; use tracing_subscriber::{filter::LevelFilter, prelude::*}; use utoipa::OpenApi; use utoipa_axum::router::OpenApiRouter; @@ -39,11 +35,10 @@ use utoipa_swagger_ui::SwaggerUi; use uuid::Uuid; use openstack_keystone::api; -use openstack_keystone::api::auth::auth; use openstack_keystone::config::Config; -use openstack_keystone::keystone::ServiceState; +use openstack_keystone::keystone::{Service, ServiceState}; use openstack_keystone::plugin_manager::PluginManager; -use openstack_keystone::provider::{Provider, ProviderApi}; +use openstack_keystone::provider::Provider; /// Simple program to greet a person #[derive(Parser, Debug)] @@ -101,12 +96,12 @@ async fn main() -> Result<(), Report> { let plugin_manager = PluginManager::default(); - let provider = ProviderApi::new(cfg.clone(), plugin_manager)?; + let provider = Provider::new(cfg.clone(), plugin_manager)?; - let shared_state = Arc::new(ServiceState::new(cfg, conn, provider).unwrap()); + let shared_state = Arc::new(Service::new(cfg, conn, provider).unwrap()); let (router, api) = OpenApiRouter::with_openapi(api::ApiDoc::openapi()) - .merge(api::openapi_router::()) + .merge(api::openapi_router()) .split_for_parts(); let x_request_id = HeaderName::from_static("x-openstack-request-id"); @@ -147,21 +142,18 @@ async fn main() -> Result<(), Report> { .layer(PropagateRequestIdLayer::new(x_request_id)); let app = router - // Router::new() - //.merge(api::router()) .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", api)) - .route_layer(middleware::from_fn_with_state(shared_state.clone(), auth)) .layer(middleware) .with_state(shared_state.clone()); let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, 8080)); let listener = TcpListener::bind(&address).await?; Ok(axum::serve(listener, app.into_make_service()) - .with_graceful_shutdown(shutdown_signal(shared_state.clone())) + .with_graceful_shutdown(shutdown_signal(shared_state)) .await?) } -async fn shutdown_signal(state: Arc>) { +async fn shutdown_signal(state: ServiceState) { let ctrl_c = async { signal::ctrl_c() .await diff --git a/src/identity/backends/sql.rs b/src/identity/backends/sql.rs index 85672869..56000e69 100644 --- a/src/identity/backends/sql.rs +++ b/src/identity/backends/sql.rs @@ -13,9 +13,9 @@ // SPDX-License-Identifier: Apache-2.0 use async_trait::async_trait; +use sea_orm::DatabaseConnection; use sea_orm::entity::*; use sea_orm::query::*; -use sea_orm::DatabaseConnection; mod common; mod federated_user; @@ -33,9 +33,9 @@ use crate::db::entity::{ prelude::{FederatedUser, LocalUser, NonlocalUser, User as DbUser, UserOption}, user as db_user, user_option as db_user_option, }; +use crate::identity::IdentityProviderError; use crate::identity::backends::error::IdentityDatabaseError; use crate::identity::password_hashing; -use crate::identity::IdentityProviderError; #[derive(Clone, Debug, Default)] pub struct SqlBackend { @@ -231,26 +231,29 @@ pub async fn get_user( if let Some(user) = &user_entry { let user_opts: Vec = user.find_related(UserOption).all(db).await?; - let user_builder: UserBuilder = if let Some(local_user_with_passwords) = - local_user::load_local_user_with_passwords(db, &user_id).await? - { - common::get_local_user_builder( - conf, - user, - local_user_with_passwords.0, - Some(local_user_with_passwords.1), - user_opts, - ) - } else if let Some(nonlocal_user) = user.find_related(NonlocalUser).one(db).await? { - common::get_nonlocal_user_builder(user, nonlocal_user, user_opts) - } else { - let federated_user = user.find_related(FederatedUser).all(db).await?; - if !federated_user.is_empty() { - common::get_federated_user_builder(user, federated_user, user_opts) - } else { - return Err(IdentityDatabaseError::MalformedUser(user_id.clone()))?; - } - }; + let user_builder: UserBuilder = + match local_user::load_local_user_with_passwords(db, &user_id).await? { + Some(local_user_with_passwords) => common::get_local_user_builder( + conf, + user, + local_user_with_passwords.0, + Some(local_user_with_passwords.1), + user_opts, + ), + _ => match user.find_related(NonlocalUser).one(db).await? { + Some(nonlocal_user) => { + common::get_nonlocal_user_builder(user, nonlocal_user, user_opts) + } + _ => { + let federated_user = user.find_related(FederatedUser).all(db).await?; + if !federated_user.is_empty() { + common::get_federated_user_builder(user, federated_user, user_opts) + } else { + return Err(IdentityDatabaseError::MalformedUser(user_id.clone()))?; + } + } + }, + }; return Ok(Some(user_builder.build()?)); } diff --git a/src/identity/backends/sql/group.rs b/src/identity/backends/sql/group.rs index 5c9b8c08..f5bde82c 100644 --- a/src/identity/backends/sql/group.rs +++ b/src/identity/backends/sql/group.rs @@ -12,15 +12,16 @@ // // SPDX-License-Identifier: Apache-2.0 +use sea_orm::DatabaseConnection; use sea_orm::entity::*; use sea_orm::query::*; -use sea_orm::DatabaseConnection; use serde_json::Value; +use serde_json::json; use crate::db::entity::{group, prelude::Group as DbGroup}; +use crate::identity::Config; use crate::identity::backends::sql::IdentityDatabaseError; use crate::identity::types::{Group, GroupCreate, GroupListParameters}; -use crate::identity::Config; pub async fn list( _conf: &Config, @@ -60,7 +61,7 @@ pub async fn create( group: GroupCreate, ) -> Result { let entry = group::ActiveModel { - id: Set(group.id.clone()), + id: Set(group.id.clone().unwrap_or_default()), domain_id: Set(group.domain_id.clone()), name: Set(group.name.clone()), description: Set(group.description.clone()), @@ -94,7 +95,7 @@ impl From for Group { domain_id: value.domain_id.clone(), extra: value .extra - .map(|x| serde_json::from_str::(&x).unwrap_or(Value::Null)), + .map(|x| serde_json::from_str::(&x).unwrap_or(json!(true))), } } } @@ -107,8 +108,8 @@ mod tests { use serde_json::json; use crate::db::entity::group; - use crate::identity::types::group::GroupListParametersBuilder; use crate::identity::Config; + use crate::identity::types::group::GroupListParametersBuilder; use super::*; @@ -236,7 +237,7 @@ mod tests { let config = Config::default(); let req = GroupCreate { - id: "1".into(), + id: Some("1".into()), domain_id: "foo_domain".into(), name: "group".into(), description: Some("fake".into()), diff --git a/src/identity/backends/sql/local_user.rs b/src/identity/backends/sql/local_user.rs index d63fcc49..b1f3cbc6 100644 --- a/src/identity/backends/sql/local_user.rs +++ b/src/identity/backends/sql/local_user.rs @@ -12,9 +12,9 @@ // // SPDX-License-Identifier: Apache-2.0 +use sea_orm::DatabaseConnection; use sea_orm::entity::*; use sea_orm::query::*; -use sea_orm::DatabaseConnection; use std::collections::HashMap; use crate::config::Config; @@ -30,7 +30,10 @@ pub async fn load_local_user_with_passwords>( db: &DatabaseConnection, user_id: S, ) -> Result< - Option<(local_user::Model, impl IntoIterator)>, + Option<( + local_user::Model, + impl IntoIterator + use, + )>, IdentityDatabaseError, > { let results: Vec<(local_user::Model, Vec)> = LocalUser::find() @@ -48,11 +51,7 @@ pub async fn load_local_user_with_passwords>( pub async fn load_local_users_passwords>>( db: &DatabaseConnection, user_ids: L, -) -> Result< - //impl IntoIterator>>, - Vec>>, - IdentityDatabaseError, -> { +) -> Result>>, IdentityDatabaseError> { let ids: Vec> = user_ids.into_iter().collect(); // Collect local user IDs that we need to query let keys: Vec = ids.iter().filter_map(Option::as_ref).copied().collect(); diff --git a/src/identity/backends/sql/password.rs b/src/identity/backends/sql/password.rs index 003108e0..f1983023 100644 --- a/src/identity/backends/sql/password.rs +++ b/src/identity/backends/sql/password.rs @@ -13,8 +13,8 @@ // SPDX-License-Identifier: Apache-2.0 use chrono::{DateTime, Local, Utc}; -use sea_orm::entity::*; use sea_orm::DatabaseConnection; +use sea_orm::entity::*; use crate::db::entity::password; use crate::identity::backends::error::IdentityDatabaseError; diff --git a/src/identity/backends/sql/user.rs b/src/identity/backends/sql/user.rs index 17521a19..404e3a00 100644 --- a/src/identity/backends/sql/user.rs +++ b/src/identity/backends/sql/user.rs @@ -13,8 +13,8 @@ // SPDX-License-Identifier: Apache-2.0 use chrono::Local; -use sea_orm::entity::*; use sea_orm::DatabaseConnection; +use sea_orm::entity::*; use crate::config::Config; use crate::db::entity::{prelude::User as DbUser, user}; diff --git a/src/identity/mod.rs b/src/identity/mod.rs index 64954d5e..43f7e2eb 100644 --- a/src/identity/mod.rs +++ b/src/identity/mod.rs @@ -13,12 +13,9 @@ // SPDX-License-Identifier: Apache-2.0 use async_trait::async_trait; -use sea_orm::DatabaseConnection; #[cfg(test)] -use std::{ - collections::HashMap, - sync::{Arc, Mutex}, -}; +use mockall::mock; +use sea_orm::DatabaseConnection; use uuid::Uuid; pub mod backends; @@ -29,9 +26,10 @@ pub(crate) mod types; use crate::config::Config; use crate::identity::backends::sql::SqlBackend; use crate::identity::error::IdentityProviderError; -use crate::identity::types::IdentityBackend; -use crate::identity::types::{Group, GroupCreate, GroupListParameters}; -use crate::identity::types::{User, UserCreate, UserListParameters}; +use crate::identity::types::{ + IdentityBackend, + {Group, GroupCreate, GroupListParameters, User, UserCreate, UserListParameters}, +}; use crate::plugin_manager::PluginManager; #[derive(Clone, Debug)] @@ -47,10 +45,10 @@ pub trait IdentityApi: Send + Sync + Clone { params: &UserListParameters, ) -> Result, IdentityProviderError>; - async fn get_user + std::fmt::Debug + Send + Sync>( + async fn get_user( &self, db: &DatabaseConnection, - user_id: S, + user_id: String, ) -> Result, IdentityProviderError>; async fn create_user( @@ -59,10 +57,10 @@ pub trait IdentityApi: Send + Sync + Clone { user: UserCreate, ) -> Result; - async fn delete_user + std::fmt::Debug + Send + Sync>( + async fn delete_user( &self, db: &DatabaseConnection, - user_id: S, + user_id: String, ) -> Result<(), IdentityProviderError>; async fn list_groups( @@ -71,10 +69,10 @@ pub trait IdentityApi: Send + Sync + Clone { params: &GroupListParameters, ) -> Result, IdentityProviderError>; - async fn get_group + std::fmt::Debug + Send + Sync>( + async fn get_group( &self, db: &DatabaseConnection, - group_id: S, + group_id: String, ) -> Result, IdentityProviderError>; async fn create_group( @@ -83,13 +81,76 @@ pub trait IdentityApi: Send + Sync + Clone { group: GroupCreate, ) -> Result; - async fn delete_group + std::fmt::Debug + Send + Sync>( + async fn delete_group( &self, db: &DatabaseConnection, - group_id: S, + group_id: String, ) -> Result<(), IdentityProviderError>; } +#[cfg(test)] +mock! { + pub IdentityProvider { + pub fn new(cfg: &Config, plugin_manager: &PluginManager) -> Result; + } + + #[async_trait] + impl IdentityApi for IdentityProvider { + async fn list_users( + &self, + db: &DatabaseConnection, + params: &UserListParameters, + ) -> Result, IdentityProviderError>; + + async fn get_user( + &self, + db: &DatabaseConnection, + user_id: String, + ) -> Result, IdentityProviderError>; + + async fn create_user( + &self, + db: &DatabaseConnection, + user: UserCreate, + ) -> Result; + + async fn delete_user( + &self, + db: &DatabaseConnection, + user_id: String, + ) -> Result<(), IdentityProviderError>; + + async fn list_groups( + &self, + db: &DatabaseConnection, + params: &GroupListParameters, + ) -> Result, IdentityProviderError>; + + async fn get_group( + &self, + db: &DatabaseConnection, + group_id: String, + ) -> Result, IdentityProviderError>; + + async fn create_group( + &self, + db: &DatabaseConnection, + group: GroupCreate, + ) -> Result; + + async fn delete_group( + &self, + db: &DatabaseConnection, + group_id: String, + ) -> Result<(), IdentityProviderError>; + } + + impl Clone for IdentityProvider { + fn clone(&self) -> Self; + } + +} + impl IdentityProvider { pub fn new( config: &Config, @@ -128,14 +189,12 @@ impl IdentityApi for IdentityProvider { /// Get single user #[tracing::instrument(level = "info", skip(self, db))] - async fn get_user + std::fmt::Debug + Send + Sync>( + async fn get_user( &self, db: &DatabaseConnection, - user_id: S, + user_id: String, ) -> Result, IdentityProviderError> { - self.backend_driver - .get_user(db, user_id.as_ref().to_string()) - .await + self.backend_driver.get_user(db, user_id).await } /// Create user @@ -155,14 +214,12 @@ impl IdentityApi for IdentityProvider { /// Delete user #[tracing::instrument(level = "info", skip(self, db))] - async fn delete_user + std::fmt::Debug + Send + Sync>( + async fn delete_user( &self, db: &DatabaseConnection, - user_id: S, + user_id: String, ) -> Result<(), IdentityProviderError> { - self.backend_driver - .delete_user(db, user_id.as_ref().to_string()) - .await + self.backend_driver.delete_user(db, user_id).await } /// List groups @@ -177,14 +234,12 @@ impl IdentityApi for IdentityProvider { /// Get single group #[tracing::instrument(level = "info", skip(self, db))] - async fn get_group + std::fmt::Debug + Send + Sync>( + async fn get_group( &self, db: &DatabaseConnection, - group_id: S, + group_id: String, ) -> Result, IdentityProviderError> { - self.backend_driver - .get_group(db, group_id.as_ref().to_string()) - .await + self.backend_driver.get_group(db, group_id).await } /// Create group @@ -195,164 +250,17 @@ impl IdentityApi for IdentityProvider { group: GroupCreate, ) -> Result { let mut res = group; - res.id = Uuid::new_v4().into(); + res.id = Some(Uuid::new_v4().into()); self.backend_driver.create_group(db, res).await } /// Delete group #[tracing::instrument(level = "info", skip(self, db))] - async fn delete_group + std::fmt::Debug + Send + Sync>( + async fn delete_group( &self, db: &DatabaseConnection, - group_id: S, - ) -> Result<(), IdentityProviderError> { - self.backend_driver - .delete_group(db, group_id.as_ref().to_string()) - .await - } -} - -#[cfg(test)] -#[derive(Clone, Debug, Default)] -pub(crate) struct FakeIdentityProvider { - users: Arc>>, - groups: Arc>>, -} - -#[cfg(test)] -impl From for User { - fn from(value: UserCreate) -> Self { - Self { - id: value.id, - name: value.name, - domain_id: value.domain_id, - ..Default::default() - } - } -} -#[cfg(test)] -impl From for Group { - fn from(value: GroupCreate) -> Self { - Self { - id: value.id, - name: value.name, - domain_id: value.domain_id, - ..Default::default() - } - } -} - -#[cfg(test)] -#[async_trait] -impl IdentityApi for FakeIdentityProvider { - /// List users - async fn list_users( - &self, - _db: &DatabaseConnection, - _params: &types::UserListParameters, - ) -> Result, IdentityProviderError> { - Ok(self - .users - .lock() - .unwrap() - .values() - .cloned() - .collect::>()) - } - - /// Get single user - async fn get_user + std::fmt::Debug + Send + Sync>( - &self, - _db: &DatabaseConnection, - user_id: S, - ) -> Result, IdentityProviderError> { - Ok(self.users.lock().unwrap().get(user_id.as_ref()).cloned()) - } - - /// Create user - async fn create_user( - &self, - _db: &DatabaseConnection, - user: UserCreate, - ) -> Result { - let mut mod_user = user; - mod_user.id = Uuid::new_v4().into(); - let res = User::from(mod_user); - self.users - .lock() - .unwrap() - .insert(res.id.clone(), res.clone()); - Ok(res) - } - - /// Delete user - async fn delete_user + std::fmt::Debug + Send + Sync>( - &self, - _db: &DatabaseConnection, - user_id: S, - ) -> Result<(), IdentityProviderError> { - Ok(self - .users - .lock() - .unwrap() - .remove(user_id.as_ref()) - .map(|_| ()) - .ok_or(IdentityProviderError::UserNotFound( - user_id.as_ref().to_string(), - ))?) - } - - async fn list_groups( - &self, - _db: &DatabaseConnection, - _params: &GroupListParameters, - ) -> Result, IdentityProviderError> { - Ok(self - .groups - .lock() - .unwrap() - .values() - .cloned() - .collect::>()) - } - - /// Get single user - async fn get_group + std::fmt::Debug + Send + Sync>( - &self, - _db: &DatabaseConnection, - group_id: S, - ) -> Result, IdentityProviderError> { - Ok(self.groups.lock().unwrap().get(group_id.as_ref()).cloned()) - } - - /// Create group - async fn create_group( - &self, - _db: &DatabaseConnection, - group: GroupCreate, - ) -> Result { - let mut req = group; - req.id = Uuid::new_v4().into(); - let res = Group::from(req); - self.groups - .lock() - .unwrap() - .insert(res.id.clone(), res.clone()); - Ok(res) - } - /// - /// Delete group - async fn delete_group + std::fmt::Debug + Send + Sync>( - &self, - _db: &DatabaseConnection, - id: S, + group_id: String, ) -> Result<(), IdentityProviderError> { - Ok(self - .groups - .lock() - .unwrap() - .remove(id.as_ref()) - .map(|_| ()) - .ok_or(IdentityProviderError::UserNotFound(id.as_ref().to_string()))?) + self.backend_driver.delete_group(db, group_id).await } } diff --git a/src/identity/types/group.rs b/src/identity/types/group.rs index befbad88..e4a2657f 100644 --- a/src/identity/types/group.rs +++ b/src/identity/types/group.rs @@ -32,7 +32,7 @@ pub struct Group { pub name: String, } -#[derive(Builder, Clone, Debug, Default, Deserialize, Serialize)] +#[derive(Builder, Clone, Debug, Default, Deserialize, PartialEq, Serialize)] #[builder(setter(strip_option, into))] pub struct GroupListParameters { /// Filter groups by the domain @@ -41,7 +41,7 @@ pub struct GroupListParameters { pub name: Option, } -#[derive(Builder, Clone, Debug, Default, Deserialize, Serialize, PartialEq)] +#[derive(Builder, Clone, Debug, Default, Deserialize, PartialEq, Serialize)] #[builder(setter(strip_option, into))] pub struct GroupCreate { /// The description of the group. @@ -52,7 +52,7 @@ pub struct GroupCreate { #[builder(default)] pub extra: Option, /// The ID of the group. - pub id: String, + pub id: Option, /// The user name. Must be unique within the owning domain. pub name: String, } diff --git a/src/identity/types/user.rs b/src/identity/types/user.rs index c3d2bb06..72e2e4f3 100644 --- a/src/identity/types/user.rs +++ b/src/identity/types/user.rs @@ -140,7 +140,7 @@ pub struct FederationProtocol { pub unique_id: String, } -#[derive(Builder, Clone, Debug, Default, Deserialize, Serialize)] +#[derive(Builder, Clone, Debug, Default, Deserialize, PartialEq, Serialize)] pub struct UserListParameters { /// Filter users by the domain pub domain_id: Option, diff --git a/src/keystone.rs b/src/keystone.rs index a6d8e058..05ea6c5d 100644 --- a/src/keystone.rs +++ b/src/keystone.rs @@ -12,9 +12,9 @@ // // SPDX-License-Identifier: Apache-2.0 +use axum::extract::FromRef; use sea_orm::DatabaseConnection; -//use tokio::sync::RwLock; - +use std::sync::Arc; use tracing::info; use crate::config::Config; @@ -24,17 +24,22 @@ use crate::provider::Provider; // Placing ServiceState behind Arc is necessary to address DatabaseConnection not implementing // Clone //#[derive(Clone)] -pub struct ServiceState

{ +#[derive(FromRef)] +pub struct Service { pub config: Config, - pub provider: P, + pub provider: Provider, + #[from_ref(skip)] pub db: DatabaseConnection, } -impl

ServiceState

-where - P: Provider, -{ - pub fn new(cfg: Config, db: DatabaseConnection, provider: P) -> Result { +pub type ServiceState = Arc; + +impl Service { + pub fn new( + cfg: Config, + db: DatabaseConnection, + provider: Provider, + ) -> Result { Ok(Self { config: cfg.clone(), provider, diff --git a/src/lib.rs b/src/lib.rs index a9735f7c..af4629c3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,4 +24,4 @@ pub mod resource; pub mod token; #[cfg(test)] -mod tests {} +mod tests; diff --git a/src/provider.rs b/src/provider.rs index e7ecacd7..ce49a6d5 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -11,28 +11,35 @@ // limitations under the License. // // SPDX-License-Identifier: Apache-2.0 +use derive_builder::Builder; +use mockall_double::double; use crate::config::Config; use crate::error::KeystoneError; -use crate::identity::{IdentityApi, IdentityProvider}; +use crate::identity::IdentityApi; +#[double] +use crate::identity::IdentityProvider; use crate::plugin_manager::PluginManager; -use crate::token::{TokenApi, TokenProvider}; -#[cfg(test)] -use crate::{identity::FakeIdentityProvider, token::FakeTokenProvider}; - -pub trait Provider: Clone + Send + Sync { - fn get_identity_provider(&self) -> &impl IdentityApi; - fn get_token_provider(&self) -> &impl TokenApi; -} - -#[derive(Clone)] -pub struct ProviderApi { +use crate::token::TokenApi; +#[double] +use crate::token::TokenProvider; + +//pub trait Provider: Clone + Send + Sync { +// fn get_identity_provider(&self) -> &impl IdentityApi; +// fn get_token_provider(&self) -> &impl TokenApi; +//} + +#[derive(Builder, Clone)] +// It is necessary to use the owned pattern since otherwise builder invokes clone which immediately +// confuses mockall used in tests +#[builder(pattern = "owned")] +pub struct Provider { pub config: Config, identity: IdentityProvider, token: TokenProvider, } -impl ProviderApi { +impl Provider { pub fn new(cfg: Config, plugin_manager: PluginManager) -> Result { let identity_provider = IdentityProvider::new(&cfg, &plugin_manager)?; let token_provider = TokenProvider::new(&cfg)?; @@ -43,47 +50,12 @@ impl ProviderApi { token: token_provider, }) } -} - -impl Provider for ProviderApi { - fn get_identity_provider(&self) -> &impl IdentityApi { - &self.identity - } - - fn get_token_provider(&self) -> &impl TokenApi { - &self.token - } -} - -#[cfg(test)] -#[derive(Clone)] -pub struct FakeProviderApi { - pub config: Config, - identity: FakeIdentityProvider, - token: FakeTokenProvider, -} - -#[cfg(test)] -impl FakeProviderApi { - pub fn new(cfg: Config) -> Result { - let identity_provider = FakeIdentityProvider::default(); - let token_provider = FakeTokenProvider::default(); - - Ok(Self { - config: cfg, - identity: identity_provider, - token: token_provider, - }) - } -} -#[cfg(test)] -impl Provider for FakeProviderApi { - fn get_identity_provider(&self) -> &impl IdentityApi { + pub fn get_identity_provider(&self) -> &impl IdentityApi { &self.identity } - fn get_token_provider(&self) -> &impl TokenApi { + pub fn get_token_provider(&self) -> &impl TokenApi { &self.token } } diff --git a/src/tests.rs b/src/tests.rs new file mode 100644 index 00000000..bafcca80 --- /dev/null +++ b/src/tests.rs @@ -0,0 +1,15 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +pub(crate) mod api; diff --git a/src/tests/api.rs b/src/tests/api.rs new file mode 100644 index 00000000..b2db2c3a --- /dev/null +++ b/src/tests/api.rs @@ -0,0 +1,59 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +use sea_orm::DatabaseConnection; +use std::sync::Arc; + +use crate::config::Config; +use crate::identity::MockIdentityProvider; +use crate::keystone::{Service, ServiceState}; +use crate::provider::ProviderBuilder; +use crate::token::{MockTokenProvider, Token, TokenProviderError}; + +pub(crate) fn get_mocked_state_unauthed() -> ServiceState { + let db = DatabaseConnection::Disconnected; + let config = Config::default(); + let identity_mock = MockIdentityProvider::default(); + let mut token_mock = MockTokenProvider::default(); + token_mock + .expect_validate_token() + .returning(|_, _| Err(TokenProviderError::InvalidToken)); + + let provider = ProviderBuilder::default() + .config(config.clone()) + .identity(identity_mock) + .token(token_mock) + .build() + .unwrap(); + + Arc::new(Service::new(config, db, provider).unwrap()) +} + +pub(crate) fn get_mocked_state(identity_mock: MockIdentityProvider) -> ServiceState { + let db = DatabaseConnection::Disconnected; + let config = Config::default(); + let mut token_mock = MockTokenProvider::default(); + token_mock + .expect_validate_token() + .returning(|_, _| Ok(Token::default())); + + let provider = ProviderBuilder::default() + .config(config.clone()) + .identity(identity_mock) + .token(token_mock) + .build() + .unwrap(); + + Arc::new(Service::new(config, db, provider).unwrap()) +} diff --git a/src/token/fernet.rs b/src/token/fernet.rs index f4540bf6..2b8dd2ba 100644 --- a/src/token/fernet.rs +++ b/src/token/fernet.rs @@ -12,10 +12,10 @@ // // SPDX-License-Identifier: Apache-2.0 -use base64::{engine::general_purpose::URL_SAFE, Engine as _}; +use base64::{Engine as _, engine::general_purpose::URL_SAFE}; use chrono::{DateTime, Utc}; use fernet::{Fernet, MultiFernet}; -use rmp::{decode::*, Marker}; +use rmp::{Marker, decode::*}; use std::collections::BTreeMap; use std::fmt; use std::io; @@ -24,9 +24,9 @@ use uuid::Uuid; use crate::config::Config; use crate::token::{ + TokenProviderError, fernet_utils::FernetUtils, types::{Token, TokenBackend}, - TokenProviderError, }; #[derive(Default, Clone)] @@ -118,7 +118,7 @@ fn read_time(rd: &mut &[u8]) -> Result, TokenProviderError> { fn decode_auth_methods( value: usize, auth_map: &BTreeMap, -) -> Result, TokenProviderError> { +) -> Result + use<>, TokenProviderError> { let mut results: Vec = Vec::new(); let mut auth: usize = value; for (idx, name) in auth_map.iter() { @@ -141,7 +141,9 @@ fn decode_auth_methods( } /// Decode array of audit ids from the payload -fn read_audit_ids(rd: &mut &[u8]) -> Result, TokenProviderError> { +fn read_audit_ids( + rd: &mut &[u8], +) -> Result + use<>, TokenProviderError> { if let Marker::FixArray(len) = read_marker(rd).map_err(ValueReadError::from)? { let mut result: Vec = Vec::new(); for _ in 0..len { @@ -198,10 +200,9 @@ impl FernetTokenProvider { pub fn decrypt(&self, credential: String) -> Result { // TODO: Implement fernet keys change watching. Keystone loads them from FS on every // request and in the best case it costs 15µs. - let payload = if let Some(fernet) = &self.fernet { - fernet.decrypt(credential.as_ref())? - } else { - self.get_fernet()?.decrypt(credential.as_ref())? + let payload = match &self.fernet { + Some(fernet) => fernet.decrypt(credential.as_ref())?, + _ => self.get_fernet()?.decrypt(credential.as_ref())?, }; self.parse(&mut payload.as_slice()) } diff --git a/src/token/fernet_utils.rs b/src/token/fernet_utils.rs index 2423337b..19e249f6 100644 --- a/src/token/fernet_utils.rs +++ b/src/token/fernet_utils.rs @@ -30,7 +30,9 @@ impl FernetUtils { Ok(self.key_repository.exists()) } - pub fn load_keys(&self) -> Result, TokenProviderError> { + pub fn load_keys( + &self, + ) -> Result + use<>, TokenProviderError> { let mut keys: BTreeMap = BTreeMap::new(); if self.validate_key_repository()? { for entry in fs::read_dir(&self.key_repository)? { @@ -49,7 +51,7 @@ impl FernetUtils { } pub async fn load_keys_async( &self, - ) -> Result, TokenProviderError> { + ) -> Result + use<>, TokenProviderError> { let mut keys: BTreeMap = BTreeMap::new(); if self.validate_key_repository()? { let mut entries = fs_async::read_dir(&self.key_repository).await?; diff --git a/src/token/mod.rs b/src/token/mod.rs index a8b81579..db2602fb 100644 --- a/src/token/mod.rs +++ b/src/token/mod.rs @@ -14,6 +14,8 @@ use async_trait::async_trait; use chrono::{Local, TimeDelta}; +#[cfg(test)] +use mockall::mock; mod error; pub mod fernet; @@ -75,22 +77,22 @@ impl TokenApi for TokenProvider { } #[cfg(test)] -#[derive(Clone, Debug, Default)] -pub(crate) struct FakeTokenProvider {} +mock! { + pub TokenProvider { + pub fn new(cfg: &Config) -> Result; + } -#[cfg(test)] -#[async_trait] -impl TokenApi for FakeTokenProvider { - /// Validate token - #[tracing::instrument(level = "info", skip(self))] - async fn validate_token( - &self, - _credential: String, - _window_seconds: Option, - ) -> Result { - Ok(Token { - user_id: String::new(), - ..Default::default() - }) + #[async_trait] + impl TokenApi for TokenProvider { + async fn validate_token( + &self, + credential: String, + window_seconds: Option, + ) -> Result; } + + impl Clone for TokenProvider { + fn clone(&self) -> Self; + } + }