diff --git a/.github/workflows/linters.yml b/.github/workflows/linters.yml
index 8fffebe0..71ed1e5c 100644
--- a/.github/workflows/linters.yml
+++ b/.github/workflows/linters.yml
@@ -13,8 +13,6 @@ on:
- 'Cargo.toml'
- 'Cargo.lock'
- '.github/workflows/linters.yml'
- - 'keystone/**'
- - 'fuzz/**'
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}
@@ -22,7 +20,7 @@ concurrency:
env:
CARGO_TERM_COLOR: always
- rust_min: 1.76.0
+ rust_min: 1.85.0
jobs:
rustfmt:
diff --git a/Cargo.lock b/Cargo.lock
index 5d0723d7..ace92415 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -496,9 +496,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "cc"
-version = "1.2.14"
+version = "1.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0c3d1b2e905a3a7b00a6141adb0e4c0bb941d11caf55349d863942a1cc44e3c9"
+checksum = "c736e259eea577f443d5c86c304f9f4ae0295c43f3ba05c21f1d66b5f06001af"
dependencies = [
"jobserver",
"libc",
@@ -968,6 +968,12 @@ version = "0.15.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b"
+[[package]]
+name = "downcast"
+version = "0.11.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1"
+
[[package]]
name = "dyn-clone"
version = "1.0.18"
@@ -1116,6 +1122,12 @@ dependencies = [
"percent-encoding",
]
+[[package]]
+name = "fragile"
+version = "2.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6c2141d6d6c8512188a7891b4b01590a45f6dac67afb4f255c4124dbb86d4eaa"
+
[[package]]
name = "funty"
version = "2.0.0"
@@ -1635,9 +1647,9 @@ dependencies = [
[[package]]
name = "inout"
-version = "0.1.3"
+version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5"
+checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01"
dependencies = [
"generic-array",
]
@@ -1765,9 +1777,9 @@ checksum = "9374ef4228402d4b7e403e5838cb880d9ee663314b0a900d5a6aabf0c213552e"
[[package]]
name = "log"
-version = "0.4.25"
+version = "0.4.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f"
+checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e"
[[package]]
name = "matchit"
@@ -1836,6 +1848,44 @@ dependencies = [
"windows-sys 0.52.0",
]
+[[package]]
+name = "mockall"
+version = "0.13.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "39a6bfcc6c8c7eed5ee98b9c3e33adc726054389233e201c95dab2d41a3839d2"
+dependencies = [
+ "cfg-if",
+ "downcast",
+ "fragile",
+ "mockall_derive",
+ "predicates",
+ "predicates-tree",
+]
+
+[[package]]
+name = "mockall_derive"
+version = "0.13.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "25ca3004c2efe9011bd4e461bd8256445052b9615405b4f7ea43fc8ca5c20898"
+dependencies = [
+ "cfg-if",
+ "proc-macro2",
+ "quote",
+ "syn 2.0.98",
+]
+
+[[package]]
+name = "mockall_double"
+version = "0.3.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f1ca96e5ac35256ae3e13536edd39b172b88f41615e1d7b653c8ad24524113e8"
+dependencies = [
+ "cfg-if",
+ "proc-macro2",
+ "quote",
+ "syn 2.0.98",
+]
+
[[package]]
name = "nu-ansi-term"
version = "0.46.0"
@@ -1987,6 +2037,8 @@ dependencies = [
"eyre",
"fernet",
"http-body-util",
+ "mockall",
+ "mockall_double",
"regex",
"rmp",
"sea-orm",
@@ -2244,6 +2296,32 @@ dependencies = [
"zerocopy",
]
+[[package]]
+name = "predicates"
+version = "3.1.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a5d19ee57562043d37e82899fade9a22ebab7be9cef5026b07fda9cdd4293573"
+dependencies = [
+ "anstyle",
+ "predicates-core",
+]
+
+[[package]]
+name = "predicates-core"
+version = "1.0.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "727e462b119fe9c93fd0eb1429a5f7647394014cf3c04ab2c0350eeb09095ffa"
+
+[[package]]
+name = "predicates-tree"
+version = "1.0.12"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "72dd2d6d381dfb73a193c7fca536518d7caee39fc8503f74e7dc0be0531b425c"
+dependencies = [
+ "predicates-core",
+ "termtree",
+]
+
[[package]]
name = "proc-macro-crate"
version = "3.2.0"
@@ -3194,6 +3272,12 @@ dependencies = [
"windows-sys 0.59.0",
]
+[[package]]
+name = "termtree"
+version = "0.5.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683"
+
[[package]]
name = "thiserror"
version = "1.0.69"
@@ -3672,9 +3756,9 @@ checksum = "e2eebbbfe4093922c2b6734d7c679ebfebd704a0d7e56dfcb0d05818ce28977d"
[[package]]
name = "uuid"
-version = "1.13.2"
+version = "1.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8c1f41ffb7cf259f1ecc2876861a17e7142e63ead296f671f81f6ae85903e0d6"
+checksum = "93d59ca99a559661b96bf898d8fce28ed87935fd2bea9f05983c1464dd6c71b1"
dependencies = [
"getrandom 0.3.1",
"serde",
@@ -4192,27 +4276,27 @@ dependencies = [
[[package]]
name = "zstd"
-version = "0.13.2"
+version = "0.13.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9"
+checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a"
dependencies = [
"zstd-safe",
]
[[package]]
name = "zstd-safe"
-version = "7.2.1"
+version = "7.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059"
+checksum = "f3051792fbdc2e1e143244dc28c60f73d8470e93f3f9cbd0ead44da5ed802722"
dependencies = [
"zstd-sys",
]
[[package]]
name = "zstd-sys"
-version = "2.0.13+zstd.1.5.6"
+version = "2.0.14+zstd.1.5.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa"
+checksum = "8fb060d4926e4ac3a3ad15d864e99ceb5f343c6b34f5bd6d81ae6ed417311be5"
dependencies = [
"cc",
"pkg-config",
diff --git a/Cargo.toml b/Cargo.toml
index c9538b49..6c607dea 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,10 +1,10 @@
[package]
name = "openstack_keystone"
version = "0.1.0"
-edition = "2021"
+edition = "2024"
license = "Apache-2.0"
authors = ["Artem Goncharov (gtema)"]
-rust-version = "1.83" # MSRV
+rust-version = "1.85" # MSRV
repository = "https://github.com/gtema/keystone"
[[bin]]
@@ -29,6 +29,7 @@ derive_builder = { version = "^0.20" }
dyn-clone = { version = "^1.0" }
eyre = { version = "^0.6" }
fernet = { version = "^0.2" }
+mockall_double = { version = "^0.3" }
regex = { version = "^1.11"}
rmp = { version = "^0.8" }
sea-orm = { version = "^1.1", features = ["sqlx-mysql", "sqlx-postgres", "runtime-tokio"] }
@@ -44,11 +45,12 @@ tracing-subscriber = { version = "^0.3" }
utoipa = { version = "^5.3", features = ["axum_extras", "chrono"] }
utoipa-axum = { version = "^0.2" }
utoipa-swagger-ui = { version = "^9.0", features = ["axum", "vendored"], default-features = false }
-uuid = { version = "^1.13", features = ["v4"] }
+uuid = { version = "^1.14", features = ["v4"] }
[dev-dependencies]
criterion = { version = "^0.5", features = ["async_tokio"] }
http-body-util = "^0.1"
+mockall = { version = "^0.13" }
sea-orm = { version = "*", features = ["mock"]}
tempfile = { version = "^3.17" }
diff --git a/benches/fernet_token.rs b/benches/fernet_token.rs
index 0a75e07d..795bd129 100644
--- a/benches/fernet_token.rs
+++ b/benches/fernet_token.rs
@@ -1,4 +1,4 @@
-use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
+use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
use std::fs::File;
use std::io::Write;
use std::path::PathBuf;
diff --git a/loadtest/Cargo.toml b/loadtest/Cargo.toml
index 0a0c5cc3..c3326ab5 100644
--- a/loadtest/Cargo.toml
+++ b/loadtest/Cargo.toml
@@ -1,8 +1,9 @@
[package]
name = "load_test"
version = "0.1.0"
+rust-version = "1.85" # MSRV
edition = "2024"
[dependencies]
-goose = "0.17.2"
-tokio = "1.43.0"
+goose = { version = "^0.17" }
+tokio = { version = "^1.43" }
diff --git a/src/api/auth.rs b/src/api/auth.rs
index d71b985a..95f7d471 100644
--- a/src/api/auth.rs
+++ b/src/api/auth.rs
@@ -13,51 +13,45 @@
// SPDX-License-Identifier: Apache-2.0
use axum::{
- extract::{Request, State},
- http::StatusCode,
- middleware::Next,
- response::Response,
+ extract::{FromRef, FromRequestParts},
+ http::{StatusCode, request::Parts},
};
use std::sync::Arc;
use crate::keystone::ServiceState;
-use crate::provider::Provider;
use crate::token::{Token, TokenApi};
#[derive(Debug, Clone)]
-pub struct Auth {
- pub token: Token,
-}
+pub struct Auth(pub Token);
-pub async fn auth
(
- State(state): State>>,
- mut req: Request,
- next: Next,
-) -> Result
+impl FromRequestParts for Auth
where
- P: Provider,
+ ServiceState: FromRef,
+ S: Send + Sync,
{
- let auth_header = req
- .headers()
- .get("X-Auth-Token")
- .and_then(|header| header.to_str().ok());
-
- let auth_header = if let Some(auth_header) = auth_header {
- auth_header
- } else {
- return Err(StatusCode::UNAUTHORIZED);
- };
-
- // insert the current user into a request extension so the handler can
- // extract it
- state
- .provider
- .get_token_provider()
- .validate_token(auth_header.to_string(), None)
- .await
- .map(|token| {
- req.extensions_mut().insert(Auth { token });
- })
- .map_err(|_| StatusCode::UNAUTHORIZED)?;
- Ok(next.run(req).await)
+ type Rejection = (StatusCode, &'static str);
+
+ async fn from_request_parts(parts: &mut Parts, state: &S) -> Result {
+ let auth_header = parts
+ .headers
+ .get("X-Auth-Token")
+ .and_then(|header| header.to_str().ok());
+
+ let auth_header = if let Some(auth_header) = auth_header {
+ auth_header
+ } else {
+ return Err((StatusCode::UNAUTHORIZED, "not authorized"));
+ };
+
+ let state = Arc::from_ref(state);
+
+ Ok(Self(
+ state
+ .provider
+ .get_token_provider()
+ .validate_token(auth_header.to_string(), None)
+ .await
+ .map_err(|_| (StatusCode::UNAUTHORIZED, "not authorized"))?,
+ ))
+ }
}
diff --git a/src/api/error.rs b/src/api/error.rs
index 454f435a..cddd073c 100644
--- a/src/api/error.rs
+++ b/src/api/error.rs
@@ -13,9 +13,9 @@
// SPDX-License-Identifier: Apache-2.0
use axum::{
+ Json,
http::StatusCode,
response::{IntoResponse, Response},
- Json,
};
use serde_json::json;
use thiserror::Error;
diff --git a/src/api/mod.rs b/src/api/mod.rs
index 00f470cd..d0d0f90e 100644
--- a/src/api/mod.rs
+++ b/src/api/mod.rs
@@ -12,12 +12,10 @@
//
// SPDX-License-Identifier: Apache-2.0
-use std::sync::Arc;
use utoipa::OpenApi;
use utoipa_axum::router::OpenApiRouter;
use crate::keystone::ServiceState;
-use crate::provider::Provider;
pub mod auth;
pub mod error;
@@ -27,9 +25,6 @@ pub mod v3;
#[openapi(info(version = "3.14.0"))]
pub struct ApiDoc;
-pub fn openapi_router() -> OpenApiRouter>>
-where
- P: Provider + 'static,
-{
+pub fn openapi_router() -> OpenApiRouter {
OpenApiRouter::new().nest("/v3", v3::openapi_router())
}
diff --git a/src/api/v3/group/mod.rs b/src/api/v3/group/mod.rs
index befcbd83..132a0183 100644
--- a/src/api/v3/group/mod.rs
+++ b/src/api/v3/group/mod.rs
@@ -13,26 +13,22 @@
// SPDX-License-Identifier: Apache-2.0
use axum::{
+ Json, debug_handler,
extract::{Path, Query, State},
http::StatusCode,
response::IntoResponse,
- Json,
};
-use std::sync::Arc;
use utoipa_axum::{router::OpenApiRouter, routes};
+use crate::api::auth::Auth;
use crate::api::error::KeystoneApiError;
use crate::identity::IdentityApi;
use crate::keystone::ServiceState;
-use crate::provider::Provider;
use types::{Group, GroupCreateRequest, GroupList, GroupListParameters, GroupResponse};
mod types;
-pub(super) fn openapi_router() -> OpenApiRouter>>
-where
- P: Provider + 'static,
-{
+pub(super) fn openapi_router() -> OpenApiRouter {
OpenApiRouter::new()
.routes(routes!(list, create))
.routes(routes!(show, remove))
@@ -51,13 +47,11 @@ where
tag="groups"
)]
#[tracing::instrument(name = "api::group_list", level = "debug", skip(state))]
-async fn list(
+async fn list(
+ Auth(user_auth): Auth,
Query(query): Query,
- State(state): State>>,
-) -> Result
-where
- P: Provider,
-{
+ State(state): State,
+) -> Result {
let groups: Vec = state
.provider
.get_identity_provider()
@@ -83,17 +77,15 @@ where
tag="groups"
)]
#[tracing::instrument(name = "api::group_get", level = "debug", skip(state))]
-async fn show(
+async fn show(
+ Auth(user_auth): Auth,
Path(group_id): Path,
- State(state): State>>,
-) -> Result
-where
- P: Provider,
-{
+ State(state): State,
+) -> Result {
state
.provider
.get_identity_provider()
- .get_group(&state.db, &group_id)
+ .get_group(&state.db, group_id.clone())
.await
.map(|x| {
x.ok_or_else(|| KeystoneApiError::NotFound {
@@ -114,13 +106,12 @@ where
tag="groups"
)]
#[tracing::instrument(name = "api::create_group", level = "debug", skip(state))]
-async fn create(
- State(state): State>>,
+#[debug_handler]
+async fn create(
+ Auth(user_auth): Auth,
+ State(state): State,
Json(req): Json,
-) -> Result
-where
- P: Provider,
-{
+) -> Result {
let res = state
.provider
.get_identity_provider()
@@ -143,17 +134,15 @@ where
tag="groups"
)]
#[tracing::instrument(name = "api::group_delete", level = "debug", skip(state))]
-async fn remove(
+async fn remove(
+ Auth(user_auth): Auth,
Path(group_id): Path,
- State(state): State>>,
-) -> Result
-where
- P: Provider,
-{
+ State(state): State,
+) -> Result {
state
.provider
.get_identity_provider()
- .delete_group(&state.db, &group_id)
+ .delete_group(&state.db, group_id)
.await
.map_err(KeystoneApiError::identity)?;
Ok((StatusCode::NO_CONTENT).into_response())
@@ -163,68 +152,164 @@ where
mod tests {
use axum::{
body::Body,
- http::{self, Request, StatusCode},
+ http::{Request, StatusCode, header},
};
use http_body_util::BodyExt; // for `collect`
use sea_orm::DatabaseConnection;
- use std::sync::Arc;
+ use serde_json::json;
+
use tower::ServiceExt; // for `call`, `oneshot`, and `ready`
use tower_http::trace::TraceLayer;
use super::openapi_router;
- use crate::api::v3::group::types::*;
- use crate::config::Config;
- use crate::identity::IdentityApi;
- use crate::keystone::ServiceState;
- use crate::provider::{FakeProviderApi, Provider};
+ use crate::api::v3::group::types::{
+ Group as ApiGroup, GroupCreate as ApiGroupCreate, GroupCreateRequest, GroupList,
+ GroupResponse,
+ };
+ use crate::identity::{
+ MockIdentityProvider,
+ error::IdentityProviderError,
+ types::{Group, GroupCreate, GroupListParameters},
+ };
+
+ use crate::tests::api::{get_mocked_state, get_mocked_state_unauthed};
#[tokio::test]
async fn test_list() {
- let db = DatabaseConnection::Disconnected;
- let config = Config::default();
- let provider = FakeProviderApi::new(config.clone()).unwrap();
- let state = Arc::new(ServiceState::new(config, db, provider).unwrap());
- let mut api = openapi_router().with_state(state);
+ let mut identity_mock = MockIdentityProvider::default();
+ identity_mock
+ .expect_list_groups()
+ .withf(|_: &DatabaseConnection, _: &GroupListParameters| true)
+ .returning(|_, _| {
+ Ok(vec![Group {
+ id: "1".into(),
+ name: "2".into(),
+ ..Default::default()
+ }])
+ });
+
+ let state = get_mocked_state(identity_mock);
+
+ let mut api = openapi_router()
+ .layer(TraceLayer::new_for_http())
+ .with_state(state);
let response = api
.as_service()
- .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
+ .oneshot(
+ Request::builder()
+ .uri("/")
+ .header("x-auth-token", "foo")
+ .body(Body::empty())
+ .unwrap(),
+ )
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = response.into_body().collect().await.unwrap().to_bytes();
- let _res: GroupList = serde_json::from_slice(&body).unwrap();
+ let res: GroupList = serde_json::from_slice(&body).unwrap();
+ assert_eq!(
+ vec![ApiGroup {
+ id: "1".into(),
+ name: "2".into(),
+ // for some reason when deserializing missing value appears still as an empty
+ // object
+ extra: Some(json!({})),
+ ..Default::default()
+ }],
+ res.groups
+ );
}
#[tokio::test]
- async fn test_get() {
- let db = DatabaseConnection::Disconnected;
- let config = Config::default();
- let provider = FakeProviderApi::new(config.clone()).unwrap();
- let state = Arc::new(ServiceState::new(config, db, provider).unwrap());
+ async fn test_list_qp() {
+ let mut identity_mock = MockIdentityProvider::default();
+ identity_mock
+ .expect_list_groups()
+ .withf(|_: &DatabaseConnection, qp: &GroupListParameters| {
+ GroupListParameters {
+ domain_id: Some("domain".into()),
+ name: Some("name".into()),
+ } == *qp
+ })
+ .returning(|_, _| Ok(Vec::new()));
+
+ let state = get_mocked_state(identity_mock);
let mut api = openapi_router()
.layer(TraceLayer::new_for_http())
- .with_state(state.clone());
+ .with_state(state);
- let group = crate::identity::types::GroupCreate {
- domain_id: "domain".into(),
- name: "name".into(),
- ..Default::default()
- };
+ let response = api
+ .as_service()
+ .oneshot(
+ Request::builder()
+ .uri("/?domain_id=domain&name=name")
+ .header("x-auth-token", "foo")
+ .body(Body::empty())
+ .unwrap(),
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(response.status(), StatusCode::OK);
+
+ let body = response.into_body().collect().await.unwrap().to_bytes();
+ let _res: GroupList = serde_json::from_slice(&body).unwrap();
+ }
- let created_group = state
- .provider
- .get_identity_provider()
- .create_group(&DatabaseConnection::Disconnected, group)
+ #[tokio::test]
+ async fn test_list_unauth() {
+ let state = get_mocked_state_unauthed();
+
+ let mut api = openapi_router()
+ .layer(TraceLayer::new_for_http())
+ .with_state(state);
+
+ let response = api
+ .as_service()
+ .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
+ assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
+ }
+
+ #[tokio::test]
+ async fn test_get() {
+ let mut identity_mock = MockIdentityProvider::default();
+ identity_mock
+ .expect_get_group()
+ .withf(|_: &DatabaseConnection, id: &String| *id == "foo")
+ .returning(|_, _| Ok(None));
+
+ identity_mock
+ .expect_get_group()
+ .withf(|_: &DatabaseConnection, id: &String| *id == "bar")
+ .returning(|_, _| {
+ Ok(Some(Group {
+ id: "bar".into(),
+ ..Default::default()
+ }))
+ });
+
+ let state = get_mocked_state(identity_mock);
+
+ let mut api = openapi_router()
+ .layer(TraceLayer::new_for_http())
+ .with_state(state.clone());
+
let response = api
.as_service()
- .oneshot(Request::builder().uri("/foo").body(Body::empty()).unwrap())
+ .oneshot(
+ Request::builder()
+ .uri("/foo")
+ .header("x-auth-token", "foo")
+ .body(Body::empty())
+ .unwrap(),
+ )
.await
.unwrap();
@@ -234,7 +319,8 @@ mod tests {
.as_service()
.oneshot(
Request::builder()
- .uri(format!("/{}", created_group.id))
+ .uri("/bar")
+ .header("x-auth-token", "foo")
.body(Body::empty())
.unwrap(),
)
@@ -244,22 +330,42 @@ mod tests {
assert_eq!(response.status(), StatusCode::OK);
let body = response.into_body().collect().await.unwrap().to_bytes();
- let _user: GroupResponse = serde_json::from_slice(&body).unwrap();
+ let res: GroupResponse = serde_json::from_slice(&body).unwrap();
+ assert_eq!(
+ ApiGroup {
+ id: "bar".into(),
+ extra: Some(json!({})),
+ ..Default::default()
+ },
+ res.group,
+ );
}
#[tokio::test]
async fn test_create() {
- let db = DatabaseConnection::Disconnected;
- let config = Config::default();
- let provider = FakeProviderApi::new(config.clone()).unwrap();
- let state = Arc::new(ServiceState::new(config, db, provider).unwrap());
+ let mut identity_mock = MockIdentityProvider::default();
+ identity_mock
+ .expect_create_group()
+ .withf(|_: &DatabaseConnection, req: &GroupCreate| {
+ req.domain_id == "domain" && req.name == "name"
+ })
+ .returning(|_, req| {
+ Ok(Group {
+ id: "bar".into(),
+ domain_id: req.domain_id,
+ name: req.name,
+ ..Default::default()
+ })
+ });
+
+ let state = get_mocked_state(identity_mock);
let mut api = openapi_router()
.layer(TraceLayer::new_for_http())
.with_state(state.clone());
let req = GroupCreateRequest {
- group: GroupCreate {
+ group: ApiGroupCreate {
domain_id: "domain".into(),
name: "name".into(),
..Default::default()
@@ -271,8 +377,9 @@ mod tests {
.oneshot(
Request::builder()
.method("POST")
- .header(http::header::CONTENT_TYPE, "application/json")
+ .header(header::CONTENT_TYPE, "application/json")
.uri("/")
+ .header("x-auth-token", "foo")
.body(Body::from(serde_json::to_string(&req).unwrap()))
.unwrap(),
)
@@ -284,38 +391,35 @@ mod tests {
let body = response.into_body().collect().await.unwrap().to_bytes();
let res: GroupResponse = serde_json::from_slice(&body).unwrap();
assert_eq!(res.group.name, req.group.name);
+ assert_eq!(res.group.domain_id, req.group.domain_id);
}
#[tokio::test]
async fn test_delete() {
- let db = DatabaseConnection::Disconnected;
- let config = Config::default();
- let provider = FakeProviderApi::new(config.clone()).unwrap();
- let state = Arc::new(ServiceState::new(config, db, provider).unwrap());
+ let mut identity_mock = MockIdentityProvider::default();
+ identity_mock
+ .expect_delete_group()
+ .withf(|_: &DatabaseConnection, id: &String| *id == "foo")
+ .returning(|_, _| Err(IdentityProviderError::GroupNotFound("foo".into())));
+
+ identity_mock
+ .expect_delete_group()
+ .withf(|_: &DatabaseConnection, id: &String| *id == "bar")
+ .returning(|_, _| Ok(()));
+
+ let state = get_mocked_state(identity_mock);
let mut api = openapi_router()
.layer(TraceLayer::new_for_http())
.with_state(state.clone());
- let group = crate::identity::types::GroupCreate {
- domain_id: "domain".into(),
- name: "name".into(),
- ..Default::default()
- };
-
- let created_group = state
- .provider
- .get_identity_provider()
- .create_group(&DatabaseConnection::Disconnected, group)
- .await
- .unwrap();
-
let response = api
.as_service()
.oneshot(
Request::builder()
.method("DELETE")
.uri("/foo")
+ .header("x-auth-token", "foo")
.body(Body::empty())
.unwrap(),
)
@@ -329,7 +433,8 @@ mod tests {
.oneshot(
Request::builder()
.method("DELETE")
- .uri(format!("/{}", created_group.id))
+ .uri("/bar")
+ .header("x-auth-token", "foo")
.body(Body::empty())
.unwrap(),
)
diff --git a/src/api/v3/group/types.rs b/src/api/v3/group/types.rs
index 3faa8c7a..391ddb1a 100644
--- a/src/api/v3/group/types.rs
+++ b/src/api/v3/group/types.rs
@@ -13,9 +13,9 @@
// SPDX-License-Identifier: Apache-2.0
use axum::{
+ Json,
http::StatusCode,
response::{IntoResponse, Response},
- Json,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
@@ -23,7 +23,7 @@ use utoipa::{IntoParams, ToSchema};
use crate::identity::types;
-#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)]
+#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)]
pub struct Group {
/// Group ID
pub id: String,
@@ -33,17 +33,17 @@ pub struct Group {
pub name: String,
/// Group description
pub description: Option,
- #[serde(flatten)]
+ #[serde(flatten, skip_serializing_if = "Option::is_none")]
pub extra: Option,
}
-#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)]
+#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)]
pub struct GroupResponse {
/// group object
pub group: Group,
}
-#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)]
+#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)]
pub struct GroupCreate {
/// Group domain ID
pub domain_id: String,
@@ -51,11 +51,11 @@ pub struct GroupCreate {
pub name: String,
/// Group description
pub description: Option,
- #[serde(flatten)]
+ #[serde(default, flatten, skip_serializing_if = "Option::is_none")]
pub extra: Option,
}
-#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)]
+#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)]
pub struct GroupCreateRequest {
/// Group object
pub group: GroupCreate,
@@ -77,7 +77,7 @@ impl From for types::GroupCreate {
fn from(value: GroupCreateRequest) -> Self {
let group = value.group;
Self {
- id: String::new(),
+ id: None,
name: group.name,
domain_id: group.domain_id,
extra: group.extra,
@@ -105,7 +105,7 @@ impl IntoResponse for types::Group {
}
/// Groups
-#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)]
+#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)]
pub struct GroupList {
/// Collection of group objects
pub groups: Vec,
diff --git a/src/api/v3/mod.rs b/src/api/v3/mod.rs
index 31777f02..5c4afc38 100644
--- a/src/api/v3/mod.rs
+++ b/src/api/v3/mod.rs
@@ -12,19 +12,14 @@
//
// SPDX-License-Identifier: Apache-2.0
-use std::sync::Arc;
use utoipa_axum::router::OpenApiRouter;
use crate::keystone::ServiceState;
-use crate::provider::Provider;
pub mod group;
pub mod user;
-pub(super) fn openapi_router() -> OpenApiRouter>>
-where
- P: Provider + 'static,
-{
+pub(super) fn openapi_router() -> OpenApiRouter {
OpenApiRouter::new()
.nest("/users", user::openapi_router())
.nest("/groups", group::openapi_router())
diff --git a/src/api/v3/user/mod.rs b/src/api/v3/user/mod.rs
index dedfdca6..4e9423b3 100644
--- a/src/api/v3/user/mod.rs
+++ b/src/api/v3/user/mod.rs
@@ -13,27 +13,22 @@
// SPDX-License-Identifier: Apache-2.0
use axum::{
- extract::{Extension, Path, Query, State},
+ Json,
+ extract::{Path, Query, State},
http::StatusCode,
response::IntoResponse,
- Json,
};
-use std::sync::Arc;
use utoipa_axum::{router::OpenApiRouter, routes};
use crate::api::auth::Auth;
use crate::api::error::KeystoneApiError;
use crate::identity::IdentityApi;
use crate::keystone::ServiceState;
-use crate::provider::Provider;
use types::{User, UserCreateRequest, UserList, UserListParameters, UserResponse};
mod types;
-pub(super) fn openapi_router() -> OpenApiRouter>>
-where
- P: Provider + 'static,
-{
+pub(super) fn openapi_router() -> OpenApiRouter {
OpenApiRouter::new()
.routes(routes!(list, create))
.routes(routes!(show, remove))
@@ -52,14 +47,11 @@ where
tag="users"
)]
#[tracing::instrument(name = "api::user_list", level = "debug", skip(state))]
-async fn list(
- Extension(current_user): Extension,
+async fn list(
+ Auth(user_auth): Auth,
Query(query): Query,
- State(state): State>>,
-) -> Result
-where
- P: Provider,
-{
+ State(state): State,
+) -> Result {
let users: Vec = state
.provider
.get_identity_provider()
@@ -84,17 +76,15 @@ where
tag="users"
)]
#[tracing::instrument(name = "api::user_get", level = "debug", skip(state))]
-async fn show(
+async fn show(
+ Auth(user_auth): Auth,
Path(user_id): Path,
- State(state): State>>,
-) -> Result
-where
- P: Provider,
-{
+ State(state): State,
+) -> Result {
state
.provider
.get_identity_provider()
- .get_user(&state.db, &user_id)
+ .get_user(&state.db, user_id.clone())
.await
.map(|x| {
x.ok_or_else(|| KeystoneApiError::NotFound {
@@ -115,14 +105,12 @@ where
tag="users"
)]
#[tracing::instrument(name = "api::create_user", level = "debug", skip(state))]
-async fn create(
+async fn create(
+ Auth(user_auth): Auth,
Query(query): Query,
- State(state): State>>,
+ State(state): State,
Json(req): Json,
-) -> Result
-where
- P: Provider,
-{
+) -> Result {
let user = state
.provider
.get_identity_provider()
@@ -145,17 +133,15 @@ where
tag="users"
)]
#[tracing::instrument(name = "api::user_delete", level = "debug", skip(state))]
-async fn remove(
+async fn remove(
+ Auth(user_auth): Auth,
Path(user_id): Path,
- State(state): State>>,
-) -> Result
-where
- P: Provider,
-{
+ State(state): State,
+) -> Result {
state
.provider
.get_identity_provider()
- .delete_user(&state.db, &user_id)
+ .delete_user(&state.db, user_id)
.await
.map_err(KeystoneApiError::identity)?;
Ok((StatusCode::NO_CONTENT).into_response())
@@ -166,31 +152,44 @@ mod tests {
use axum::{
body::Body,
http::{self, Request, StatusCode},
- middleware,
};
use http_body_util::BodyExt; // for `collect`
use sea_orm::DatabaseConnection;
- use std::sync::Arc;
+ use serde_json::json;
+
use tower::ServiceExt; // for `call`, `oneshot`, and `ready`
use tower_http::trace::TraceLayer;
use super::openapi_router;
- use crate::api::auth::auth;
- use crate::api::v3::user::types::*;
- use crate::config::Config;
- use crate::identity::IdentityApi;
- use crate::keystone::ServiceState;
- use crate::provider::{FakeProviderApi, Provider};
+ use crate::api::v3::user::types::{
+ User as ApiUser, UserCreate as ApiUserCreate, UserCreateRequest, UserList, UserResponse,
+ };
+ use crate::identity::{
+ MockIdentityProvider,
+ error::IdentityProviderError,
+ types::{User, UserCreate, UserListParameters},
+ };
+
+ use crate::tests::api::{get_mocked_state, get_mocked_state_unauthed};
#[tokio::test]
async fn test_list() {
- let db = DatabaseConnection::Disconnected;
- let config = Config::default();
- let provider = FakeProviderApi::new(config.clone()).unwrap();
- let state = Arc::new(ServiceState::new(config, db, provider).unwrap());
+ let mut identity_mock = MockIdentityProvider::default();
+ identity_mock
+ .expect_list_users()
+ .withf(|_: &DatabaseConnection, _: &UserListParameters| true)
+ .returning(|_, _| {
+ Ok(vec![User {
+ id: "1".into(),
+ name: "2".into(),
+ ..Default::default()
+ }])
+ });
+
+ let state = get_mocked_state(identity_mock);
+
let mut api = openapi_router()
.layer(TraceLayer::new_for_http())
- .route_layer(middleware::from_fn_with_state(state.clone(), auth))
.with_state(state);
let response = api
@@ -208,22 +207,98 @@ mod tests {
assert_eq!(response.status(), StatusCode::OK);
let body = response.into_body().collect().await.unwrap().to_bytes();
- let _users: UserList = serde_json::from_slice(&body).unwrap();
+ let res: UserList = serde_json::from_slice(&body).unwrap();
+ assert_eq!(
+ vec![ApiUser {
+ id: "1".into(),
+ name: "2".into(),
+ // object
+ extra: Some(json!({})),
+ ..Default::default()
+ }],
+ res.users
+ );
+ }
+
+ #[tokio::test]
+ async fn test_list_qp() {
+ let mut identity_mock = MockIdentityProvider::default();
+ identity_mock
+ .expect_list_users()
+ .withf(|_: &DatabaseConnection, qp: &UserListParameters| {
+ UserListParameters {
+ domain_id: Some("domain".into()),
+ name: Some("name".into()),
+ } == *qp
+ })
+ .returning(|_, _| Ok(Vec::new()));
+
+ let state = get_mocked_state(identity_mock);
+
+ let mut api = openapi_router()
+ .layer(TraceLayer::new_for_http())
+ .with_state(state);
+
+ let response = api
+ .as_service()
+ .oneshot(
+ Request::builder()
+ .uri("/?domain_id=domain&name=name")
+ .header("x-auth-token", "foo")
+ .body(Body::empty())
+ .unwrap(),
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(response.status(), StatusCode::OK);
+
+ let body = response.into_body().collect().await.unwrap().to_bytes();
+ let _res: UserList = serde_json::from_slice(&body).unwrap();
+ }
+
+ #[tokio::test]
+ async fn test_list_unauth() {
+ let state = get_mocked_state_unauthed();
+
+ let mut api = openapi_router()
+ .layer(TraceLayer::new_for_http())
+ .with_state(state);
+
+ let response = api
+ .as_service()
+ .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
+ .await
+ .unwrap();
+
+ assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_create() {
- let db = DatabaseConnection::Disconnected;
- let config = Config::default();
- let provider = FakeProviderApi::new(config.clone()).unwrap();
- let state = Arc::new(ServiceState::new(config, db, provider).unwrap());
+ let mut identity_mock = MockIdentityProvider::default();
+ identity_mock
+ .expect_create_user()
+ .withf(|_: &DatabaseConnection, req: &UserCreate| {
+ req.domain_id == "domain" && req.name == "name"
+ })
+ .returning(|_, req| {
+ Ok(User {
+ id: "bar".into(),
+ domain_id: req.domain_id,
+ name: req.name,
+ ..Default::default()
+ })
+ });
+
+ let state = get_mocked_state(identity_mock);
let mut api = openapi_router()
.layer(TraceLayer::new_for_http())
.with_state(state.clone());
let user = UserCreateRequest {
- user: UserCreate {
+ user: ApiUserCreate {
domain_id: "domain".into(),
name: "name".into(),
..Default::default()
@@ -237,6 +312,7 @@ mod tests {
.method("POST")
.header(http::header::CONTENT_TYPE, "application/json")
.uri("/")
+ .header("x-auth-token", "foo")
.body(Body::from(serde_json::to_string(&user).unwrap()))
.unwrap(),
)
@@ -252,31 +328,37 @@ mod tests {
#[tokio::test]
async fn test_get() {
- let db = DatabaseConnection::Disconnected;
- let config = Config::default();
- let provider = FakeProviderApi::new(config.clone()).unwrap();
- let state = Arc::new(ServiceState::new(config, db, provider).unwrap());
+ let mut identity_mock = MockIdentityProvider::default();
+ identity_mock
+ .expect_get_user()
+ .withf(|_: &DatabaseConnection, id: &String| *id == "foo")
+ .returning(|_, _| Ok(None));
+
+ identity_mock
+ .expect_get_user()
+ .withf(|_: &DatabaseConnection, id: &String| *id == "bar")
+ .returning(|_, _| {
+ Ok(Some(User {
+ id: "bar".into(),
+ ..Default::default()
+ }))
+ });
+
+ let state = get_mocked_state(identity_mock);
let mut api = openapi_router()
.layer(TraceLayer::new_for_http())
.with_state(state.clone());
- let user = crate::identity::types::UserCreate {
- domain_id: "domain".into(),
- name: "name".into(),
- ..Default::default()
- };
-
- let created_user = state
- .provider
- .get_identity_provider()
- .create_user(&DatabaseConnection::Disconnected, user)
- .await
- .unwrap();
-
let response = api
.as_service()
- .oneshot(Request::builder().uri("/foo").body(Body::empty()).unwrap())
+ .oneshot(
+ Request::builder()
+ .uri("/foo")
+ .header("x-auth-token", "foo")
+ .body(Body::empty())
+ .unwrap(),
+ )
.await
.unwrap();
@@ -286,7 +368,8 @@ mod tests {
.as_service()
.oneshot(
Request::builder()
- .uri(format!("/{}", created_user.id))
+ .uri("/bar")
+ .header("x-auth-token", "foo")
.body(Body::empty())
.unwrap(),
)
@@ -296,39 +379,43 @@ mod tests {
assert_eq!(response.status(), StatusCode::OK);
let body = response.into_body().collect().await.unwrap().to_bytes();
- let _user: UserResponse = serde_json::from_slice(&body).unwrap();
+ let res: UserResponse = serde_json::from_slice(&body).unwrap();
+ assert_eq!(
+ ApiUser {
+ id: "bar".into(),
+ extra: Some(json!({})),
+ ..Default::default()
+ },
+ res.user,
+ );
}
#[tokio::test]
async fn test_delete() {
- let db = DatabaseConnection::Disconnected;
- let config = Config::default();
- let provider = FakeProviderApi::new(config.clone()).unwrap();
- let state = Arc::new(ServiceState::new(config, db, provider).unwrap());
+ let mut identity_mock = MockIdentityProvider::default();
+ identity_mock
+ .expect_delete_user()
+ .withf(|_: &DatabaseConnection, id: &String| *id == "foo")
+ .returning(|_, _| Err(IdentityProviderError::UserNotFound("foo".into())));
+
+ identity_mock
+ .expect_delete_user()
+ .withf(|_: &DatabaseConnection, id: &String| *id == "bar")
+ .returning(|_, _| Ok(()));
+
+ let state = get_mocked_state(identity_mock);
let mut api = openapi_router()
.layer(TraceLayer::new_for_http())
.with_state(state.clone());
- let user = crate::identity::types::UserCreate {
- domain_id: "domain".into(),
- name: "name".into(),
- ..Default::default()
- };
-
- let created_user = state
- .provider
- .get_identity_provider()
- .create_user(&DatabaseConnection::Disconnected, user)
- .await
- .unwrap();
-
let response = api
.as_service()
.oneshot(
Request::builder()
.method("DELETE")
.uri("/foo")
+ .header("x-auth-token", "foo")
.body(Body::empty())
.unwrap(),
)
@@ -342,7 +429,8 @@ mod tests {
.oneshot(
Request::builder()
.method("DELETE")
- .uri(format!("/{}", created_user.id))
+ .uri("/bar")
+ .header("x-auth-token", "foo")
.body(Body::empty())
.unwrap(),
)
diff --git a/src/api/v3/user/types.rs b/src/api/v3/user/types.rs
index 29ef9111..ca153876 100644
--- a/src/api/v3/user/types.rs
+++ b/src/api/v3/user/types.rs
@@ -13,9 +13,9 @@
// SPDX-License-Identifier: Apache-2.0
use axum::{
+ Json,
http::StatusCode,
response::{IntoResponse, Response},
- Json,
};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
@@ -24,7 +24,7 @@ use utoipa::{IntoParams, ToSchema};
use crate::identity::types;
-#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)]
+#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)]
pub struct User {
/// User ID
pub id: String,
@@ -41,25 +41,28 @@ pub struct User {
/// default project, the default project is ignored at token creation. (Since v3.1)
/// Additionally, if your default project is not valid, a token is issued without an explicit
/// scope of authorization.
+ #[serde(skip_serializing_if = "Option::is_none")]
pub default_project_id: Option,
- #[serde(flatten)]
+ #[serde(flatten, skip_serializing_if = "Option::is_none")]
pub extra: Option,
/// The date and time when the password expires. The time zone is UTC.
+ #[serde(skip_serializing_if = "Option::is_none")]
pub password_expires_at: Option>,
/// The resource options for the user. Available resource options are
/// ignore_change_password_upon_first_use, ignore_password_expiry,
/// ignore_lockout_failure_attempts, lock_password, multi_factor_auth_enabled, and
/// multi_factor_auth_rules ignore_user_inactivity.
+ #[serde(skip_serializing_if = "Option::is_none")]
pub options: Option,
}
-#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)]
+#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)]
pub struct UserResponse {
/// User object
pub user: User,
}
-#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)]
+#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)]
pub struct UserCreate {
/// User domain ID
pub domain_id: String,
@@ -87,7 +90,7 @@ pub struct UserCreate {
pub extra: Option,
}
-#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)]
+#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)]
pub struct UserUpdateRequest {
/// The user name. Must be unique within the owning domain.
pub name: Option,
@@ -113,7 +116,7 @@ pub struct UserUpdateRequest {
pub extra: Option,
}
-#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)]
+#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)]
pub struct UserOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub ignore_change_password_upon_first_use: Option,
@@ -159,7 +162,7 @@ impl From for types::UserOptions {
}
}
-#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)]
+#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)]
pub struct UserCreateRequest {
/// User object
pub user: UserCreate,
@@ -167,6 +170,20 @@ pub struct UserCreateRequest {
impl From for User {
fn from(value: types::User) -> Self {
+ let opts: UserOptions = value.options.clone().into();
+ // We only want to see user options if there is at least 1 option set
+ let opts = if opts.ignore_change_password_upon_first_use.is_some()
+ || opts.ignore_password_expiry.is_some()
+ || opts.ignore_lockout_failure_attempts.is_some()
+ || opts.lock_password.is_some()
+ || opts.ignore_user_inactivity.is_some()
+ || opts.multi_factor_auth_rules.is_some()
+ || opts.multi_factor_auth_enabled.is_some()
+ {
+ Some(opts)
+ } else {
+ None
+ };
Self {
id: value.id,
domain_id: value.domain_id,
@@ -175,7 +192,7 @@ impl From for User {
default_project_id: value.default_project_id,
extra: value.extra,
password_expires_at: value.password_expires_at,
- options: Some(value.options.into()),
+ options: opts,
}
}
}
@@ -216,7 +233,7 @@ impl IntoResponse for types::User {
}
/// Users
-#[derive(Clone, Debug, Default, Deserialize, Serialize, ToSchema)]
+#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, ToSchema)]
pub struct UserList {
/// Collection of user objects
pub users: Vec,
@@ -235,7 +252,7 @@ impl IntoResponse for UserList {
}
}
-#[derive(Clone, Debug, Default, Deserialize, Serialize, IntoParams)]
+#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize, IntoParams)]
pub struct UserListParameters {
/// Filter users by Domain ID
pub domain_id: Option,
diff --git a/src/bin/keystone.rs b/src/bin/keystone.rs
index cf1978cb..889c6b85 100644
--- a/src/bin/keystone.rs
+++ b/src/bin/keystone.rs
@@ -12,10 +12,7 @@
//
// SPDX-License-Identifier: Apache-2.0
-use axum::{
- http::{self, header, HeaderName, Request},
- middleware::{self},
-};
+use axum::http::{self, HeaderName, Request, header};
use clap::Parser;
use color_eyre::eyre::{Report, Result};
use sea_orm::ConnectOptions;
@@ -23,15 +20,14 @@ use sea_orm::Database;
use std::io;
use std::net::{Ipv4Addr, SocketAddr};
use std::sync::Arc;
-use tokio::net::TcpListener;
-use tokio::signal;
+use tokio::{net::TcpListener, signal};
use tower::ServiceBuilder;
use tower_http::{
+ LatencyUnit, ServiceBuilderExt,
request_id::{MakeRequestId, PropagateRequestIdLayer, RequestId, SetRequestIdLayer},
trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer},
- LatencyUnit, ServiceBuilderExt,
};
-use tracing::{info_span, Level};
+use tracing::{Level, info_span};
use tracing_subscriber::{filter::LevelFilter, prelude::*};
use utoipa::OpenApi;
use utoipa_axum::router::OpenApiRouter;
@@ -39,11 +35,10 @@ use utoipa_swagger_ui::SwaggerUi;
use uuid::Uuid;
use openstack_keystone::api;
-use openstack_keystone::api::auth::auth;
use openstack_keystone::config::Config;
-use openstack_keystone::keystone::ServiceState;
+use openstack_keystone::keystone::{Service, ServiceState};
use openstack_keystone::plugin_manager::PluginManager;
-use openstack_keystone::provider::{Provider, ProviderApi};
+use openstack_keystone::provider::Provider;
/// Simple program to greet a person
#[derive(Parser, Debug)]
@@ -101,12 +96,12 @@ async fn main() -> Result<(), Report> {
let plugin_manager = PluginManager::default();
- let provider = ProviderApi::new(cfg.clone(), plugin_manager)?;
+ let provider = Provider::new(cfg.clone(), plugin_manager)?;
- let shared_state = Arc::new(ServiceState::new(cfg, conn, provider).unwrap());
+ let shared_state = Arc::new(Service::new(cfg, conn, provider).unwrap());
let (router, api) = OpenApiRouter::with_openapi(api::ApiDoc::openapi())
- .merge(api::openapi_router::())
+ .merge(api::openapi_router())
.split_for_parts();
let x_request_id = HeaderName::from_static("x-openstack-request-id");
@@ -147,21 +142,18 @@ async fn main() -> Result<(), Report> {
.layer(PropagateRequestIdLayer::new(x_request_id));
let app = router
- // Router::new()
- //.merge(api::router())
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", api))
- .route_layer(middleware::from_fn_with_state(shared_state.clone(), auth))
.layer(middleware)
.with_state(shared_state.clone());
let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, 8080));
let listener = TcpListener::bind(&address).await?;
Ok(axum::serve(listener, app.into_make_service())
- .with_graceful_shutdown(shutdown_signal(shared_state.clone()))
+ .with_graceful_shutdown(shutdown_signal(shared_state))
.await?)
}
-async fn shutdown_signal(state: Arc>) {
+async fn shutdown_signal(state: ServiceState) {
let ctrl_c = async {
signal::ctrl_c()
.await
diff --git a/src/identity/backends/sql.rs b/src/identity/backends/sql.rs
index 85672869..56000e69 100644
--- a/src/identity/backends/sql.rs
+++ b/src/identity/backends/sql.rs
@@ -13,9 +13,9 @@
// SPDX-License-Identifier: Apache-2.0
use async_trait::async_trait;
+use sea_orm::DatabaseConnection;
use sea_orm::entity::*;
use sea_orm::query::*;
-use sea_orm::DatabaseConnection;
mod common;
mod federated_user;
@@ -33,9 +33,9 @@ use crate::db::entity::{
prelude::{FederatedUser, LocalUser, NonlocalUser, User as DbUser, UserOption},
user as db_user, user_option as db_user_option,
};
+use crate::identity::IdentityProviderError;
use crate::identity::backends::error::IdentityDatabaseError;
use crate::identity::password_hashing;
-use crate::identity::IdentityProviderError;
#[derive(Clone, Debug, Default)]
pub struct SqlBackend {
@@ -231,26 +231,29 @@ pub async fn get_user(
if let Some(user) = &user_entry {
let user_opts: Vec = user.find_related(UserOption).all(db).await?;
- let user_builder: UserBuilder = if let Some(local_user_with_passwords) =
- local_user::load_local_user_with_passwords(db, &user_id).await?
- {
- common::get_local_user_builder(
- conf,
- user,
- local_user_with_passwords.0,
- Some(local_user_with_passwords.1),
- user_opts,
- )
- } else if let Some(nonlocal_user) = user.find_related(NonlocalUser).one(db).await? {
- common::get_nonlocal_user_builder(user, nonlocal_user, user_opts)
- } else {
- let federated_user = user.find_related(FederatedUser).all(db).await?;
- if !federated_user.is_empty() {
- common::get_federated_user_builder(user, federated_user, user_opts)
- } else {
- return Err(IdentityDatabaseError::MalformedUser(user_id.clone()))?;
- }
- };
+ let user_builder: UserBuilder =
+ match local_user::load_local_user_with_passwords(db, &user_id).await? {
+ Some(local_user_with_passwords) => common::get_local_user_builder(
+ conf,
+ user,
+ local_user_with_passwords.0,
+ Some(local_user_with_passwords.1),
+ user_opts,
+ ),
+ _ => match user.find_related(NonlocalUser).one(db).await? {
+ Some(nonlocal_user) => {
+ common::get_nonlocal_user_builder(user, nonlocal_user, user_opts)
+ }
+ _ => {
+ let federated_user = user.find_related(FederatedUser).all(db).await?;
+ if !federated_user.is_empty() {
+ common::get_federated_user_builder(user, federated_user, user_opts)
+ } else {
+ return Err(IdentityDatabaseError::MalformedUser(user_id.clone()))?;
+ }
+ }
+ },
+ };
return Ok(Some(user_builder.build()?));
}
diff --git a/src/identity/backends/sql/group.rs b/src/identity/backends/sql/group.rs
index 5c9b8c08..f5bde82c 100644
--- a/src/identity/backends/sql/group.rs
+++ b/src/identity/backends/sql/group.rs
@@ -12,15 +12,16 @@
//
// SPDX-License-Identifier: Apache-2.0
+use sea_orm::DatabaseConnection;
use sea_orm::entity::*;
use sea_orm::query::*;
-use sea_orm::DatabaseConnection;
use serde_json::Value;
+use serde_json::json;
use crate::db::entity::{group, prelude::Group as DbGroup};
+use crate::identity::Config;
use crate::identity::backends::sql::IdentityDatabaseError;
use crate::identity::types::{Group, GroupCreate, GroupListParameters};
-use crate::identity::Config;
pub async fn list(
_conf: &Config,
@@ -60,7 +61,7 @@ pub async fn create(
group: GroupCreate,
) -> Result {
let entry = group::ActiveModel {
- id: Set(group.id.clone()),
+ id: Set(group.id.clone().unwrap_or_default()),
domain_id: Set(group.domain_id.clone()),
name: Set(group.name.clone()),
description: Set(group.description.clone()),
@@ -94,7 +95,7 @@ impl From for Group {
domain_id: value.domain_id.clone(),
extra: value
.extra
- .map(|x| serde_json::from_str::(&x).unwrap_or(Value::Null)),
+ .map(|x| serde_json::from_str::(&x).unwrap_or(json!(true))),
}
}
}
@@ -107,8 +108,8 @@ mod tests {
use serde_json::json;
use crate::db::entity::group;
- use crate::identity::types::group::GroupListParametersBuilder;
use crate::identity::Config;
+ use crate::identity::types::group::GroupListParametersBuilder;
use super::*;
@@ -236,7 +237,7 @@ mod tests {
let config = Config::default();
let req = GroupCreate {
- id: "1".into(),
+ id: Some("1".into()),
domain_id: "foo_domain".into(),
name: "group".into(),
description: Some("fake".into()),
diff --git a/src/identity/backends/sql/local_user.rs b/src/identity/backends/sql/local_user.rs
index d63fcc49..b1f3cbc6 100644
--- a/src/identity/backends/sql/local_user.rs
+++ b/src/identity/backends/sql/local_user.rs
@@ -12,9 +12,9 @@
//
// SPDX-License-Identifier: Apache-2.0
+use sea_orm::DatabaseConnection;
use sea_orm::entity::*;
use sea_orm::query::*;
-use sea_orm::DatabaseConnection;
use std::collections::HashMap;
use crate::config::Config;
@@ -30,7 +30,10 @@ pub async fn load_local_user_with_passwords>(
db: &DatabaseConnection,
user_id: S,
) -> Result<
- Option<(local_user::Model, impl IntoIterator- )>,
+ Option<(
+ local_user::Model,
+ impl IntoIterator
- + use
,
+ )>,
IdentityDatabaseError,
> {
let results: Vec<(local_user::Model, Vec)> = LocalUser::find()
@@ -48,11 +51,7 @@ pub async fn load_local_user_with_passwords>(
pub async fn load_local_users_passwords>>(
db: &DatabaseConnection,
user_ids: L,
-) -> Result<
- //impl IntoIterator- >>,
- Vec