Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/api/v3/group/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use crate::identity::IdentityApi;
use crate::keystone::ServiceState;
use types::{Group, GroupCreateRequest, GroupList, GroupListParameters, GroupResponse};

mod types;
pub mod types;

pub(super) fn openapi_router() -> OpenApiRouter<ServiceState> {
OpenApiRouter::new()
Expand Down
83 changes: 81 additions & 2 deletions src/api/v3/user/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,18 @@ use utoipa_axum::{router::OpenApiRouter, routes};

use crate::api::auth::Auth;
use crate::api::error::KeystoneApiError;
use crate::api::v3::group::types::{Group, GroupList};
use crate::identity::IdentityApi;
use crate::keystone::ServiceState;
use types::{User, UserCreateRequest, UserList, UserListParameters, UserResponse};

mod types;
pub mod types;

pub(super) fn openapi_router() -> OpenApiRouter<ServiceState> {
OpenApiRouter::new()
.routes(routes!(list, create))
.routes(routes!(show, remove))
.routes(routes!(groups))
}

/// List users
Expand Down Expand Up @@ -147,6 +149,35 @@ async fn remove(
Ok((StatusCode::NO_CONTENT).into_response())
}

/// List groups a user is member of
#[utoipa::path(
get,
path = "/{user_id}/groups",
description = "List groups a user is member of",
responses(
(status = OK, description = "List of user groups", body = GroupList),
(status = 500, description = "Internal error", example = json!(KeystoneApiError::InternalError(String::from("id = 1"))))
),
tag="users"
)]
#[tracing::instrument(name = "api::user_list", level = "debug", skip(state))]
async fn groups(
Auth(user_auth): Auth,
Path(user_id): Path<String>,
State(state): State<ServiceState>,
) -> Result<impl IntoResponse, KeystoneApiError> {
let groups: Vec<Group> = state
.provider
.get_identity_provider()
.list_groups_for_user(&state.db, &user_id)
.await
.map_err(KeystoneApiError::identity)?
.into_iter()
.map(Into::into)
.collect();
Ok(GroupList { groups })
}

#[cfg(test)]
mod tests {
use axum::{
Expand All @@ -161,13 +192,14 @@ mod tests {
use tower_http::trace::TraceLayer;

use super::openapi_router;
use crate::api::v3::group::types::{Group as ApiGroup, GroupList};
use crate::api::v3::user::types::{
User as ApiUser, UserCreate as ApiUserCreate, UserCreateRequest, UserList, UserResponse,
};
use crate::identity::{
MockIdentityProvider,
error::IdentityProviderError,
types::{User, UserCreate, UserListParameters},
types::{Group, User, UserCreate, UserListParameters},
};

use crate::tests::api::{get_mocked_state, get_mocked_state_unauthed};
Expand Down Expand Up @@ -439,4 +471,51 @@ mod tests {

assert_eq!(response.status(), StatusCode::NO_CONTENT);
}

#[tokio::test]
async fn test_groups() {
let mut identity_mock = MockIdentityProvider::default();
identity_mock
.expect_list_groups_for_user()
.withf(|_: &DatabaseConnection, uid: &str| uid == "foo")
.returning(|_, _| {
Ok(vec![Group {
id: "1".into(),
name: "2".into(),
..Default::default()
}])
});

let state = get_mocked_state(identity_mock);

let mut api = openapi_router()
.layer(TraceLayer::new_for_http())
.with_state(state);

let response = api
.as_service()
.oneshot(
Request::builder()
.uri("/foo/groups")
.header("x-auth-token", "foo")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();

assert_eq!(response.status(), StatusCode::OK);

let body = response.into_body().collect().await.unwrap().to_bytes();
let res: GroupList = serde_json::from_slice(&body).unwrap();
assert_eq!(
vec![ApiGroup {
id: "1".into(),
name: "2".into(),
extra: Some(json!({})),
..Default::default()
}],
res.groups
);
}
}
10 changes: 10 additions & 0 deletions src/identity/backends/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,16 @@ impl IdentityBackend for SqlBackend {
.await
.map_err(IdentityProviderError::database)
}

/// List groups a user is member of
#[tracing::instrument(level = "debug", skip(self, db))]
async fn list_groups_for_user<'a>(
&self,
db: &DatabaseConnection,
user_id: &'a str,
) -> Result<Vec<Group>, IdentityProviderError> {
Ok(group::list_for_user(&self.config, db, user_id).await?)
}
}

async fn list_users(
Expand Down
58 changes: 51 additions & 7 deletions src/identity/backends/sql/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ use sea_orm::query::*;
use serde_json::Value;
use serde_json::json;

use crate::db::entity::{group, prelude::Group as DbGroup};
use crate::db::entity::{
group,
prelude::{Group as DbGroup, UserGroupMembership as DbUserGroupMembership},
user_group_membership,
};
use crate::identity::Config;
use crate::identity::backends::sql::IdentityDatabaseError;
use crate::identity::types::{Group, GroupCreate, GroupListParameters};
Expand Down Expand Up @@ -97,6 +101,27 @@ impl From<group::Model> for Group {
}
}

pub async fn list_for_user(
_conf: &Config,
db: &DatabaseConnection,
user_id: &str,
) -> Result<Vec<Group>, IdentityDatabaseError> {
let groups: Vec<(user_group_membership::Model, Vec<group::Model>)> =
DbUserGroupMembership::find()
.filter(user_group_membership::Column::UserId.eq(user_id))
.find_with_related(DbGroup)
.all(db)
.await?;

let results: Vec<Group> = groups
.into_iter()
.map(|(_, x)| x.into_iter())
.flatten()
.map(Into::into)
.collect();
Ok(results)
}

#[cfg(test)]
mod tests {
#![allow(clippy::derivable_impls)]
Expand All @@ -110,9 +135,9 @@ mod tests {

use super::*;

fn get_group_mock(id: String) -> group::Model {
fn get_group_mock<S: AsRef<str>>(id: S) -> group::Model {
group::Model {
id: id.clone(),
id: id.as_ref().to_string(),
domain_id: "foo_domain".into(),
name: "group".into(),
description: Some("fake".into()),
Expand All @@ -126,7 +151,7 @@ mod tests {
let db = MockDatabase::new(DatabaseBackend::Postgres)
.append_query_results([
// First query result - select user itself
vec![get_group_mock("1".into())],
vec![get_group_mock("1")],
])
.into_connection();
let config = Config::default();
Expand Down Expand Up @@ -191,7 +216,7 @@ mod tests {
async fn test_get() {
// Create MockDatabase with mock query results
let db = MockDatabase::new(DatabaseBackend::Postgres)
.append_query_results([vec![get_group_mock("1".into())], vec![]])
.append_query_results([vec![get_group_mock("1")], vec![]])
.into_connection();
let config = Config::default();

Expand Down Expand Up @@ -229,7 +254,7 @@ mod tests {
async fn test_create() {
// Create MockDatabase with mock query results
let db = MockDatabase::new(DatabaseBackend::Postgres)
.append_query_results([vec![get_group_mock("1".into())], vec![]])
.append_query_results([vec![get_group_mock("1")], vec![]])
.into_connection();
let config = Config::default();

Expand All @@ -242,7 +267,7 @@ mod tests {
};
assert_eq!(
create(&config, &db, req).await.unwrap(),
get_group_mock("1".into()).into()
get_group_mock("1").into()
);
// Checking transaction log
assert_eq!(
Expand Down Expand Up @@ -283,4 +308,23 @@ mod tests {
),]
);
}

#[tokio::test]
async fn test_list_for_user() {
let db = MockDatabase::new(DatabaseBackend::Postgres)
.append_query_results([vec![], vec![get_group_mock("1"), get_group_mock("2")]])
.into_connection();
let config = Config::default();
assert_eq!(list_for_user(&config, &db, "foo").await.unwrap(), vec![]);

// Checking transaction log
assert_eq!(
db.into_transaction_log(),
[Transaction::from_sql_and_values(
DatabaseBackend::Postgres,
r#"SELECT "user_group_membership"."user_id" AS "A_user_id", "user_group_membership"."group_id" AS "A_group_id", "group"."id" AS "B_id", "group"."domain_id" AS "B_domain_id", "group"."name" AS "B_name", "group"."description" AS "B_description", "group"."extra" AS "B_extra" FROM "user_group_membership" LEFT JOIN "group" ON "user_group_membership"."group_id" = "group"."id" WHERE "user_group_membership"."user_id" = $1 ORDER BY "user_group_membership"."user_id" ASC, "user_group_membership"."group_id" ASC"#,
["foo".into()]
),]
);
}
}
116 changes: 69 additions & 47 deletions src/identity/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ pub trait IdentityApi: Send + Sync + Clone {
db: &DatabaseConnection,
group_id: &'a str,
) -> Result<(), IdentityProviderError>;

async fn list_groups_for_user<'a>(
&self,
db: &DatabaseConnection,
user_id: &'a str,
) -> Result<impl IntoIterator<Item = Group>, IdentityProviderError>;
}

#[cfg(test)]
Expand All @@ -96,53 +102,59 @@ mock! {

#[async_trait]
impl IdentityApi for IdentityProvider {
async fn list_users(
&self,
db: &DatabaseConnection,
params: &UserListParameters,
) -> Result<Vec<User>, IdentityProviderError>;

async fn get_user<'a>(
&self,
db: &DatabaseConnection,
user_id: &'a str,
) -> Result<Option<User>, IdentityProviderError>;

async fn create_user(
&self,
db: &DatabaseConnection,
user: UserCreate,
) -> Result<User, IdentityProviderError>;

async fn delete_user<'a>(
&self,
db: &DatabaseConnection,
user_id: &'a str,
) -> Result<(), IdentityProviderError>;

async fn list_groups(
&self,
db: &DatabaseConnection,
params: &GroupListParameters,
) -> Result<Vec<Group>, IdentityProviderError>;

async fn get_group<'a>(
&self,
db: &DatabaseConnection,
group_id: &'a str,
) -> Result<Option<Group>, IdentityProviderError>;

async fn create_group(
&self,
db: &DatabaseConnection,
group: GroupCreate,
) -> Result<Group, IdentityProviderError>;

async fn delete_group<'a>(
&self,
db: &DatabaseConnection,
group_id: &'a str,
) -> Result<(), IdentityProviderError>;
async fn list_users(
&self,
db: &DatabaseConnection,
params: &UserListParameters,
) -> Result<Vec<User>, IdentityProviderError>;

async fn get_user<'a>(
&self,
db: &DatabaseConnection,
user_id: &'a str,
) -> Result<Option<User>, IdentityProviderError>;

async fn create_user(
&self,
db: &DatabaseConnection,
user: UserCreate,
) -> Result<User, IdentityProviderError>;

async fn delete_user<'a>(
&self,
db: &DatabaseConnection,
user_id: &'a str,
) -> Result<(), IdentityProviderError>;

async fn list_groups(
&self,
db: &DatabaseConnection,
params: &GroupListParameters,
) -> Result<Vec<Group>, IdentityProviderError>;

async fn get_group<'a>(
&self,
db: &DatabaseConnection,
group_id: &'a str,
) -> Result<Option<Group>, IdentityProviderError>;

async fn create_group(
&self,
db: &DatabaseConnection,
group: GroupCreate,
) -> Result<Group, IdentityProviderError>;

async fn delete_group<'a>(
&self,
db: &DatabaseConnection,
group_id: &'a str,
) -> Result<(), IdentityProviderError>;

async fn list_groups_for_user<'a>(
&self,
db: &DatabaseConnection,
user_id: &'a str,
) -> Result<Vec<Group>, IdentityProviderError>;
}

impl Clone for IdentityProvider {
Expand Down Expand Up @@ -263,4 +275,14 @@ impl IdentityApi for IdentityProvider {
) -> Result<(), IdentityProviderError> {
self.backend_driver.delete_group(db, group_id).await
}

/// List groups a user is a member of
#[tracing::instrument(level = "info", skip(self, db))]
async fn list_groups_for_user<'a>(
&self,
db: &DatabaseConnection,
user_id: &'a str,
) -> Result<impl IntoIterator<Item = Group>, IdentityProviderError> {
self.backend_driver.list_groups_for_user(db, user_id).await
}
}
Loading