diff --git a/src/api/v3/group/mod.rs b/src/api/v3/group/mod.rs index 922c4e21..12f12a34 100644 --- a/src/api/v3/group/mod.rs +++ b/src/api/v3/group/mod.rs @@ -26,7 +26,7 @@ use crate::identity::IdentityApi; use crate::keystone::ServiceState; use types::{Group, GroupCreateRequest, GroupList, GroupListParameters, GroupResponse}; -mod types; +pub mod types; pub(super) fn openapi_router() -> OpenApiRouter { OpenApiRouter::new() diff --git a/src/api/v3/user/mod.rs b/src/api/v3/user/mod.rs index d595d69a..55552b69 100644 --- a/src/api/v3/user/mod.rs +++ b/src/api/v3/user/mod.rs @@ -22,16 +22,18 @@ use utoipa_axum::{router::OpenApiRouter, routes}; use crate::api::auth::Auth; use crate::api::error::KeystoneApiError; +use crate::api::v3::group::types::{Group, GroupList}; use crate::identity::IdentityApi; use crate::keystone::ServiceState; use types::{User, UserCreateRequest, UserList, UserListParameters, UserResponse}; -mod types; +pub mod types; pub(super) fn openapi_router() -> OpenApiRouter { OpenApiRouter::new() .routes(routes!(list, create)) .routes(routes!(show, remove)) + .routes(routes!(groups)) } /// List users @@ -147,6 +149,35 @@ async fn remove( Ok((StatusCode::NO_CONTENT).into_response()) } +/// List groups a user is member of +#[utoipa::path( + get, + path = "/{user_id}/groups", + description = "List groups a user is member of", + responses( + (status = OK, description = "List of user groups", body = GroupList), + (status = 500, description = "Internal error", example = json!(KeystoneApiError::InternalError(String::from("id = 1")))) + ), + tag="users" +)] +#[tracing::instrument(name = "api::user_list", level = "debug", skip(state))] +async fn groups( + Auth(user_auth): Auth, + Path(user_id): Path, + State(state): State, +) -> Result { + let groups: Vec = state + .provider + .get_identity_provider() + .list_groups_for_user(&state.db, &user_id) + .await + .map_err(KeystoneApiError::identity)? + .into_iter() + .map(Into::into) + .collect(); + Ok(GroupList { groups }) +} + #[cfg(test)] mod tests { use axum::{ @@ -161,13 +192,14 @@ mod tests { use tower_http::trace::TraceLayer; use super::openapi_router; + use crate::api::v3::group::types::{Group as ApiGroup, GroupList}; use crate::api::v3::user::types::{ User as ApiUser, UserCreate as ApiUserCreate, UserCreateRequest, UserList, UserResponse, }; use crate::identity::{ MockIdentityProvider, error::IdentityProviderError, - types::{User, UserCreate, UserListParameters}, + types::{Group, User, UserCreate, UserListParameters}, }; use crate::tests::api::{get_mocked_state, get_mocked_state_unauthed}; @@ -439,4 +471,51 @@ mod tests { assert_eq!(response.status(), StatusCode::NO_CONTENT); } + + #[tokio::test] + async fn test_groups() { + let mut identity_mock = MockIdentityProvider::default(); + identity_mock + .expect_list_groups_for_user() + .withf(|_: &DatabaseConnection, uid: &str| uid == "foo") + .returning(|_, _| { + Ok(vec![Group { + id: "1".into(), + name: "2".into(), + ..Default::default() + }]) + }); + + let state = get_mocked_state(identity_mock); + + let mut api = openapi_router() + .layer(TraceLayer::new_for_http()) + .with_state(state); + + let response = api + .as_service() + .oneshot( + Request::builder() + .uri("/foo/groups") + .header("x-auth-token", "foo") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let body = response.into_body().collect().await.unwrap().to_bytes(); + let res: GroupList = serde_json::from_slice(&body).unwrap(); + assert_eq!( + vec![ApiGroup { + id: "1".into(), + name: "2".into(), + extra: Some(json!({})), + ..Default::default() + }], + res.groups + ); + } } diff --git a/src/identity/backends/sql.rs b/src/identity/backends/sql.rs index 65ce6658..71835dc3 100644 --- a/src/identity/backends/sql.rs +++ b/src/identity/backends/sql.rs @@ -134,6 +134,16 @@ impl IdentityBackend for SqlBackend { .await .map_err(IdentityProviderError::database) } + + /// List groups a user is member of + #[tracing::instrument(level = "debug", skip(self, db))] + async fn list_groups_for_user<'a>( + &self, + db: &DatabaseConnection, + user_id: &'a str, + ) -> Result, IdentityProviderError> { + Ok(group::list_for_user(&self.config, db, user_id).await?) + } } async fn list_users( diff --git a/src/identity/backends/sql/group.rs b/src/identity/backends/sql/group.rs index 7b693d49..4d2f01bc 100644 --- a/src/identity/backends/sql/group.rs +++ b/src/identity/backends/sql/group.rs @@ -18,7 +18,11 @@ use sea_orm::query::*; use serde_json::Value; use serde_json::json; -use crate::db::entity::{group, prelude::Group as DbGroup}; +use crate::db::entity::{ + group, + prelude::{Group as DbGroup, UserGroupMembership as DbUserGroupMembership}, + user_group_membership, +}; use crate::identity::Config; use crate::identity::backends::sql::IdentityDatabaseError; use crate::identity::types::{Group, GroupCreate, GroupListParameters}; @@ -97,6 +101,27 @@ impl From for Group { } } +pub async fn list_for_user( + _conf: &Config, + db: &DatabaseConnection, + user_id: &str, +) -> Result, IdentityDatabaseError> { + let groups: Vec<(user_group_membership::Model, Vec)> = + DbUserGroupMembership::find() + .filter(user_group_membership::Column::UserId.eq(user_id)) + .find_with_related(DbGroup) + .all(db) + .await?; + + let results: Vec = groups + .into_iter() + .map(|(_, x)| x.into_iter()) + .flatten() + .map(Into::into) + .collect(); + Ok(results) +} + #[cfg(test)] mod tests { #![allow(clippy::derivable_impls)] @@ -110,9 +135,9 @@ mod tests { use super::*; - fn get_group_mock(id: String) -> group::Model { + fn get_group_mock>(id: S) -> group::Model { group::Model { - id: id.clone(), + id: id.as_ref().to_string(), domain_id: "foo_domain".into(), name: "group".into(), description: Some("fake".into()), @@ -126,7 +151,7 @@ mod tests { let db = MockDatabase::new(DatabaseBackend::Postgres) .append_query_results([ // First query result - select user itself - vec![get_group_mock("1".into())], + vec![get_group_mock("1")], ]) .into_connection(); let config = Config::default(); @@ -191,7 +216,7 @@ mod tests { async fn test_get() { // Create MockDatabase with mock query results let db = MockDatabase::new(DatabaseBackend::Postgres) - .append_query_results([vec![get_group_mock("1".into())], vec![]]) + .append_query_results([vec![get_group_mock("1")], vec![]]) .into_connection(); let config = Config::default(); @@ -229,7 +254,7 @@ mod tests { async fn test_create() { // Create MockDatabase with mock query results let db = MockDatabase::new(DatabaseBackend::Postgres) - .append_query_results([vec![get_group_mock("1".into())], vec![]]) + .append_query_results([vec![get_group_mock("1")], vec![]]) .into_connection(); let config = Config::default(); @@ -242,7 +267,7 @@ mod tests { }; assert_eq!( create(&config, &db, req).await.unwrap(), - get_group_mock("1".into()).into() + get_group_mock("1").into() ); // Checking transaction log assert_eq!( @@ -283,4 +308,23 @@ mod tests { ),] ); } + + #[tokio::test] + async fn test_list_for_user() { + let db = MockDatabase::new(DatabaseBackend::Postgres) + .append_query_results([vec![], vec![get_group_mock("1"), get_group_mock("2")]]) + .into_connection(); + let config = Config::default(); + assert_eq!(list_for_user(&config, &db, "foo").await.unwrap(), vec![]); + + // Checking transaction log + assert_eq!( + db.into_transaction_log(), + [Transaction::from_sql_and_values( + DatabaseBackend::Postgres, + r#"SELECT "user_group_membership"."user_id" AS "A_user_id", "user_group_membership"."group_id" AS "A_group_id", "group"."id" AS "B_id", "group"."domain_id" AS "B_domain_id", "group"."name" AS "B_name", "group"."description" AS "B_description", "group"."extra" AS "B_extra" FROM "user_group_membership" LEFT JOIN "group" ON "user_group_membership"."group_id" = "group"."id" WHERE "user_group_membership"."user_id" = $1 ORDER BY "user_group_membership"."user_id" ASC, "user_group_membership"."group_id" ASC"#, + ["foo".into()] + ),] + ); + } } diff --git a/src/identity/mod.rs b/src/identity/mod.rs index 5c8a5e4f..c0ffc057 100644 --- a/src/identity/mod.rs +++ b/src/identity/mod.rs @@ -86,6 +86,12 @@ pub trait IdentityApi: Send + Sync + Clone { db: &DatabaseConnection, group_id: &'a str, ) -> Result<(), IdentityProviderError>; + + async fn list_groups_for_user<'a>( + &self, + db: &DatabaseConnection, + user_id: &'a str, + ) -> Result, IdentityProviderError>; } #[cfg(test)] @@ -96,53 +102,59 @@ mock! { #[async_trait] impl IdentityApi for IdentityProvider { - async fn list_users( - &self, - db: &DatabaseConnection, - params: &UserListParameters, - ) -> Result, IdentityProviderError>; - - async fn get_user<'a>( - &self, - db: &DatabaseConnection, - user_id: &'a str, - ) -> Result, IdentityProviderError>; - - async fn create_user( - &self, - db: &DatabaseConnection, - user: UserCreate, - ) -> Result; - - async fn delete_user<'a>( - &self, - db: &DatabaseConnection, - user_id: &'a str, - ) -> Result<(), IdentityProviderError>; - - async fn list_groups( - &self, - db: &DatabaseConnection, - params: &GroupListParameters, - ) -> Result, IdentityProviderError>; - - async fn get_group<'a>( - &self, - db: &DatabaseConnection, - group_id: &'a str, - ) -> Result, IdentityProviderError>; - - async fn create_group( - &self, - db: &DatabaseConnection, - group: GroupCreate, - ) -> Result; - - async fn delete_group<'a>( - &self, - db: &DatabaseConnection, - group_id: &'a str, - ) -> Result<(), IdentityProviderError>; + async fn list_users( + &self, + db: &DatabaseConnection, + params: &UserListParameters, + ) -> Result, IdentityProviderError>; + + async fn get_user<'a>( + &self, + db: &DatabaseConnection, + user_id: &'a str, + ) -> Result, IdentityProviderError>; + + async fn create_user( + &self, + db: &DatabaseConnection, + user: UserCreate, + ) -> Result; + + async fn delete_user<'a>( + &self, + db: &DatabaseConnection, + user_id: &'a str, + ) -> Result<(), IdentityProviderError>; + + async fn list_groups( + &self, + db: &DatabaseConnection, + params: &GroupListParameters, + ) -> Result, IdentityProviderError>; + + async fn get_group<'a>( + &self, + db: &DatabaseConnection, + group_id: &'a str, + ) -> Result, IdentityProviderError>; + + async fn create_group( + &self, + db: &DatabaseConnection, + group: GroupCreate, + ) -> Result; + + async fn delete_group<'a>( + &self, + db: &DatabaseConnection, + group_id: &'a str, + ) -> Result<(), IdentityProviderError>; + + async fn list_groups_for_user<'a>( + &self, + db: &DatabaseConnection, + user_id: &'a str, + ) -> Result, IdentityProviderError>; } impl Clone for IdentityProvider { @@ -263,4 +275,14 @@ impl IdentityApi for IdentityProvider { ) -> Result<(), IdentityProviderError> { self.backend_driver.delete_group(db, group_id).await } + + /// List groups a user is a member of + #[tracing::instrument(level = "info", skip(self, db))] + async fn list_groups_for_user<'a>( + &self, + db: &DatabaseConnection, + user_id: &'a str, + ) -> Result, IdentityProviderError> { + self.backend_driver.list_groups_for_user(db, user_id).await + } } diff --git a/src/identity/types.rs b/src/identity/types.rs index aff26f7a..677b39b3 100644 --- a/src/identity/types.rs +++ b/src/identity/types.rs @@ -87,6 +87,13 @@ pub trait IdentityBackend: DynClone + Send + Sync + std::fmt::Debug { db: &DatabaseConnection, group_id: &'a str, ) -> Result<(), IdentityProviderError>; + + /// List groups a user is member of + async fn list_groups_for_user<'a>( + &self, + db: &DatabaseConnection, + user_id: &'a str, + ) -> Result, IdentityProviderError>; } dyn_clone::clone_trait_object!(IdentityBackend);