diff --git a/src/api/common.rs b/src/api/common.rs index b0641d2e..5ad22635 100644 --- a/src/api/common.rs +++ b/src/api/common.rs @@ -39,7 +39,7 @@ pub async fn get_domain, N: AsRef>( state .provider .get_resource_provider() - .get_domain(&state.db, did.as_ref()) + .get_domain(state, did.as_ref()) .await? .ok_or_else(|| KeystoneApiError::NotFound { resource: "domain".into(), @@ -49,7 +49,7 @@ pub async fn get_domain, N: AsRef>( state .provider .get_resource_provider() - .find_domain_by_name(&state.db, name.as_ref()) + .find_domain_by_name(state, name.as_ref()) .await? .ok_or_else(|| KeystoneApiError::NotFound { resource: "domain".into(), @@ -76,7 +76,7 @@ pub async fn find_project_from_scope( state .provider .get_resource_provider() - .get_project(&state.db, pid) + .get_project(state, pid) .await? } else if let Some(name) = &scope.name { if let Some(domain) = &scope.domain { @@ -87,7 +87,7 @@ pub async fn find_project_from_scope( .provider .get_resource_provider() .find_domain_by_name( - &state.db, + state, &domain .name .clone() @@ -107,7 +107,7 @@ pub async fn find_project_from_scope( state .provider .get_resource_provider() - .get_project_by_name(&state.db, name, &domain_id) + .get_project_by_name(state, name, &domain_id) .await? } else { return Err(KeystoneApiError::ProjectDomain); @@ -137,7 +137,7 @@ mod tests { let mut resource_mock = MockResourceProvider::default(); resource_mock .expect_get_domain() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "domain_id") + .withf(|_, id: &'_ str| id == "domain_id") .returning(|_, _| { Ok(Some(Domain { id: "domain_id".into(), @@ -147,7 +147,7 @@ mod tests { }); resource_mock .expect_find_domain_by_name() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "domain_name") + .withf(|_, id: &'_ str| id == "domain_name") .returning(|_, _| { Ok(Some(Domain { id: "domain_id".into(), diff --git a/src/api/v3/auth/token/common.rs b/src/api/v3/auth/token/common.rs index 2299bd6d..b0c4887f 100644 --- a/src/api/v3/auth/token/common.rs +++ b/src/api/v3/auth/token/common.rs @@ -43,7 +43,7 @@ impl Token { &state .provider .get_identity_provider() - .get_user(&state.db, token.user_id()) + .get_user(state, token.user_id()) .await .map_err(KeystoneApiError::identity)? .ok_or_else(|| KeystoneApiError::NotFound { @@ -88,7 +88,7 @@ impl Token { state .provider .get_resource_provider() - .get_project(&state.db, &token.project_id) + .get_project(state, &token.project_id) .await .map_err(KeystoneApiError::resource)? .ok_or_else(|| KeystoneApiError::NotFound { @@ -104,7 +104,7 @@ impl Token { state .provider .get_resource_provider() - .get_project(&state.db, &token.project_id) + .get_project(state, &token.project_id) .await .map_err(KeystoneApiError::resource)? .ok_or_else(|| KeystoneApiError::NotFound { @@ -128,7 +128,7 @@ impl Token { state .provider .get_resource_provider() - .get_project(&state.db, &token.project_id) + .get_project(state, &token.project_id) .await .map_err(KeystoneApiError::resource)? .ok_or_else(|| KeystoneApiError::NotFound { @@ -144,7 +144,7 @@ impl Token { state .provider .get_resource_provider() - .get_project(&state.db, &token.project_id) + .get_project(state, &token.project_id) .await .map_err(KeystoneApiError::resource)? .ok_or_else(|| KeystoneApiError::NotFound { @@ -219,7 +219,7 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_get_user() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") + .withf(|_, id: &'_ str| id == "bar") .returning(|_, _| { Ok(Some(UserResponse { id: "bar".into(), @@ -231,7 +231,7 @@ mod tests { let mut resource_mock = MockResourceProvider::default(); resource_mock .expect_get_domain() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "user_domain_id") + .withf(|_, id: &'_ str| id == "user_domain_id") .returning(|_, _| { Ok(Some(Domain { id: "user_domain_id".into(), @@ -274,7 +274,7 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_get_user() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") + .withf(|_, id: &'_ str| id == "bar") .returning(|_, _| { Ok(Some(UserResponse { id: "bar".into(), @@ -333,7 +333,7 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_get_user() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") + .withf(|_, id: &'_ str| id == "bar") .returning(|_, _| { Ok(Some(UserResponse { id: "bar".into(), diff --git a/src/api/v3/auth/token/mod.rs b/src/api/v3/auth/token/mod.rs index 3773b98e..0c191106 100644 --- a/src/api/v3/auth/token/mod.rs +++ b/src/api/v3/auth/token/mod.rs @@ -66,7 +66,7 @@ async fn authenticate_request( state .provider .get_identity_provider() - .authenticate_by_password(&state.db, &state.provider, req) + .authenticate_by_password(state, req) .await?, ); } @@ -83,7 +83,7 @@ async fn authenticate_request( state .provider .get_identity_provider() - .get_user(&state.db, &authz.user_id) + .get_user(state, &authz.user_id) .await .map(|x| { x.ok_or_else(|| KeystoneApiError::NotFound { @@ -196,7 +196,7 @@ async fn post( let catalog: Catalog = state .provider .get_catalog_provider() - .get_catalog(&state.db, true) + .get_catalog(&state, true) .await? .into(); api_token.token.catalog = Some(catalog); @@ -277,7 +277,7 @@ async fn show( let catalog: Catalog = state .provider .get_catalog_provider() - .get_catalog(&state.db, true) + .get_catalog(&state, true) .await? .into(); response_token.catalog = Some(catalog); @@ -356,12 +356,12 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_authenticate_by_password() - .withf(|_, _, req: &UserPasswordAuthRequest| { + .withf(|_, req: &UserPasswordAuthRequest| { req.id == Some("uid".to_string()) && req.password == "pwd" && req.name == Some("uname".to_string()) }) - .returning(move |_, _, _| Ok(auth_clone.clone())); + .returning(move |_, _| Ok(auth_clone.clone())); let provider = Provider::mocked_builder() .config(config.clone()) @@ -535,7 +535,7 @@ mod tests { let mut resource_mock = MockResourceProvider::default(); resource_mock .expect_get_domain() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "user_domain_id") + .withf(|_, id: &'_ str| id == "user_domain_id") .returning(|_, _| { Ok(Some(Domain { id: "user_domain_id".into(), @@ -634,7 +634,7 @@ mod tests { let mut resource_mock = MockResourceProvider::default(); resource_mock .expect_get_domain() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "user_domain_id") + .withf(|_, id: &'_ str| id == "user_domain_id") .returning(|_, _| { Ok(Some(Domain { id: "user_domain_id".into(), @@ -823,12 +823,12 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_authenticate_by_password() - .withf(|_, _, req: &UserPasswordAuthRequest| { + .withf(|_, req: &UserPasswordAuthRequest| { req.id == Some("uid".to_string()) && req.password == "pass" && req.name == Some("uname".to_string()) }) - .returning(|_, _, _| { + .returning(|_, _| { Ok(AuthenticatedInfo::builder() .user_id("uid") .user(UserResponse { @@ -979,7 +979,7 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_authenticate_by_password() - .returning(|_, _, _| { + .returning(|_, _| { Ok(AuthenticatedInfo::builder() .user_id("uid") .user(UserResponse { diff --git a/src/api/v3/group/mod.rs b/src/api/v3/group/mod.rs index 44dc338a..c01002c3 100644 --- a/src/api/v3/group/mod.rs +++ b/src/api/v3/group/mod.rs @@ -55,7 +55,7 @@ async fn list( let groups: Vec = state .provider .get_identity_provider() - .list_groups(&state.db, &query.into()) + .list_groups(&state, &query.into()) .await .map_err(KeystoneApiError::identity)? .into_iter() @@ -85,7 +85,7 @@ async fn show( state .provider .get_identity_provider() - .get_group(&state.db, &group_id) + .get_group(&state, &group_id) .await .map(|x| { x.ok_or_else(|| KeystoneApiError::NotFound { @@ -115,7 +115,7 @@ async fn create( let res = state .provider .get_identity_provider() - .create_group(&state.db, req.into()) + .create_group(&state, req.into()) .await .map_err(KeystoneApiError::identity)?; Ok((StatusCode::CREATED, res).into_response()) @@ -142,7 +142,7 @@ async fn remove( state .provider .get_identity_provider() - .delete_group(&state.db, &group_id) + .delete_group(&state, &group_id) .await .map_err(KeystoneApiError::identity)?; Ok((StatusCode::NO_CONTENT).into_response()) @@ -155,7 +155,7 @@ mod tests { http::{Request, StatusCode, header}, }; use http_body_util::BodyExt; // for `collect` - use sea_orm::DatabaseConnection; + use serde_json::json; use tower::ServiceExt; // for `call`, `oneshot`, and `ready` @@ -179,7 +179,7 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_list_groups() - .withf(|_: &DatabaseConnection, _: &GroupListParameters| true) + .withf(|_, _: &GroupListParameters| true) .returning(|_, _| { Ok(vec![Group { id: "1".into(), @@ -228,7 +228,7 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_list_groups() - .withf(|_: &DatabaseConnection, qp: &GroupListParameters| { + .withf(|_, qp: &GroupListParameters| { GroupListParameters { domain_id: Some("domain".into()), name: Some("name".into()), @@ -282,12 +282,12 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_get_group() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "foo") + .withf(|_, id: &'_ str| id == "foo") .returning(|_, _| Ok(None)); identity_mock .expect_get_group() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") + .withf(|_, id: &'_ str| id == "bar") .returning(|_, _| { Ok(Some(Group { id: "bar".into(), @@ -346,9 +346,7 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_create_group() - .withf(|_: &DatabaseConnection, req: &GroupCreate| { - req.domain_id == "domain" && req.name == "name" - }) + .withf(|_, req: &GroupCreate| req.domain_id == "domain" && req.name == "name") .returning(|_, req| { Ok(Group { id: "bar".into(), @@ -399,12 +397,12 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_delete_group() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "foo") + .withf(|_, id: &'_ str| id == "foo") .returning(|_, _| Err(IdentityProviderError::GroupNotFound("foo".into()))); identity_mock .expect_delete_group() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") + .withf(|_, id: &'_ str| id == "bar") .returning(|_, _| Ok(())); let state = get_mocked_state(identity_mock); diff --git a/src/api/v3/user/mod.rs b/src/api/v3/user/mod.rs index f529bf74..b8e5b27f 100644 --- a/src/api/v3/user/mod.rs +++ b/src/api/v3/user/mod.rs @@ -57,7 +57,7 @@ async fn list( let users: Vec = state .provider .get_identity_provider() - .list_users(&state.db, &query.into()) + .list_users(&state, &query.into()) .await .map_err(KeystoneApiError::identity)? .into_iter() @@ -86,7 +86,7 @@ async fn show( state .provider .get_identity_provider() - .get_user(&state.db, &user_id) + .get_user(&state, &user_id) .await .map(|x| { x.ok_or_else(|| KeystoneApiError::NotFound { @@ -116,7 +116,7 @@ async fn create( let user = state .provider .get_identity_provider() - .create_user(&state.db, req.into()) + .create_user(&state, req.into()) .await .map_err(KeystoneApiError::identity)?; Ok((StatusCode::CREATED, user).into_response()) @@ -143,7 +143,7 @@ async fn remove( state .provider .get_identity_provider() - .delete_user(&state.db, &user_id) + .delete_user(&state, &user_id) .await .map_err(KeystoneApiError::identity)?; Ok((StatusCode::NO_CONTENT).into_response()) @@ -169,7 +169,7 @@ async fn groups( let groups: Vec = state .provider .get_identity_provider() - .list_groups_of_user(&state.db, &user_id) + .list_groups_of_user(&state, &user_id) .await .map_err(KeystoneApiError::identity)? .into_iter() @@ -185,7 +185,7 @@ mod tests { http::{self, Request, StatusCode}, }; use http_body_util::BodyExt; // for `collect` - use sea_orm::DatabaseConnection; + use serde_json::json; use tower::ServiceExt; // for `call`, `oneshot`, and `ready` @@ -210,7 +210,7 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_list_users() - .withf(|_: &DatabaseConnection, _: &UserListParameters| true) + .withf(|_, _: &UserListParameters| true) .returning(|_, _| { Ok(vec![UserResponse { id: "1".into(), @@ -258,7 +258,7 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_list_users() - .withf(|_: &DatabaseConnection, qp: &UserListParameters| { + .withf(|_, qp: &UserListParameters| { UserListParameters { domain_id: Some("domain".into()), name: Some("name".into()), @@ -312,9 +312,7 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_create_user() - .withf(|_: &DatabaseConnection, req: &UserCreate| { - req.domain_id == "domain" && req.name == "name" - }) + .withf(|_, req: &UserCreate| req.domain_id == "domain" && req.name == "name") .returning(|_, req| { Ok(UserResponse { id: "bar".into(), @@ -364,12 +362,12 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_get_user() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "foo") + .withf(|_, id: &'_ str| id == "foo") .returning(|_, _| Ok(None)); identity_mock .expect_get_user() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") + .withf(|_, id: &'_ str| id == "bar") .returning(|_, _| { Ok(Some(UserResponse { id: "bar".into(), @@ -428,12 +426,12 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_delete_user() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "foo") + .withf(|_, id: &'_ str| id == "foo") .returning(|_, _| Err(IdentityProviderError::UserNotFound("foo".into()))); identity_mock .expect_delete_user() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") + .withf(|_, id: &'_ str| id == "bar") .returning(|_, _| Ok(())); let state = get_mocked_state(identity_mock); @@ -478,7 +476,7 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_list_groups_of_user() - .withf(|_: &DatabaseConnection, uid: &str| uid == "foo") + .withf(|_, uid: &str| uid == "foo") .returning(|_, _| { Ok(vec![Group { id: "1".into(), diff --git a/src/api/v4/auth/passkey/finish.rs b/src/api/v4/auth/passkey/finish.rs index 93756b7b..98bf1c58 100644 --- a/src/api/v4/auth/passkey/finish.rs +++ b/src/api/v4/auth/passkey/finish.rs @@ -61,7 +61,7 @@ pub(super) async fn finish( if let Some(s) = state .provider .get_identity_provider() - .get_user_webauthn_credential_authentication_state(&state.db, &user_id) + .get_user_webauthn_credential_authentication_state(&state, &user_id) .await? { // We explicitly try to deserealize the request data directly into the underlying @@ -81,7 +81,7 @@ pub(super) async fn finish( state .provider .get_identity_provider() - .delete_user_webauthn_credential_authentication_state(&state.db, &user_id) + .delete_user_webauthn_credential_authentication_state(&state, &user_id) .await?; } let authed_info = AuthenticatedInfo::builder() @@ -90,7 +90,7 @@ pub(super) async fn finish( state .provider .get_identity_provider() - .get_user(&state.db, &user_id) + .get_user(&state, &user_id) .await .map(|x| { x.ok_or_else(|| KeystoneApiError::NotFound { diff --git a/src/api/v4/auth/passkey/start.rs b/src/api/v4/auth/passkey/start.rs index 8d955346..c75c26cc 100644 --- a/src/api/v4/auth/passkey/start.rs +++ b/src/api/v4/auth/passkey/start.rs @@ -50,12 +50,12 @@ pub(super) async fn start( state .provider .get_identity_provider() - .delete_user_webauthn_credential_authentication_state(&state.db, &req.passkey.user_id) + .delete_user_webauthn_credential_authentication_state(&state, &req.passkey.user_id) .await?; let allow_credentials: Vec = state .provider .get_identity_provider() - .list_user_webauthn_credentials(&state.db, &req.passkey.user_id) + .list_user_webauthn_credentials(&state, &req.passkey.user_id) .await? .into_iter() .collect(); @@ -68,7 +68,7 @@ pub(super) async fn start( .provider .get_identity_provider() .save_user_webauthn_credential_authentication_state( - &state.db, + &state, &req.passkey.user_id, auth_state, ) diff --git a/src/api/v4/auth/token/common.rs b/src/api/v4/auth/token/common.rs index 70ead8e3..8b6b918a 100644 --- a/src/api/v4/auth/token/common.rs +++ b/src/api/v4/auth/token/common.rs @@ -43,7 +43,7 @@ impl Token { &state .provider .get_identity_provider() - .get_user(&state.db, token.user_id()) + .get_user(state, token.user_id()) .await .map_err(KeystoneApiError::identity)? .ok_or_else(|| KeystoneApiError::NotFound { @@ -86,7 +86,7 @@ impl Token { state .provider .get_resource_provider() - .get_project(&state.db, &token.project_id) + .get_project(state, &token.project_id) .await .map_err(KeystoneApiError::resource)? .ok_or_else(|| KeystoneApiError::NotFound { @@ -102,7 +102,7 @@ impl Token { state .provider .get_resource_provider() - .get_project(&state.db, &token.project_id) + .get_project(state, &token.project_id) .await .map_err(KeystoneApiError::resource)? .ok_or_else(|| KeystoneApiError::NotFound { @@ -126,7 +126,7 @@ impl Token { state .provider .get_resource_provider() - .get_project(&state.db, &token.project_id) + .get_project(state, &token.project_id) .await .map_err(KeystoneApiError::resource)? .ok_or_else(|| KeystoneApiError::NotFound { @@ -142,7 +142,7 @@ impl Token { state .provider .get_resource_provider() - .get_project(&state.db, &token.project_id) + .get_project(state, &token.project_id) .await .map_err(KeystoneApiError::resource)? .ok_or_else(|| KeystoneApiError::NotFound { diff --git a/src/api/v4/auth/token/mod.rs b/src/api/v4/auth/token/mod.rs index 1f6beabf..10c5b93e 100644 --- a/src/api/v4/auth/token/mod.rs +++ b/src/api/v4/auth/token/mod.rs @@ -50,7 +50,7 @@ async fn authenticate_request( state .provider .get_identity_provider() - .authenticate_by_password(&state.db, &state.provider, req) + .authenticate_by_password(state, req) .await?, ); } @@ -67,7 +67,7 @@ async fn authenticate_request( state .provider .get_identity_provider() - .get_user(&state.db, &authz.user_id) + .get_user(state, &authz.user_id) .await .map(|x| { x.ok_or_else(|| KeystoneApiError::NotFound { @@ -196,12 +196,12 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_authenticate_by_password() - .withf(|_, _, req: &UserPasswordAuthRequest| { + .withf(|_, req: &UserPasswordAuthRequest| { req.id == Some("uid".to_string()) && req.password == "pwd" && req.name == Some("uname".to_string()) }) - .returning(move |_, _, _| Ok(auth_clone.clone())); + .returning(move |_, _| Ok(auth_clone.clone())); let provider = Provider::mocked_builder() .config(config.clone()) @@ -375,7 +375,7 @@ mod tests { let mut resource_mock = MockResourceProvider::default(); resource_mock .expect_get_domain() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "user_domain_id") + .withf(|_, id: &'_ str| id == "user_domain_id") .returning(|_, _| { Ok(Some(Domain { id: "user_domain_id".into(), @@ -474,7 +474,7 @@ mod tests { let mut resource_mock = MockResourceProvider::default(); resource_mock .expect_get_domain() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "user_domain_id") + .withf(|_, id: &'_ str| id == "user_domain_id") .returning(|_, _| { Ok(Some(Domain { id: "user_domain_id".into(), @@ -663,12 +663,12 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_authenticate_by_password() - .withf(|_, _, req: &UserPasswordAuthRequest| { + .withf(|_, req: &UserPasswordAuthRequest| { req.id == Some("uid".to_string()) && req.password == "pass" && req.name == Some("uname".to_string()) }) - .returning(|_, _, _| { + .returning(|_, _| { Ok(AuthenticatedInfo::builder() .user_id("uid") .user(UserResponse { @@ -684,15 +684,15 @@ mod tests { let mut resource_mock = MockResourceProvider::default(); resource_mock .expect_get_project() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "pid") + .withf(|_, id: &'_ str| id == "pid") .returning(move |_, _| Ok(Some(project.clone()))); resource_mock .expect_get_domain() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "user_domain_id") + .withf(|_, id: &'_ str| id == "user_domain_id") .returning(move |_, _| Ok(Some(user_domain.clone()))); resource_mock .expect_get_domain() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "pdid") + .withf(|_, id: &'_ str| id == "pdid") .returning(move |_, _| Ok(Some(project_domain.clone()))); let mut token_mock = MockTokenProvider::default(); token_mock.expect_issue_token().returning(|_, _, _| { @@ -819,7 +819,7 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_authenticate_by_password() - .returning(|_, _, _| { + .returning(|_, _| { Ok(AuthenticatedInfo::builder() .user_id("uid") .user(UserResponse { @@ -835,7 +835,7 @@ mod tests { let mut resource_mock = MockResourceProvider::default(); resource_mock .expect_get_project() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "pid") + .withf(|_, id: &'_ str| id == "pid") .returning(move |_, _| { Ok(Some(Project { id: "pid".into(), diff --git a/src/api/v4/federation/auth.rs b/src/api/v4/federation/auth.rs index bdcf7bdc..8ca225c9 100644 --- a/src/api/v4/federation/auth.rs +++ b/src/api/v4/federation/auth.rs @@ -88,7 +88,7 @@ pub async fn post( let idp = state .provider .get_federation_provider() - .get_identity_provider(&state.db, &idp_id) + .get_identity_provider(&state, &idp_id) .await .map(|x| { x.ok_or_else(|| KeystoneApiError::NotFound { @@ -101,7 +101,7 @@ pub async fn post( state .provider .get_federation_provider() - .get_mapping(&state.db, &mapping_id) + .get_mapping(&state, &mapping_id) .await .map(|x| { x.ok_or_else(|| KeystoneApiError::NotFound { @@ -114,7 +114,7 @@ pub async fn post( .provider .get_federation_provider() .list_mappings( - &state.db, + &state, &ProviderMappingListParameters { idp_id: Some(idp.id.clone()), name: Some(mapping_name.clone()), @@ -182,7 +182,7 @@ pub async fn post( .provider .get_federation_provider() .create_auth_state( - &state.db, + &state, AuthState { state: csrf_token.secret().clone(), nonce: nonce.secret().clone(), diff --git a/src/api/v4/federation/identity_provider/create.rs b/src/api/v4/federation/identity_provider/create.rs index fbba51fc..68ac597b 100644 --- a/src/api/v4/federation/identity_provider/create.rs +++ b/src/api/v4/federation/identity_provider/create.rs @@ -64,7 +64,7 @@ pub(super) async fn create( let res = state .provider .get_federation_provider() - .create_identity_provider(&state.db, req.into()) + .create_identity_provider(&state, req.into()) .await .map_err(KeystoneApiError::federation)?; Ok((StatusCode::CREATED, res).into_response()) @@ -77,7 +77,7 @@ mod tests { http::{Request, StatusCode, header}, }; use http_body_util::BodyExt; // for `collect` - use sea_orm::DatabaseConnection; + use tower::ServiceExt; // for `call`, `oneshot`, and `ready` use tower_http::trace::TraceLayer; use tracing_test::traced_test; @@ -95,9 +95,7 @@ mod tests { let mut federation_mock = MockFederationProvider::default(); federation_mock .expect_create_identity_provider() - .withf( - |_: &DatabaseConnection, req: &provider_types::IdentityProvider| req.name == "name", - ) + .withf(|_, req: &provider_types::IdentityProvider| req.name == "name") .returning(|_, _| { Ok(provider_types::IdentityProvider { id: "bar".into(), diff --git a/src/api/v4/federation/identity_provider/delete.rs b/src/api/v4/federation/identity_provider/delete.rs index b75edd29..ac14bdb6 100644 --- a/src/api/v4/federation/identity_provider/delete.rs +++ b/src/api/v4/federation/identity_provider/delete.rs @@ -61,7 +61,7 @@ pub(super) async fn remove( let current = state .provider .get_federation_provider() - .get_identity_provider(&state.db, &id) + .get_identity_provider(&state, &id) .await?; policy @@ -79,7 +79,7 @@ pub(super) async fn remove( state .provider .get_federation_provider() - .delete_identity_provider(&state.db, &id) + .delete_identity_provider(&state, &id) .await .map_err(KeystoneApiError::federation)?; } else { @@ -98,7 +98,6 @@ mod tests { http::{Request, StatusCode}, }; // for `collect` - use sea_orm::DatabaseConnection; use tower::ServiceExt; // for `call`, `oneshot`, and `ready` use tower_http::trace::TraceLayer; @@ -116,11 +115,11 @@ mod tests { let mut federation_mock = MockFederationProvider::default(); federation_mock .expect_get_identity_provider() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "foo") + .withf(|_, id: &'_ str| id == "foo") .returning(|_, _| Ok(None)); federation_mock .expect_get_identity_provider() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") + .withf(|_, id: &'_ str| id == "bar") .returning(|_, _| { Ok(Some(provider_types::IdentityProvider { id: "bar".into(), @@ -131,7 +130,7 @@ mod tests { }); federation_mock .expect_delete_identity_provider() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "foo") + .withf(|_, id: &'_ str| id == "foo") .returning(|_, _| { Err(FederationProviderError::IdentityProviderNotFound( "foo".into(), @@ -140,7 +139,7 @@ mod tests { federation_mock .expect_delete_identity_provider() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") + .withf(|_, id: &'_ str| id == "bar") .returning(|_, _| Ok(())); let state = get_mocked_state(federation_mock, true, None); diff --git a/src/api/v4/federation/identity_provider/list.rs b/src/api/v4/federation/identity_provider/list.rs index 63e7674b..bae5944e 100644 --- a/src/api/v4/federation/identity_provider/list.rs +++ b/src/api/v4/federation/identity_provider/list.rs @@ -91,7 +91,7 @@ pub(super) async fn list( let identity_providers: Vec = state .provider .get_federation_provider() - .list_identity_providers(&state.db, &provider_list_params) + .list_identity_providers(&state, &provider_list_params) .await .map_err(KeystoneApiError::federation)? .into_iter() @@ -107,7 +107,7 @@ mod tests { http::{Request, StatusCode}, }; use http_body_util::BodyExt; // for `collect` - use sea_orm::DatabaseConnection; + use std::collections::HashSet; use tower::ServiceExt; // for `call`, `oneshot`, and `ready` @@ -126,9 +126,7 @@ mod tests { let mut federation_mock = MockFederationProvider::default(); federation_mock .expect_list_identity_providers() - .withf( - |_: &DatabaseConnection, _: &provider_types::IdentityProviderListParameters| true, - ) + .withf(|_, _: &provider_types::IdentityProviderListParameters| true) .returning(|_, _| { Ok(vec![provider_types::IdentityProvider { id: "id".into(), @@ -185,14 +183,12 @@ mod tests { let mut federation_mock = MockFederationProvider::default(); federation_mock .expect_list_identity_providers() - .withf( - |_: &DatabaseConnection, qp: &provider_types::IdentityProviderListParameters| { - provider_types::IdentityProviderListParameters { - name: Some("name".into()), - domain_ids: Some(HashSet::from([Some("did".into())])), - } == *qp - }, - ) + .withf(|_, qp: &provider_types::IdentityProviderListParameters| { + provider_types::IdentityProviderListParameters { + name: Some("name".into()), + domain_ids: Some(HashSet::from([Some("did".into())])), + } == *qp + }) .returning(|_, _| { Ok(vec![provider_types::IdentityProvider { id: "id".into(), @@ -257,14 +253,12 @@ mod tests { let mut federation_mock = MockFederationProvider::default(); federation_mock .expect_list_identity_providers() - .withf( - |_: &DatabaseConnection, qp: &provider_types::IdentityProviderListParameters| { - provider_types::IdentityProviderListParameters { - name: Some("name".into()), - domain_ids: Some(HashSet::from([None, Some("udid".into())])), - } == *qp - }, - ) + .withf(|_, qp: &provider_types::IdentityProviderListParameters| { + provider_types::IdentityProviderListParameters { + name: Some("name".into()), + domain_ids: Some(HashSet::from([None, Some("udid".into())])), + } == *qp + }) .returning(|_, _| { Ok(vec![provider_types::IdentityProvider { id: "id".into(), @@ -306,14 +300,12 @@ mod tests { let mut federation_mock = MockFederationProvider::default(); federation_mock .expect_list_identity_providers() - .withf( - |_: &DatabaseConnection, qp: &provider_types::IdentityProviderListParameters| { - provider_types::IdentityProviderListParameters { - name: Some("name".into()), - domain_ids: None, - } == *qp - }, - ) + .withf(|_, qp: &provider_types::IdentityProviderListParameters| { + provider_types::IdentityProviderListParameters { + name: Some("name".into()), + domain_ids: None, + } == *qp + }) .returning(|_, _| { Ok(vec![provider_types::IdentityProvider { id: "id".into(), diff --git a/src/api/v4/federation/identity_provider/show.rs b/src/api/v4/federation/identity_provider/show.rs index ca5dcf5d..f8b87962 100644 --- a/src/api/v4/federation/identity_provider/show.rs +++ b/src/api/v4/federation/identity_provider/show.rs @@ -59,7 +59,7 @@ pub(super) async fn show( let current = state .provider .get_federation_provider() - .get_identity_provider(&state.db, &idp_id) + .get_identity_provider(&state, &idp_id) .await .map(|x| { x.ok_or_else(|| KeystoneApiError::NotFound { @@ -86,7 +86,6 @@ mod tests { http::{Request, StatusCode}, }; use http_body_util::BodyExt; // for `collect` - use sea_orm::DatabaseConnection; use tower::ServiceExt; // for `call`, `oneshot`, and `ready` use tower_http::trace::TraceLayer; @@ -105,12 +104,12 @@ mod tests { let mut federation_mock = MockFederationProvider::default(); federation_mock .expect_get_identity_provider() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "foo") + .withf(|_, id: &'_ str| id == "foo") .returning(|_, _| Ok(None)); federation_mock .expect_get_identity_provider() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") + .withf(|_, id: &'_ str| id == "bar") .returning(|_, _| { Ok(Some(provider_types::IdentityProvider { id: "bar".into(), @@ -182,7 +181,7 @@ mod tests { let mut federation_mock = MockFederationProvider::default(); federation_mock .expect_get_identity_provider() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") + .withf(|_, id: &'_ str| id == "bar") .returning(|_, _| { Ok(Some(provider_types::IdentityProvider { id: "bar".into(), diff --git a/src/api/v4/federation/identity_provider/update.rs b/src/api/v4/federation/identity_provider/update.rs index 96d45088..c0426bbc 100644 --- a/src/api/v4/federation/identity_provider/update.rs +++ b/src/api/v4/federation/identity_provider/update.rs @@ -62,7 +62,7 @@ pub(super) async fn update( let current = state .provider .get_federation_provider() - .get_identity_provider(&state.db, &idp_id) + .get_identity_provider(&state, &idp_id) .await?; policy @@ -77,7 +77,7 @@ pub(super) async fn update( let res = state .provider .get_federation_provider() - .update_identity_provider(&state.db, &idp_id, req.into()) + .update_identity_provider(&state, &idp_id, req.into()) .await .map_err(KeystoneApiError::federation)?; Ok(res.into_response()) @@ -90,7 +90,7 @@ mod tests { http::{Request, StatusCode, header}, }; use http_body_util::BodyExt; // for `collect` - use sea_orm::DatabaseConnection; + use tower::ServiceExt; // for `call`, `oneshot`, and `ready` use tower_http::trace::TraceLayer; use tracing_test::traced_test; @@ -108,7 +108,7 @@ mod tests { let mut federation_mock = MockFederationProvider::default(); federation_mock .expect_get_identity_provider() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "1") + .withf(|_, id: &'_ str| id == "1") .returning(|_, _| { Ok(Some(provider_types::IdentityProvider { id: "bar".into(), @@ -120,9 +120,7 @@ mod tests { federation_mock .expect_update_identity_provider() .withf( - |_: &DatabaseConnection, - id: &'_ str, - req: &provider_types::IdentityProviderUpdate| { + |_, id: &'_ str, req: &provider_types::IdentityProviderUpdate| { id == "1" && req.name == Some("name".to_string()) }, ) diff --git a/src/api/v4/federation/jwt.rs b/src/api/v4/federation/jwt.rs index b27b6c45..fd1e9325 100644 --- a/src/api/v4/federation/jwt.rs +++ b/src/api/v4/federation/jwt.rs @@ -132,7 +132,7 @@ pub async fn login( let idp = state .provider .get_federation_provider() - .get_identity_provider(&state.db, &idp_id) + .get_identity_provider(&state, &idp_id) .await .map(|x| { x.ok_or_else(|| KeystoneApiError::NotFound { @@ -145,7 +145,7 @@ pub async fn login( .provider .get_federation_provider() .list_mappings( - &state.db, + &state, &ProviderMappingListParameters { idp_id: Some(idp_id.clone()), name: Some(mapping.clone()), @@ -245,7 +245,7 @@ pub async fn login( let user = if let Some(existing_user) = state .provider .get_identity_provider() - .find_federated_user(&state.db, &idp.id, &mapped_user_data.unique_id) + .find_federated_user(&state, &idp.id, &mapped_user_data.unique_id) .await? { // The user exists already @@ -274,7 +274,7 @@ pub async fn login( .provider .get_identity_provider() .create_user( - &state.db, + &state, user_builder.build().map_err(IdentityProviderError::from)?, ) .await? @@ -326,7 +326,7 @@ pub async fn login( let catalog: Catalog = state .provider .get_catalog_provider() - .get_catalog(&state.db, true) + .get_catalog(&state, true) .await? .into(); api_token.token.catalog = Some(catalog); diff --git a/src/api/v4/federation/mapping/create.rs b/src/api/v4/federation/mapping/create.rs index 2f26ab09..723394c2 100644 --- a/src/api/v4/federation/mapping/create.rs +++ b/src/api/v4/federation/mapping/create.rs @@ -59,7 +59,7 @@ pub(super) async fn create( let res = state .provider .get_federation_provider() - .create_mapping(&state.db, req.into()) + .create_mapping(&state, req.into()) .await .map_err(KeystoneApiError::federation)?; Ok((StatusCode::CREATED, res).into_response()) @@ -72,7 +72,6 @@ mod tests { http::{Request, StatusCode, header}, }; use http_body_util::BodyExt; // for `collect` - use sea_orm::DatabaseConnection; use tower::ServiceExt; // for `call`, `oneshot`, and `ready` use tower_http::trace::TraceLayer; @@ -91,7 +90,7 @@ mod tests { let mut federation_mock = MockFederationProvider::default(); federation_mock .expect_create_mapping() - .withf(|_: &DatabaseConnection, req: &provider_types::Mapping| req.name == "name") + .withf(|_, req: &provider_types::Mapping| req.name == "name") .returning(|_, _| { Ok(provider_types::Mapping { id: "bar".into(), diff --git a/src/api/v4/federation/mapping/delete.rs b/src/api/v4/federation/mapping/delete.rs index 88007332..e9da3f18 100644 --- a/src/api/v4/federation/mapping/delete.rs +++ b/src/api/v4/federation/mapping/delete.rs @@ -57,7 +57,7 @@ pub(super) async fn remove( let current = state .provider .get_federation_provider() - .get_mapping(&state.db, &id) + .get_mapping(&state, &id) .await?; policy @@ -72,7 +72,7 @@ pub(super) async fn remove( state .provider .get_federation_provider() - .delete_mapping(&state.db, &id) + .delete_mapping(&state, &id) .await .map_err(KeystoneApiError::federation)?; } else { @@ -91,7 +91,6 @@ mod tests { http::{Request, StatusCode}, }; use http_body_util::BodyExt; // for `collect` - use sea_orm::DatabaseConnection; use tower::ServiceExt; // for `call`, `oneshot`, and `ready` use tower_http::trace::TraceLayer; @@ -106,11 +105,11 @@ mod tests { let mut federation_mock = MockFederationProvider::default(); federation_mock .expect_get_mapping() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "foo") + .withf(|_, id: &'_ str| id == "foo") .returning(|_, _| Ok(None)); federation_mock .expect_get_mapping() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") + .withf(|_, id: &'_ str| id == "bar") .returning(|_, _| { Ok(Some(provider_types::Mapping { id: "bar".into(), @@ -121,7 +120,7 @@ mod tests { }); federation_mock .expect_delete_mapping() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") + .withf(|_, id: &'_ str| id == "bar") .returning(|_, _| Ok(())); let state = get_mocked_state(federation_mock, true); diff --git a/src/api/v4/federation/mapping/list.rs b/src/api/v4/federation/mapping/list.rs index c02d70dd..2c1a609d 100644 --- a/src/api/v4/federation/mapping/list.rs +++ b/src/api/v4/federation/mapping/list.rs @@ -67,7 +67,7 @@ pub(super) async fn list( let mappings: Vec = state .provider .get_federation_provider() - .list_mappings(&state.db, &query.try_into()?) + .list_mappings(&state, &query.try_into()?) .await .map_err(KeystoneApiError::federation)? .into_iter() @@ -83,7 +83,6 @@ mod tests { http::{Request, StatusCode}, }; use http_body_util::BodyExt; // for `collect` - use sea_orm::DatabaseConnection; use tower::ServiceExt; // for `call`, `oneshot`, and `ready` use tower_http::trace::TraceLayer; @@ -101,7 +100,7 @@ mod tests { let mut federation_mock = MockFederationProvider::default(); federation_mock .expect_list_mappings() - .withf(|_: &DatabaseConnection, _: &provider_types::MappingListParameters| true) + .withf(|_, _: &provider_types::MappingListParameters| true) .returning(|_, _| { Ok(vec![provider_types::Mapping { id: "id".into(), @@ -166,16 +165,14 @@ mod tests { let mut federation_mock = MockFederationProvider::default(); federation_mock .expect_list_mappings() - .withf( - |_: &DatabaseConnection, qp: &provider_types::MappingListParameters| { - provider_types::MappingListParameters { - name: Some("name".into()), - domain_id: Some("did".into()), - idp_id: Some("idp".into()), - ..Default::default() - } == *qp - }, - ) + .withf(|_, qp: &provider_types::MappingListParameters| { + provider_types::MappingListParameters { + name: Some("name".into()), + domain_id: Some("did".into()), + idp_id: Some("idp".into()), + ..Default::default() + } == *qp + }) .returning(|_, _| { Ok(vec![provider_types::Mapping { id: "id".into(), diff --git a/src/api/v4/federation/mapping/show.rs b/src/api/v4/federation/mapping/show.rs index 3bea05ec..27aaf8b3 100644 --- a/src/api/v4/federation/mapping/show.rs +++ b/src/api/v4/federation/mapping/show.rs @@ -60,7 +60,7 @@ pub(super) async fn show( let current = state .provider .get_federation_provider() - .get_mapping(&state.db, &id) + .get_mapping(&state, &id) .await .map(|x| { x.ok_or_else(|| KeystoneApiError::NotFound { @@ -87,7 +87,6 @@ mod tests { http::{Request, StatusCode}, }; use http_body_util::BodyExt; // for `collect` - use sea_orm::DatabaseConnection; use tower::ServiceExt; // for `call`, `oneshot`, and `ready` use tower_http::trace::TraceLayer; @@ -105,12 +104,12 @@ mod tests { let mut federation_mock = MockFederationProvider::default(); federation_mock .expect_get_mapping() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "foo") + .withf(|_, id: &'_ str| id == "foo") .returning(|_, _| Ok(None)); federation_mock .expect_get_mapping() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") + .withf(|_, id: &'_ str| id == "bar") .returning(|_, _| { Ok(Some(provider_types::Mapping { id: "bar".into(), diff --git a/src/api/v4/federation/mapping/update.rs b/src/api/v4/federation/mapping/update.rs index eaa38810..40903b3a 100644 --- a/src/api/v4/federation/mapping/update.rs +++ b/src/api/v4/federation/mapping/update.rs @@ -61,7 +61,7 @@ pub(super) async fn update( let current = state .provider .get_federation_provider() - .get_mapping(&state.db, &id) + .get_mapping(&state, &id) .await?; policy @@ -76,7 +76,7 @@ pub(super) async fn update( let res = state .provider .get_federation_provider() - .update_mapping(&state.db, &id, req.into()) + .update_mapping(&state, &id, req.into()) .await .map_err(KeystoneApiError::federation)?; Ok(res.into_response()) @@ -89,7 +89,6 @@ mod tests { http::{Request, StatusCode, header}, }; use http_body_util::BodyExt; // for `collect` - use sea_orm::DatabaseConnection; use tower::ServiceExt; // for `call`, `oneshot`, and `ready` use tower_http::trace::TraceLayer; @@ -107,7 +106,7 @@ mod tests { let mut federation_mock = MockFederationProvider::default(); federation_mock .expect_get_mapping() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "1") + .withf(|_, id: &'_ str| id == "1") .returning(|_, _| { Ok(Some(provider_types::Mapping { id: "bar".into(), @@ -119,11 +118,9 @@ mod tests { federation_mock .expect_update_mapping() - .withf( - |_: &DatabaseConnection, id: &'_ str, req: &provider_types::MappingUpdate| { - id == "1" && req.name == Some("name".to_string()) - }, - ) + .withf(|_, id: &'_ str, req: &provider_types::MappingUpdate| { + id == "1" && req.name == Some("name".to_string()) + }) .returning(|_, _, _| { Ok(provider_types::Mapping { id: "bar".into(), diff --git a/src/api/v4/federation/oidc.rs b/src/api/v4/federation/oidc.rs index b351cc03..c0c08235 100644 --- a/src/api/v4/federation/oidc.rs +++ b/src/api/v4/federation/oidc.rs @@ -84,7 +84,7 @@ pub async fn callback( let auth_state = state .provider .get_federation_provider() - .get_auth_state(&state.db, &query.state) + .get_auth_state(&state, &query.state) .await? .ok_or_else(|| KeystoneApiError::NotFound { resource: "auth state".into(), @@ -98,7 +98,7 @@ pub async fn callback( let idp = state .provider .get_federation_provider() - .get_identity_provider(&state.db, &auth_state.idp_id) + .get_identity_provider(&state, &auth_state.idp_id) .await .map(|x| { x.ok_or_else(|| KeystoneApiError::NotFound { @@ -110,7 +110,7 @@ pub async fn callback( let mapping = state .provider .get_federation_provider() - .get_mapping(&state.db, &auth_state.mapping_id) + .get_mapping(&state, &auth_state.mapping_id) .await .map(|x| { x.ok_or_else(|| KeystoneApiError::NotFound { @@ -190,7 +190,7 @@ pub async fn callback( let user = if let Some(existing_user) = state .provider .get_identity_provider() - .find_federated_user(&state.db, &idp.id, &mapped_user_data.unique_id) + .find_federated_user(&state, &idp.id, &mapped_user_data.unique_id) .await? { // The user exists already @@ -219,7 +219,7 @@ pub async fn callback( .provider .get_identity_provider() .create_user( - &state.db, + &state, user_builder.build().map_err(IdentityProviderError::from)?, ) .await? @@ -231,7 +231,7 @@ pub async fn callback( .provider .get_identity_provider() .list_groups( - &state.db, + &state, &GroupListParameters { domain_id: Some(user.domain_id.clone()), ..Default::default() @@ -251,7 +251,7 @@ pub async fn callback( .provider .get_identity_provider() .create_group( - &state.db, + &state, GroupCreate { domain_id: user.domain_id.clone(), name: group_name.clone(), @@ -268,7 +268,7 @@ pub async fn callback( .provider .get_identity_provider() .set_user_groups( - &state.db, + &state, &user.id, HashSet::from_iter(group_ids.iter().map(|i| i.as_str())), ) @@ -279,7 +279,7 @@ pub async fn callback( state .provider .get_identity_provider() - .list_groups_of_user(&state.db, &user.id) + .list_groups_of_user(&state, &user.id) .await?, ); @@ -316,7 +316,7 @@ pub async fn callback( let catalog: Catalog = state .provider .get_catalog_provider() - .get_catalog(&state.db, true) + .get_catalog(&state, true) .await? .into(); api_token.token.catalog = Some(catalog); diff --git a/src/api/v4/group/mod.rs b/src/api/v4/group/mod.rs index eb406b44..d45218b9 100644 --- a/src/api/v4/group/mod.rs +++ b/src/api/v4/group/mod.rs @@ -29,7 +29,7 @@ mod tests { http::{Request, StatusCode, header}, }; use http_body_util::BodyExt; // for `collect` - use sea_orm::DatabaseConnection; + use serde_json::json; use tower::ServiceExt; // for `call`, `oneshot`, and `ready` @@ -53,7 +53,7 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_list_groups() - .withf(|_: &DatabaseConnection, _: &GroupListParameters| true) + .withf(|_, _: &GroupListParameters| true) .returning(|_, _| { Ok(vec![Group { id: "1".into(), @@ -102,7 +102,7 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_list_groups() - .withf(|_: &DatabaseConnection, qp: &GroupListParameters| { + .withf(|_, qp: &GroupListParameters| { GroupListParameters { domain_id: Some("domain".into()), name: Some("name".into()), @@ -156,12 +156,12 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_get_group() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "foo") + .withf(|_, id: &'_ str| id == "foo") .returning(|_, _| Ok(None)); identity_mock .expect_get_group() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") + .withf(|_, id: &'_ str| id == "bar") .returning(|_, _| { Ok(Some(Group { id: "bar".into(), @@ -220,9 +220,7 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_create_group() - .withf(|_: &DatabaseConnection, req: &GroupCreate| { - req.domain_id == "domain" && req.name == "name" - }) + .withf(|_, req: &GroupCreate| req.domain_id == "domain" && req.name == "name") .returning(|_, req| { Ok(Group { id: "bar".into(), @@ -273,12 +271,12 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_delete_group() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "foo") + .withf(|_, id: &'_ str| id == "foo") .returning(|_, _| Err(IdentityProviderError::GroupNotFound("foo".into()))); identity_mock .expect_delete_group() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") + .withf(|_, id: &'_ str| id == "bar") .returning(|_, _| Ok(())); let state = get_mocked_state(identity_mock); diff --git a/src/api/v4/user/mod.rs b/src/api/v4/user/mod.rs index 6c6992a8..f6048aed 100644 --- a/src/api/v4/user/mod.rs +++ b/src/api/v4/user/mod.rs @@ -59,7 +59,7 @@ async fn list( let users: Vec = state .provider .get_identity_provider() - .list_users(&state.db, &query.into()) + .list_users(&state, &query.into()) .await .map_err(KeystoneApiError::identity)? .into_iter() @@ -88,7 +88,7 @@ async fn show( state .provider .get_identity_provider() - .get_user(&state.db, &user_id) + .get_user(&state, &user_id) .await .map(|x| { x.ok_or_else(|| KeystoneApiError::NotFound { @@ -118,7 +118,7 @@ async fn create( let user = state .provider .get_identity_provider() - .create_user(&state.db, req.into()) + .create_user(&state, req.into()) .await .map_err(KeystoneApiError::identity)?; Ok((StatusCode::CREATED, user).into_response()) @@ -145,7 +145,7 @@ async fn remove( state .provider .get_identity_provider() - .delete_user(&state.db, &user_id) + .delete_user(&state, &user_id) .await .map_err(KeystoneApiError::identity)?; Ok((StatusCode::NO_CONTENT).into_response()) @@ -171,7 +171,7 @@ async fn groups( let groups: Vec = state .provider .get_identity_provider() - .list_groups_of_user(&state.db, &user_id) + .list_groups_of_user(&state, &user_id) .await .map_err(KeystoneApiError::identity)? .into_iter() @@ -187,7 +187,7 @@ mod tests { http::{self, Request, StatusCode}, }; use http_body_util::BodyExt; // for `collect` - use sea_orm::DatabaseConnection; + use serde_json::json; use tower::ServiceExt; // for `call`, `oneshot`, and `ready` @@ -212,7 +212,7 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_list_users() - .withf(|_: &DatabaseConnection, _: &UserListParameters| true) + .withf(|_, _: &UserListParameters| true) .returning(|_, _| { Ok(vec![UserResponse { id: "1".into(), @@ -260,7 +260,7 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_list_users() - .withf(|_: &DatabaseConnection, qp: &UserListParameters| { + .withf(|_, qp: &UserListParameters| { UserListParameters { domain_id: Some("domain".into()), name: Some("name".into()), @@ -314,9 +314,7 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_create_user() - .withf(|_: &DatabaseConnection, req: &UserCreate| { - req.domain_id == "domain" && req.name == "name" - }) + .withf(|_, req: &UserCreate| req.domain_id == "domain" && req.name == "name") .returning(|_, req| { Ok(UserResponse { id: "bar".into(), @@ -366,12 +364,12 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_get_user() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "foo") + .withf(|_, id: &'_ str| id == "foo") .returning(|_, _| Ok(None)); identity_mock .expect_get_user() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") + .withf(|_, id: &'_ str| id == "bar") .returning(|_, _| { Ok(Some(UserResponse { id: "bar".into(), @@ -430,12 +428,12 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_delete_user() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "foo") + .withf(|_, id: &'_ str| id == "foo") .returning(|_, _| Err(IdentityProviderError::UserNotFound("foo".into()))); identity_mock .expect_delete_user() - .withf(|_: &DatabaseConnection, id: &'_ str| id == "bar") + .withf(|_, id: &'_ str| id == "bar") .returning(|_, _| Ok(())); let state = get_mocked_state(identity_mock); @@ -480,7 +478,7 @@ mod tests { let mut identity_mock = MockIdentityProvider::default(); identity_mock .expect_list_groups_of_user() - .withf(|_: &DatabaseConnection, uid: &str| uid == "foo") + .withf(|_, uid: &str| uid == "foo") .returning(|_, _| { Ok(vec![Group { id: "1".into(), diff --git a/src/api/v4/user/passkey/register_finish.rs b/src/api/v4/user/passkey/register_finish.rs index 82e7dd7d..4d9d686d 100644 --- a/src/api/v4/user/passkey/register_finish.rs +++ b/src/api/v4/user/passkey/register_finish.rs @@ -64,7 +64,7 @@ pub(super) async fn finish( let user = state .provider .get_identity_provider() - .get_user(&state.db, &user_id) + .get_user(&state, &user_id) .await .map(|x| { x.ok_or_else(|| KeystoneApiError::NotFound { @@ -85,7 +85,7 @@ pub(super) async fn finish( if let Some(s) = state .provider .get_identity_provider() - .get_user_webauthn_credential_registration_state(&state.db, &user_id) + .get_user_webauthn_credential_registration_state(&state, &user_id) .await? { let credential_description = req.description.clone(); @@ -98,7 +98,7 @@ pub(super) async fn finish( .provider .get_identity_provider() .create_user_webauthn_credential( - &state.db, + &state, &user_id, &sk, credential_description.as_deref(), @@ -113,7 +113,7 @@ pub(super) async fn finish( state .provider .get_identity_provider() - .delete_user_webauthn_credential_registration_state(&state.db, &user_id) + .delete_user_webauthn_credential_registration_state(&state, &user_id) .await?; Ok((StatusCode::CREATED, Json(PasskeyResponse::from(passkey))).into_response()) } else { diff --git a/src/api/v4/user/passkey/register_start.rs b/src/api/v4/user/passkey/register_start.rs index 00f12396..b37c40a3 100644 --- a/src/api/v4/user/passkey/register_start.rs +++ b/src/api/v4/user/passkey/register_start.rs @@ -72,7 +72,7 @@ pub(super) async fn start( let user = state .provider .get_identity_provider() - .get_user(&state.db, &user_id) + .get_user(&state, &user_id) .await .map(|x| { x.ok_or_else(|| KeystoneApiError::NotFound { @@ -93,7 +93,7 @@ pub(super) async fn start( state .provider .get_identity_provider() - .delete_user_webauthn_credential_registration_state(&state.db, &user_id) + .delete_user_webauthn_credential_registration_state(&state, &user_id) .await?; let res = match state.webauthn.start_passkey_registration( Uuid::parse_str(&user_id)?, @@ -107,7 +107,7 @@ pub(super) async fn start( state .provider .get_identity_provider() - .save_user_webauthn_credential_registration_state(&state.db, &user_id, reg_state) + .save_user_webauthn_credential_registration_state(&state, &user_id, reg_state) .await?; Json(UserPasskeyRegistrationStartResponse::try_from(ccr)?) } diff --git a/src/assignment/mod.rs b/src/assignment/mod.rs index 7b71ee9b..bf4d4456 100644 --- a/src/assignment/mod.rs +++ b/src/assignment/mod.rs @@ -160,7 +160,7 @@ impl AssignmentApi for AssignmentProvider { let users = state .provider .get_identity_provider() - .list_groups_of_user(&state.db, uid) + .list_groups_of_user(state, uid) .await?; actors.extend(users.into_iter().map(|x| x.id)); }; diff --git a/src/bin/keystone.rs b/src/bin/keystone.rs index 5be5f545..e686a4b2 100644 --- a/src/bin/keystone.rs +++ b/src/bin/keystone.rs @@ -242,7 +242,7 @@ async fn cleanup(cancel: CancellationToken, state: ServiceState) { tokio::select! { _ = interval.tick() => { trace!("cleanup job tick"); - if let Err(e) = state.provider.get_federation_provider().cleanup(&state.db).await { + if let Err(e) = state.provider.get_federation_provider().cleanup(&state).await { error!("Error during cleanup job: {}", e); } }, diff --git a/src/catalog/backends/sql.rs b/src/catalog/backends/sql.rs index 1ea4dff6..f1dc88ef 100644 --- a/src/catalog/backends/sql.rs +++ b/src/catalog/backends/sql.rs @@ -26,6 +26,7 @@ use crate::db::entity::{ prelude::{Endpoint as DbEndpoint, Service as DbService}, service as db_service, }; +use crate::keystone::ServiceState; mod endpoint; mod service; @@ -45,53 +46,53 @@ impl CatalogBackend for SqlBackend { } /// List Services - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn list_services( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &ServiceListParameters, ) -> Result, CatalogProviderError> { - Ok(service::list(&self.config, db, params).await?) + Ok(service::list(&state.db, params).await?) } /// Get single service by ID - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn get_service<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result, CatalogProviderError> { - Ok(service::get(&self.config, db, id).await?) + Ok(service::get(&state.db, id).await?) } /// List Endpoints - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn list_endpoints( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &EndpointListParameters, ) -> Result, CatalogProviderError> { - Ok(endpoint::list(&self.config, db, params).await?) + Ok(endpoint::list(&state.db, params).await?) } /// Get single endpoint by ID - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn get_endpoint<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result, CatalogProviderError> { - Ok(endpoint::get(&self.config, db, id).await?) + Ok(endpoint::get(&state.db, id).await?) } /// Get Catalog (Services with Endpoints) - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn get_catalog( &self, - db: &DatabaseConnection, + state: &ServiceState, enabled: bool, ) -> Result)>, CatalogProviderError> { - Ok(get_catalog(db, enabled).await?) + Ok(get_catalog(&state.db, enabled).await?) } } diff --git a/src/catalog/backends/sql/endpoint/get.rs b/src/catalog/backends/sql/endpoint/get.rs index 3e62fa0c..1ce56cdc 100644 --- a/src/catalog/backends/sql/endpoint/get.rs +++ b/src/catalog/backends/sql/endpoint/get.rs @@ -17,11 +17,9 @@ use sea_orm::entity::*; use crate::catalog::backends::error::{CatalogDatabaseError, db_err}; use crate::catalog::types::*; -use crate::config::Config; use crate::db::entity::{endpoint as db_endpoint, prelude::Endpoint as DbEndpoint}; pub async fn get>( - _conf: &Config, db: &DatabaseConnection, id: I, ) -> Result, CatalogDatabaseError> { @@ -38,8 +36,6 @@ pub async fn get>( mod tests { use sea_orm::{DatabaseBackend, MockDatabase, Transaction}; - use crate::config::Config; - use super::super::tests::get_endpoint_mock; use super::*; @@ -52,9 +48,8 @@ mod tests { vec![get_endpoint_mock("1".into())], ]) .into_connection(); - let config = Config::default(); assert_eq!( - get(&config, &db, "1").await.unwrap().unwrap(), + get(&db, "1").await.unwrap().unwrap(), Endpoint { id: "1".into(), interface: "public".into(), diff --git a/src/catalog/backends/sql/endpoint/list.rs b/src/catalog/backends/sql/endpoint/list.rs index babee5ac..84735c28 100644 --- a/src/catalog/backends/sql/endpoint/list.rs +++ b/src/catalog/backends/sql/endpoint/list.rs @@ -18,11 +18,9 @@ use sea_orm::query::*; use crate::catalog::backends::error::{CatalogDatabaseError, db_err}; use crate::catalog::types::*; -use crate::config::Config; use crate::db::entity::{endpoint as db_endpoint, prelude::Endpoint as DbEndpoint}; pub async fn list( - _conf: &Config, db: &DatabaseConnection, params: &EndpointListParameters, ) -> Result, CatalogDatabaseError> { @@ -54,8 +52,6 @@ pub async fn list( mod tests { use sea_orm::{DatabaseBackend, MockDatabase, Transaction}; - use crate::config::Config; - use super::super::tests::get_endpoint_mock; use super::*; @@ -66,15 +62,9 @@ mod tests { .append_query_results([vec![get_endpoint_mock("1".into())]]) .append_query_results([vec![get_endpoint_mock("1".into())]]) .into_connection(); - let config = Config::default(); - assert!( - list(&config, &db, &EndpointListParameters::default()) - .await - .is_ok() - ); + assert!(list(&db, &EndpointListParameters::default()).await.is_ok()); assert_eq!( list( - &config, &db, &EndpointListParameters { interface: Some("public".into()), diff --git a/src/catalog/backends/sql/service/get.rs b/src/catalog/backends/sql/service/get.rs index 182aaace..aff5f6fd 100644 --- a/src/catalog/backends/sql/service/get.rs +++ b/src/catalog/backends/sql/service/get.rs @@ -16,11 +16,9 @@ use sea_orm::entity::*; use crate::catalog::backends::error::{CatalogDatabaseError, db_err}; use crate::catalog::types::*; -use crate::config::Config; use crate::db::entity::{prelude::Service as DbService, service as db_service}; pub async fn get>( - _conf: &Config, db: &DatabaseConnection, id: I, ) -> Result, CatalogDatabaseError> { @@ -38,8 +36,6 @@ mod tests { use sea_orm::{DatabaseBackend, MockDatabase, Transaction}; use serde_json::json; - use crate::config::Config; - use super::super::tests::get_service_mock; use super::*; @@ -52,9 +48,8 @@ mod tests { vec![get_service_mock("1".into())], ]) .into_connection(); - let config = Config::default(); assert_eq!( - get(&config, &db, "1").await.unwrap().unwrap(), + get(&db, "1").await.unwrap().unwrap(), Service { id: "1".into(), r#type: Some("type".into()), diff --git a/src/catalog/backends/sql/service/list.rs b/src/catalog/backends/sql/service/list.rs index d9ee5eaf..72fd24ff 100644 --- a/src/catalog/backends/sql/service/list.rs +++ b/src/catalog/backends/sql/service/list.rs @@ -18,11 +18,9 @@ use sea_orm::query::*; use crate::catalog::backends::error::{CatalogDatabaseError, db_err}; use crate::catalog::types::*; -use crate::config::Config; use crate::db::entity::{prelude::Service as DbService, service as db_service}; pub async fn list( - _conf: &Config, db: &DatabaseConnection, params: &ServiceListParameters, ) -> Result, CatalogDatabaseError> { @@ -49,8 +47,6 @@ mod tests { use sea_orm::{DatabaseBackend, MockDatabase, Transaction}; use serde_json::json; - use crate::config::Config; - use super::super::tests::get_service_mock; use super::*; @@ -61,15 +57,9 @@ mod tests { .append_query_results([vec![get_service_mock("1".into())]]) .append_query_results([vec![get_service_mock("1".into())]]) .into_connection(); - let config = Config::default(); - assert!( - list(&config, &db, &ServiceListParameters::default()) - .await - .is_ok() - ); + assert!(list(&db, &ServiceListParameters::default()).await.is_ok()); assert_eq!( list( - &config, &db, &ServiceListParameters { r#type: Some("type".into()), diff --git a/src/catalog/mod.rs b/src/catalog/mod.rs index e471b410..8cbe80c8 100644 --- a/src/catalog/mod.rs +++ b/src/catalog/mod.rs @@ -15,7 +15,6 @@ use async_trait::async_trait; #[cfg(test)] use mockall::mock; -use sea_orm::DatabaseConnection; pub mod backends; pub mod error; @@ -27,6 +26,7 @@ use crate::catalog::types::{ CatalogBackend, Endpoint, EndpointListParameters, Service, ServiceListParameters, }; use crate::config::Config; +use crate::keystone::ServiceState; use crate::plugin_manager::PluginManager; #[derive(Clone, Debug)] @@ -38,31 +38,31 @@ pub struct CatalogProvider { pub trait CatalogApi: Send + Sync + Clone { async fn list_services( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &ServiceListParameters, ) -> Result, CatalogProviderError>; async fn get_service<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result, CatalogProviderError>; async fn list_endpoints( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &EndpointListParameters, ) -> Result, CatalogProviderError>; async fn get_endpoint<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result, CatalogProviderError>; async fn get_catalog( &self, - db: &DatabaseConnection, + state: &ServiceState, enabled: bool, ) -> Result)>, CatalogProviderError>; } @@ -77,31 +77,31 @@ mock! { impl CatalogApi for CatalogProvider { async fn list_services( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &ServiceListParameters ) -> Result, CatalogProviderError>; async fn get_service<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result, CatalogProviderError>; async fn list_endpoints( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &EndpointListParameters, ) -> Result, CatalogProviderError>; async fn get_endpoint<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result, CatalogProviderError>; async fn get_catalog( &self, - db: &DatabaseConnection, + state: &ServiceState, enabled: bool, ) -> Result)>, CatalogProviderError>; @@ -139,52 +139,52 @@ impl CatalogProvider { #[async_trait] impl CatalogApi for CatalogProvider { /// List services - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn list_services( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &ServiceListParameters, ) -> Result, CatalogProviderError> { - self.backend_driver.list_services(db, params).await + self.backend_driver.list_services(state, params).await } /// Get single service by ID - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn get_service<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result, CatalogProviderError> { - self.backend_driver.get_service(db, id).await + self.backend_driver.get_service(state, id).await } /// List Endpoints - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn list_endpoints( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &EndpointListParameters, ) -> Result, CatalogProviderError> { - self.backend_driver.list_endpoints(db, params).await + self.backend_driver.list_endpoints(state, params).await } /// Get single endpoint by ID - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn get_endpoint<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result, CatalogProviderError> { - self.backend_driver.get_endpoint(db, id).await + self.backend_driver.get_endpoint(state, id).await } /// Get catalog - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn get_catalog( &self, - db: &DatabaseConnection, + state: &ServiceState, enabled: bool, ) -> Result)>, CatalogProviderError> { - self.backend_driver.get_catalog(db, enabled).await + self.backend_driver.get_catalog(state, enabled).await } } diff --git a/src/catalog/types.rs b/src/catalog/types.rs index ea00bcc9..b677dcc9 100644 --- a/src/catalog/types.rs +++ b/src/catalog/types.rs @@ -17,7 +17,6 @@ pub mod service; use async_trait::async_trait; use dyn_clone::DynClone; -use sea_orm::DatabaseConnection; use crate::catalog::CatalogProviderError; use crate::config::Config; @@ -28,6 +27,7 @@ pub use crate::catalog::types::endpoint::{ pub use crate::catalog::types::service::{ Service, ServiceBuilder, ServiceBuilderError, ServiceListParameters, }; +use crate::keystone::ServiceState; #[async_trait] pub trait CatalogBackend: DynClone + Send + Sync + std::fmt::Debug { @@ -37,35 +37,35 @@ pub trait CatalogBackend: DynClone + Send + Sync + std::fmt::Debug { /// List services async fn list_services( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &ServiceListParameters, ) -> Result, CatalogProviderError>; /// Get single service by ID async fn get_service<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result, CatalogProviderError>; /// List Endpoints async fn list_endpoints( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &EndpointListParameters, ) -> Result, CatalogProviderError>; /// Get single endpoint by ID async fn get_endpoint<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result, CatalogProviderError>; /// Get Catalog (Services with Endpoints) async fn get_catalog( &self, - db: &DatabaseConnection, + state: &ServiceState, enabled: bool, ) -> Result)>, CatalogProviderError>; } diff --git a/src/federation/backends/sql.rs b/src/federation/backends/sql.rs index 72d37315..5fb7c28d 100644 --- a/src/federation/backends/sql.rs +++ b/src/federation/backends/sql.rs @@ -13,11 +13,11 @@ // SPDX-License-Identifier: Apache-2.0 use async_trait::async_trait; -use sea_orm::DatabaseConnection; use super::super::types::*; use crate::config::Config; use crate::federation::FederationProviderError; +use crate::keystone::ServiceState; mod auth_state; mod identity_provider; @@ -36,141 +36,141 @@ impl FederationBackend for SqlBackend { } /// List IDPs - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn list_identity_providers( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &IdentityProviderListParameters, ) -> Result, FederationProviderError> { - Ok(identity_provider::list(&self.config, db, params).await?) + Ok(identity_provider::list(&self.config, &state.db, params).await?) } /// Get single IDP by ID - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn get_identity_provider<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result, FederationProviderError> { - Ok(identity_provider::get(&self.config, db, id).await?) + Ok(identity_provider::get(&self.config, &state.db, id).await?) } /// Create Identity provider - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn create_identity_provider( &self, - db: &DatabaseConnection, + state: &ServiceState, idp: IdentityProvider, ) -> Result { - Ok(identity_provider::create(&self.config, db, idp).await?) + Ok(identity_provider::create(&self.config, &state.db, idp).await?) } /// Update Identity provider - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn update_identity_provider<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, idp: IdentityProviderUpdate, ) -> Result { - Ok(identity_provider::update(&self.config, db, id, idp).await?) + Ok(identity_provider::update(&self.config, &state.db, id, idp).await?) } /// Delete identity provider - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn delete_identity_provider<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result<(), FederationProviderError> { - Ok(identity_provider::delete(&self.config, db, id).await?) + Ok(identity_provider::delete(&self.config, &state.db, id).await?) } /// List Mapping - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn list_mappings( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &MappingListParameters, ) -> Result, FederationProviderError> { - Ok(mapping::list(&self.config, db, params).await?) + Ok(mapping::list(&self.config, &state.db, params).await?) } /// Get single mapping by ID - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn get_mapping<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result, FederationProviderError> { - Ok(mapping::get(&self.config, db, id).await?) + Ok(mapping::get(&self.config, &state.db, id).await?) } /// Create mapping - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn create_mapping( &self, - db: &DatabaseConnection, + state: &ServiceState, idp: Mapping, ) -> Result { - Ok(mapping::create(&self.config, db, idp).await?) + Ok(mapping::create(&self.config, &state.db, idp).await?) } /// Update mapping - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn update_mapping<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, idp: MappingUpdate, ) -> Result { - Ok(mapping::update(&self.config, db, id, idp).await?) + Ok(mapping::update(&self.config, &state.db, id, idp).await?) } /// Delete mapping - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn delete_mapping<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result<(), FederationProviderError> { - Ok(mapping::delete(&self.config, db, id).await?) + Ok(mapping::delete(&self.config, &state.db, id).await?) } /// Get auth state by ID - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn get_auth_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result, FederationProviderError> { - Ok(auth_state::get(&self.config, db, id).await?) + Ok(auth_state::get(&self.config, &state.db, id).await?) } /// Create new auth state - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn create_auth_state( &self, - db: &DatabaseConnection, - state: AuthState, + state: &ServiceState, + auth_state: AuthState, ) -> Result { - Ok(auth_state::create(&self.config, db, state).await?) + Ok(auth_state::create(&self.config, &state.db, auth_state).await?) } /// Delete auth state - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn delete_auth_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result<(), FederationProviderError> { - Ok(auth_state::delete(&self.config, db, id).await?) + Ok(auth_state::delete(&self.config, &state.db, id).await?) } /// Cleanup expired resources - #[tracing::instrument(level = "debug", skip(self, db))] - async fn cleanup(&self, db: &DatabaseConnection) -> Result<(), FederationProviderError> { - Ok(auth_state::delete_expired(&self.config, db).await?) + #[tracing::instrument(level = "debug", skip(self, state))] + async fn cleanup(&self, state: &ServiceState) -> Result<(), FederationProviderError> { + Ok(auth_state::delete_expired(&self.config, &state.db).await?) } } diff --git a/src/federation/mod.rs b/src/federation/mod.rs index b24423e1..9a201309 100644 --- a/src/federation/mod.rs +++ b/src/federation/mod.rs @@ -15,7 +15,6 @@ use async_trait::async_trait; #[cfg(test)] use mockall::mock; -use sea_orm::DatabaseConnection; use uuid::Uuid; pub mod backends; @@ -26,6 +25,7 @@ use crate::config::Config; use crate::federation::backends::sql::SqlBackend; use crate::federation::error::FederationProviderError; use crate::federation::types::*; +use crate::keystone::ServiceState; use crate::plugin_manager::PluginManager; #[derive(Clone, Debug)] @@ -37,86 +37,86 @@ pub struct FederationProvider { pub trait FederationApi: Send + Sync + Clone { async fn list_identity_providers( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &IdentityProviderListParameters, ) -> Result, FederationProviderError>; async fn get_identity_provider<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result, FederationProviderError>; async fn create_identity_provider( &self, - db: &DatabaseConnection, + state: &ServiceState, idp: IdentityProvider, ) -> Result; async fn update_identity_provider<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, idp: IdentityProviderUpdate, ) -> Result; async fn delete_identity_provider<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result<(), FederationProviderError>; async fn list_mappings( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &MappingListParameters, ) -> Result, FederationProviderError>; async fn get_mapping<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result, FederationProviderError>; async fn create_mapping( &self, - db: &DatabaseConnection, + state: &ServiceState, mapping: Mapping, ) -> Result; async fn update_mapping<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, mapping: MappingUpdate, ) -> Result; async fn delete_mapping<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result<(), FederationProviderError>; async fn get_auth_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result, FederationProviderError>; async fn create_auth_state( &self, - db: &DatabaseConnection, + state: &ServiceState, state: AuthState, ) -> Result; async fn delete_auth_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result<(), FederationProviderError>; /// Cleanup expired resources - async fn cleanup(&self, db: &DatabaseConnection) -> Result<(), FederationProviderError>; + async fn cleanup(&self, state: &ServiceState) -> Result<(), FederationProviderError>; } #[cfg(test)] @@ -129,59 +129,59 @@ mock! { impl FederationApi for FederationProvider { async fn list_identity_providers( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &IdentityProviderListParameters, ) -> Result, FederationProviderError>; async fn get_identity_provider<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result, FederationProviderError>; async fn create_identity_provider( &self, - db: &DatabaseConnection, + state: &ServiceState, idp: IdentityProvider, ) -> Result; async fn update_identity_provider<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, idp: IdentityProviderUpdate, ) -> Result; async fn delete_identity_provider<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result<(), FederationProviderError>; async fn list_mappings( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &MappingListParameters, ) -> Result, FederationProviderError>; /// Get single mapping by ID async fn get_mapping<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result, FederationProviderError>; /// Create mapping async fn create_mapping( &self, - db: &DatabaseConnection, + state: &ServiceState, mapping: Mapping, ) -> Result; /// Update mapping async fn update_mapping<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, mapping: MappingUpdate, ) -> Result; @@ -189,31 +189,31 @@ mock! { /// Delete mapping async fn delete_mapping<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result<(), FederationProviderError>; async fn get_auth_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result, FederationProviderError>; async fn create_auth_state( &self, - db: &DatabaseConnection, - state: AuthState, + state: &ServiceState, + auth_state: AuthState, ) -> Result; async fn delete_auth_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result<(), FederationProviderError>; async fn cleanup( &self, - db: &DatabaseConnection, + state: &ServiceState, ) -> Result<(), FederationProviderError>; } @@ -250,32 +250,32 @@ impl FederationProvider { #[async_trait] impl FederationApi for FederationProvider { /// List IDP - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn list_identity_providers( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &IdentityProviderListParameters, ) -> Result, FederationProviderError> { self.backend_driver - .list_identity_providers(db, params) + .list_identity_providers(state, params) .await } /// Get single IDP by ID - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn get_identity_provider<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result, FederationProviderError> { - self.backend_driver.get_identity_provider(db, id).await + self.backend_driver.get_identity_provider(state, id).await } /// Create Identity provider - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn create_identity_provider( &self, - db: &DatabaseConnection, + state: &ServiceState, idp: IdentityProvider, ) -> Result { let mut mod_idp = idp; @@ -284,58 +284,60 @@ impl FederationApi for FederationProvider { } self.backend_driver - .create_identity_provider(db, mod_idp) + .create_identity_provider(state, mod_idp) .await } /// Update Identity provider - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn update_identity_provider<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, idp: IdentityProviderUpdate, ) -> Result { self.backend_driver - .update_identity_provider(db, id, idp) + .update_identity_provider(state, id, idp) .await } /// Delete identity provider - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn delete_identity_provider<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result<(), FederationProviderError> { - self.backend_driver.delete_identity_provider(db, id).await + self.backend_driver + .delete_identity_provider(state, id) + .await } /// List mappings - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn list_mappings( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &MappingListParameters, ) -> Result, FederationProviderError> { - self.backend_driver.list_mappings(db, params).await + self.backend_driver.list_mappings(state, params).await } /// Get single mapping by ID - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn get_mapping<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result, FederationProviderError> { - self.backend_driver.get_mapping(db, id).await + self.backend_driver.get_mapping(state, id).await } /// Create mapping - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn create_mapping( &self, - db: &DatabaseConnection, + state: &ServiceState, mapping: Mapping, ) -> Result { let mut mod_mapping = mapping; @@ -350,20 +352,20 @@ impl FederationApi for FederationProvider { // TODO: ensure current user has access to the project } - self.backend_driver.create_mapping(db, mod_mapping).await + self.backend_driver.create_mapping(state, mod_mapping).await } /// Update mapping - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn update_mapping<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, mapping: MappingUpdate, ) -> Result { let current = self .backend_driver - .get_mapping(db, id) + .get_mapping(state, id) .await? .ok_or_else(|| FederationProviderError::MappingNotFound(id.to_string()))?; @@ -381,52 +383,54 @@ impl FederationApi for FederationProvider { // TODO: ensure current user has access to the project } // TODO: Pass current to the backend to skip re-fetching - self.backend_driver.update_mapping(db, id, mapping).await + self.backend_driver.update_mapping(state, id, mapping).await } /// Delete identity provider - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn delete_mapping<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result<(), FederationProviderError> { - self.backend_driver.delete_mapping(db, id).await + self.backend_driver.delete_mapping(state, id).await } /// Get auth state by ID - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn get_auth_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result, FederationProviderError> { - self.backend_driver.get_auth_state(db, id).await + self.backend_driver.get_auth_state(state, id).await } /// Create new auth state - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn create_auth_state( &self, - db: &DatabaseConnection, - state: AuthState, + state: &ServiceState, + auth_state: AuthState, ) -> Result { - self.backend_driver.create_auth_state(db, state).await + self.backend_driver + .create_auth_state(state, auth_state) + .await } /// Delete auth state - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn delete_auth_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result<(), FederationProviderError> { - self.backend_driver.delete_auth_state(db, id).await + self.backend_driver.delete_auth_state(state, id).await } /// Cleanup expired resources - #[tracing::instrument(level = "info", skip(self, db))] - async fn cleanup(&self, db: &DatabaseConnection) -> Result<(), FederationProviderError> { - self.backend_driver.cleanup(db).await + #[tracing::instrument(level = "info", skip(self, state))] + async fn cleanup(&self, state: &ServiceState) -> Result<(), FederationProviderError> { + self.backend_driver.cleanup(state).await } } diff --git a/src/federation/types.rs b/src/federation/types.rs index af385418..578ff654 100644 --- a/src/federation/types.rs +++ b/src/federation/types.rs @@ -18,10 +18,10 @@ pub mod mapping; use async_trait::async_trait; use dyn_clone::DynClone; -use sea_orm::DatabaseConnection; use crate::config::Config; use crate::federation::FederationProviderError; +use crate::keystone::ServiceState; pub use auth_state::*; pub use identity_provider::*; @@ -35,28 +35,28 @@ pub trait FederationBackend: DynClone + Send + Sync + std::fmt::Debug { /// List Identity Providers async fn list_identity_providers( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &IdentityProviderListParameters, ) -> Result, FederationProviderError>; /// Get single identity provider by ID async fn get_identity_provider<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result, FederationProviderError>; /// Create Identity provider async fn create_identity_provider( &self, - db: &DatabaseConnection, + state: &ServiceState, idp: IdentityProvider, ) -> Result; /// Update Identity provider async fn update_identity_provider<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, idp: IdentityProviderUpdate, ) -> Result; @@ -64,35 +64,35 @@ pub trait FederationBackend: DynClone + Send + Sync + std::fmt::Debug { /// Delete identity provider async fn delete_identity_provider<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result<(), FederationProviderError>; /// List Identity Providers async fn list_mappings( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &MappingListParameters, ) -> Result, FederationProviderError>; /// Get single mapping by ID async fn get_mapping<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result, FederationProviderError>; /// Create mapping async fn create_mapping( &self, - db: &DatabaseConnection, + state: &ServiceState, idp: Mapping, ) -> Result; /// Update mapping async fn update_mapping<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, idp: MappingUpdate, ) -> Result; @@ -100,33 +100,33 @@ pub trait FederationBackend: DynClone + Send + Sync + std::fmt::Debug { /// Delete mapping async fn delete_mapping<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result<(), FederationProviderError>; /// Get authentication state async fn get_auth_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result, FederationProviderError>; /// Create new authentication state async fn create_auth_state( &self, - db: &DatabaseConnection, - state: AuthState, + state: &ServiceState, + auth_state: AuthState, ) -> Result; /// Delete authentication state async fn delete_auth_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, id: &'a str, ) -> Result<(), FederationProviderError>; /// Cleanup expired resources - async fn cleanup(&self, db: &DatabaseConnection) -> Result<(), FederationProviderError>; + async fn cleanup(&self, state: &ServiceState) -> Result<(), FederationProviderError>; } dyn_clone::clone_trait_object!(FederationBackend); diff --git a/src/identity/backends/sql.rs b/src/identity/backends/sql.rs index c374a364..db1a02c7 100644 --- a/src/identity/backends/sql.rs +++ b/src/identity/backends/sql.rs @@ -41,6 +41,7 @@ use crate::db::entity::{ use crate::identity::IdentityProviderError; use crate::identity::backends::error::{IdentityDatabaseError, db_err}; use crate::identity::password_hashing; +use crate::keystone::ServiceState; #[derive(Clone, Debug, Default)] pub struct SqlBackend { @@ -59,11 +60,11 @@ impl IdentityBackend for SqlBackend { /// Authenticate a user by a password async fn authenticate_by_password( &self, - db: &DatabaseConnection, + state: &ServiceState, auth: UserPasswordAuthRequest, ) -> Result { let user_with_passwords = local_user::load_local_user_with_passwords( - db, + &state.db, auth.id, auth.name, auth.domain.and_then(|x| x.id), @@ -74,10 +75,10 @@ impl IdentityBackend for SqlBackend { if let Some(latest_password) = passwords.first() && let Some(expected_hash) = &latest_password.password_hash { - let user_opts = user_option::get(db, local_user.user_id.clone()).await?; + let user_opts = user_option::get(&state.db, local_user.user_id.clone()).await?; if password_hashing::verify_password(&self.config, auth.password, expected_hash)? { - if let Some(user) = user::get(db, &local_user.user_id).await? { + if let Some(user) = user::get(&state.db, &local_user.user_id).await? { // TODO: Check password is expired // TODO: reset failed login attempt let user_builder = common::get_local_user_builder( @@ -104,242 +105,242 @@ impl IdentityBackend for SqlBackend { } /// Fetch users from the database - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn list_users( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &UserListParameters, ) -> Result, IdentityProviderError> { - Ok(list_users(&self.config, db, params).await?) + Ok(list_users(&self.config, &state.db, params).await?) } /// Get single user by ID - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn get_user<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result, IdentityProviderError> { - Ok(get_user(&self.config, db, user_id).await?) + Ok(get_user(&self.config, &state.db, user_id).await?) } /// Find federated user by IDP and Unique ID - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn find_federated_user<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, idp_id: &'a str, unique_id: &'a str, ) -> Result, IdentityProviderError> { - Ok(find_federated_user(&self.config, db, idp_id, unique_id).await?) + Ok(find_federated_user(&self.config, &state.db, idp_id, unique_id).await?) } /// Create user - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn create_user( &self, - db: &DatabaseConnection, + state: &ServiceState, user: UserCreate, ) -> Result { - Ok(create_user(&self.config, db, user).await?) + Ok(create_user(&self.config, &state.db, user).await?) } /// Delete user - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn delete_user<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result<(), IdentityProviderError> { - Ok(user::delete(&self.config, db, user_id).await?) + Ok(user::delete(&self.config, &state.db, user_id).await?) } /// List groups - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn list_groups( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &GroupListParameters, ) -> Result, IdentityProviderError> { - Ok(group::list(&self.config, db, params).await?) + Ok(group::list(&self.config, &state.db, params).await?) } /// Get single group by ID - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn get_group<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, group_id: &'a str, ) -> Result, IdentityProviderError> { - Ok(group::get(&self.config, db, group_id).await?) + Ok(group::get(&self.config, &state.db, group_id).await?) } /// Create group - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn create_group( &self, - db: &DatabaseConnection, + state: &ServiceState, group: GroupCreate, ) -> Result { - Ok(group::create(&self.config, db, group).await?) + Ok(group::create(&self.config, &state.db, group).await?) } /// Delete group - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn delete_group<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, group_id: &'a str, ) -> Result<(), IdentityProviderError> { - Ok(group::delete(&self.config, db, group_id).await?) + Ok(group::delete(&self.config, &state.db, group_id).await?) } /// List groups a user is member of. - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn list_groups_of_user<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result, IdentityProviderError> { - Ok(user_group::list_user_groups(db, user_id).await?) + Ok(user_group::list_user_groups(&state.db, user_id).await?) } /// Add the user into the group. - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn add_user_to_group<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, group_id: &'a str, ) -> Result<(), IdentityProviderError> { - Ok(user_group::add_user_to_group(db, user_id, group_id).await?) + Ok(user_group::add_user_to_group(&state.db, user_id, group_id).await?) } /// Add user group membership relations. - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn add_users_to_groups<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, memberships: Vec<(&'a str, &'a str)>, ) -> Result<(), IdentityProviderError> { - Ok(user_group::add_users_to_groups(db, memberships).await?) + Ok(user_group::add_users_to_groups(&state.db, memberships).await?) } /// Remove the user from the group. - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn remove_user_from_group<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, group_id: &'a str, ) -> Result<(), IdentityProviderError> { - Ok(user_group::remove_user_from_group(db, user_id, group_id).await?) + Ok(user_group::remove_user_from_group(&state.db, user_id, group_id).await?) } /// Remove the user from multiple groups. - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn remove_user_from_groups<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, group_ids: HashSet<&'a str>, ) -> Result<(), IdentityProviderError> { - Ok(user_group::remove_user_from_groups(db, user_id, group_ids).await?) + Ok(user_group::remove_user_from_groups(&state.db, user_id, group_ids).await?) } /// Set group memberships of the user. - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn set_user_groups<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, group_ids: HashSet<&'a str>, ) -> Result<(), IdentityProviderError> { - Ok(user_group::set_user_groups(db, user_id, group_ids).await?) + Ok(user_group::set_user_groups(&state.db, user_id, group_ids).await?) } /// Create webauthn credential for the user. - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn create_user_webauthn_credential<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, credential: &Passkey, description: Option<&'a str>, ) -> Result { - Ok(webauthn::credential::create(db, user_id, credential, description, None).await?) + Ok(webauthn::credential::create(&state.db, user_id, credential, description, None).await?) } /// List user webauthn credentials. - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn list_user_webauthn_credentials<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result, IdentityProviderError> { - Ok(webauthn::credential::list(db, user_id).await?) + Ok(webauthn::credential::list(&state.db, user_id).await?) } /// Save webauthn credential registration state. - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn create_user_webauthn_credential_registration_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, - state: PasskeyRegistration, + reg_state: PasskeyRegistration, ) -> Result<(), IdentityProviderError> { - Ok(webauthn::state::create_register(db, user_id, state).await?) + Ok(webauthn::state::create_register(&state.db, user_id, reg_state).await?) } /// Save webauthn credential auth state. - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn create_user_webauthn_credential_authentication_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, - state: PasskeyAuthentication, + auth_state: PasskeyAuthentication, ) -> Result<(), IdentityProviderError> { - Ok(webauthn::state::create_auth(db, user_id, state).await?) + Ok(webauthn::state::create_auth(&state.db, user_id, auth_state).await?) } /// Get webauthn credential registration state. - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn get_user_webauthn_credential_registration_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result, IdentityProviderError> { - Ok(webauthn::state::get_register(db, user_id).await?) + Ok(webauthn::state::get_register(&state.db, user_id).await?) } /// Get webauthn credential auth state. - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn get_user_webauthn_credential_authentication_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result, IdentityProviderError> { - Ok(webauthn::state::get_auth(db, user_id).await?) + Ok(webauthn::state::get_auth(&state.db, user_id).await?) } /// Delete webauthn credential registration state for the user. - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn delete_user_webauthn_credential_registration_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result<(), IdentityProviderError> { - Ok(webauthn::state::delete(db, user_id).await?) + Ok(webauthn::state::delete(&state.db, user_id).await?) } /// Delete webauthn credential auth state for a user. - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn delete_user_webauthn_credential_authentication_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result<(), IdentityProviderError> { - Ok(webauthn::state::delete(db, user_id).await?) + Ok(webauthn::state::delete(&state.db, user_id).await?) } } diff --git a/src/identity/mod.rs b/src/identity/mod.rs index 515ff1f9..0bd7f2cb 100644 --- a/src/identity/mod.rs +++ b/src/identity/mod.rs @@ -15,7 +15,6 @@ use async_trait::async_trait; #[cfg(test)] use mockall::mock; -use sea_orm::DatabaseConnection; use std::collections::HashSet; use uuid::Uuid; use webauthn_rs::prelude::{Passkey, PasskeyAuthentication, PasskeyRegistration}; @@ -33,8 +32,8 @@ use crate::identity::types::{ Group, GroupCreate, GroupListParameters, IdentityBackend, UserCreate, UserListParameters, UserPasswordAuthRequest, UserResponse, WebauthnCredential, }; +use crate::keystone::ServiceState; use crate::plugin_manager::PluginManager; -use crate::provider::Provider; use crate::resource::{ResourceApi, error::ResourceProviderError}; #[derive(Clone, Debug)] @@ -46,77 +45,76 @@ pub struct IdentityProvider { pub trait IdentityApi: Send + Sync + Clone { async fn authenticate_by_password( &self, - db: &DatabaseConnection, - provider: &Provider, + state: &ServiceState, auth: UserPasswordAuthRequest, ) -> Result; async fn list_users( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &UserListParameters, ) -> Result, IdentityProviderError>; async fn get_user<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result, IdentityProviderError>; async fn find_federated_user<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, idp_id: &'a str, unique_id: &'a str, ) -> Result, IdentityProviderError>; async fn create_user( &self, - db: &DatabaseConnection, + state: &ServiceState, user: UserCreate, ) -> Result; async fn delete_user<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result<(), IdentityProviderError>; async fn list_groups( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &GroupListParameters, ) -> Result, IdentityProviderError>; async fn get_group<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, group_id: &'a str, ) -> Result, IdentityProviderError>; async fn create_group( &self, - db: &DatabaseConnection, + state: &ServiceState, group: GroupCreate, ) -> Result; async fn delete_group<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, group_id: &'a str, ) -> Result<(), IdentityProviderError>; /// List groups the user is a member of. async fn list_groups_of_user<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result, IdentityProviderError>; /// Add the user to the single group. async fn add_user_to_group<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, group_id: &'a str, ) -> Result<(), IdentityProviderError>; @@ -124,14 +122,14 @@ pub trait IdentityApi: Send + Sync + Clone { /// Add user group memberships as specified by (uid, gid) tuples. async fn add_users_to_groups<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, memberships: Vec<(&'a str, &'a str)>, ) -> Result<(), IdentityProviderError>; /// Remove the user from the single group. async fn remove_user_from_group<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, group_id: &'a str, ) -> Result<(), IdentityProviderError>; @@ -139,7 +137,7 @@ pub trait IdentityApi: Send + Sync + Clone { /// Remove the user from specified groups. async fn remove_user_from_groups<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, group_ids: HashSet<&'a str>, ) -> Result<(), IdentityProviderError>; @@ -147,21 +145,21 @@ pub trait IdentityApi: Send + Sync + Clone { /// Set group memberships of the user. async fn set_user_groups<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, group_ids: HashSet<&'a str>, ) -> Result<(), IdentityProviderError>; async fn list_user_webauthn_credentials<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result, IdentityProviderError>; /// Create passkey. async fn create_user_webauthn_credential<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, passkey: &Passkey, description: Option<&'a str>, @@ -169,41 +167,41 @@ pub trait IdentityApi: Send + Sync + Clone { async fn save_user_webauthn_credential_registration_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, state: PasskeyRegistration, ) -> Result<(), IdentityProviderError>; async fn save_user_webauthn_credential_authentication_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, state: PasskeyAuthentication, ) -> Result<(), IdentityProviderError>; async fn get_user_webauthn_credential_registration_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result, IdentityProviderError>; async fn get_user_webauthn_credential_authentication_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result, IdentityProviderError>; /// Delete passkey registration state of a user async fn delete_user_webauthn_credential_registration_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result<(), IdentityProviderError>; /// Delete passkey registration state of a user async fn delete_user_webauthn_credential_authentication_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result<(), IdentityProviderError>; } @@ -218,115 +216,114 @@ mock! { impl IdentityApi for IdentityProvider { async fn authenticate_by_password( &self, - db: &DatabaseConnection, - provider: &Provider, + state: &ServiceState, auth: UserPasswordAuthRequest, ) -> Result; async fn list_users( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &UserListParameters, ) -> Result, IdentityProviderError>; async fn get_user<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result, IdentityProviderError>; async fn find_federated_user<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, idp_id: &'a str, unique_id: &'a str, ) -> Result, IdentityProviderError>; async fn create_user( &self, - db: &DatabaseConnection, + state: &ServiceState, user: UserCreate, ) -> Result; async fn delete_user<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result<(), IdentityProviderError>; async fn list_groups( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &GroupListParameters, ) -> Result, IdentityProviderError>; async fn get_group<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, group_id: &'a str, ) -> Result, IdentityProviderError>; async fn create_group( &self, - db: &DatabaseConnection, + state: &ServiceState, group: GroupCreate, ) -> Result; async fn delete_group<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, group_id: &'a str, ) -> Result<(), IdentityProviderError>; async fn list_groups_of_user<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result, IdentityProviderError>; async fn add_user_to_group<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, group_id: &'a str, ) -> Result<(), IdentityProviderError>; async fn add_users_to_groups<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, memberships: Vec<(&'a str, &'a str)> ) -> Result<(), IdentityProviderError>; async fn remove_user_from_group<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, group_id: &'a str, ) -> Result<(), IdentityProviderError>; async fn remove_user_from_groups<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, group_ids: HashSet<&'a str>, ) -> Result<(), IdentityProviderError>; async fn set_user_groups<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, group_ids: HashSet<&'a str>, ) -> Result<(), IdentityProviderError>; async fn list_user_webauthn_credentials<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result, IdentityProviderError>; async fn create_user_webauthn_credential<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, passkey: &Passkey, description: Option<&'a str> @@ -334,39 +331,39 @@ mock! { async fn save_user_webauthn_credential_registration_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, - state: PasskeyRegistration, + auth_state: PasskeyRegistration, ) -> Result<(), IdentityProviderError>; async fn save_user_webauthn_credential_authentication_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, - state: PasskeyAuthentication, + auth_state: PasskeyAuthentication, ) -> Result<(), IdentityProviderError>; async fn get_user_webauthn_credential_registration_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result, IdentityProviderError>; async fn get_user_webauthn_credential_authentication_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result, IdentityProviderError>; async fn delete_user_webauthn_credential_registration_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result<(), IdentityProviderError>; async fn delete_user_webauthn_credential_authentication_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result<(), IdentityProviderError>; } @@ -404,11 +401,10 @@ impl IdentityProvider { #[async_trait] impl IdentityApi for IdentityProvider { /// Authenticate user with the password auth method - #[tracing::instrument(level = "info", skip(self, db, provider, auth))] + #[tracing::instrument(level = "info", skip(self, state, auth))] async fn authenticate_by_password( &self, - db: &DatabaseConnection, - provider: &Provider, + state: &ServiceState, auth: UserPasswordAuthRequest, ) -> Result { let mut auth = auth; @@ -419,9 +415,10 @@ impl IdentityApi for IdentityProvider { if let Some(ref mut domain) = auth.domain { if let Some(dname) = &domain.name { - let d = provider + let d = state + .provider .get_resource_provider() - .find_domain_by_name(db, dname) + .find_domain_by_name(state, dname) .await? .ok_or(ResourceProviderError::DomainNotFound(dname.clone()))?; domain.id = Some(d.id); @@ -433,47 +430,49 @@ impl IdentityApi for IdentityProvider { } } - self.backend_driver.authenticate_by_password(db, auth).await + self.backend_driver + .authenticate_by_password(state, auth) + .await } /// List users - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn list_users( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &UserListParameters, ) -> Result, IdentityProviderError> { - self.backend_driver.list_users(db, params).await + self.backend_driver.list_users(state, params).await } /// Get single user - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn get_user<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result, IdentityProviderError> { - self.backend_driver.get_user(db, user_id).await + self.backend_driver.get_user(state, user_id).await } /// Find federated user by IDP and Unique ID - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn find_federated_user<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, idp_id: &'a str, unique_id: &'a str, ) -> Result, IdentityProviderError> { self.backend_driver - .find_federated_user(db, idp_id, unique_id) + .find_federated_user(state, idp_id, unique_id) .await } /// Create user - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn create_user( &self, - db: &DatabaseConnection, + state: &ServiceState, user: UserCreate, ) -> Result { let mut mod_user = user; @@ -481,227 +480,229 @@ impl IdentityApi for IdentityProvider { if mod_user.enabled.is_none() { mod_user.enabled = Some(true); } - self.backend_driver.create_user(db, mod_user).await + self.backend_driver.create_user(state, mod_user).await } /// Delete user - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn delete_user<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result<(), IdentityProviderError> { - self.backend_driver.delete_user(db, user_id).await + self.backend_driver.delete_user(state, user_id).await } /// List groups - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn list_groups( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &GroupListParameters, ) -> Result, IdentityProviderError> { - self.backend_driver.list_groups(db, params).await + self.backend_driver.list_groups(state, params).await } /// Get single group - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn get_group<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, group_id: &'a str, ) -> Result, IdentityProviderError> { - self.backend_driver.get_group(db, group_id).await + self.backend_driver.get_group(state, group_id).await } /// Create group - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn create_group( &self, - db: &DatabaseConnection, + state: &ServiceState, group: GroupCreate, ) -> Result { let mut res = group; res.id = Some(Uuid::new_v4().simple().to_string()); - self.backend_driver.create_group(db, res).await + self.backend_driver.create_group(state, res).await } /// Delete group - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn delete_group<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, group_id: &'a str, ) -> Result<(), IdentityProviderError> { - self.backend_driver.delete_group(db, group_id).await + self.backend_driver.delete_group(state, group_id).await } /// List groups a user is a member of. - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn list_groups_of_user<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result, IdentityProviderError> { - self.backend_driver.list_groups_of_user(db, user_id).await + self.backend_driver + .list_groups_of_user(state, user_id) + .await } - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn add_user_to_group<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, group_id: &'a str, ) -> Result<(), IdentityProviderError> { self.backend_driver - .add_user_to_group(db, user_id, group_id) + .add_user_to_group(state, user_id, group_id) .await } - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn add_users_to_groups<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, memberships: Vec<(&'a str, &'a str)>, ) -> Result<(), IdentityProviderError> { self.backend_driver - .add_users_to_groups(db, memberships) + .add_users_to_groups(state, memberships) .await } - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn remove_user_from_group<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, group_id: &'a str, ) -> Result<(), IdentityProviderError> { self.backend_driver - .remove_user_from_group(db, user_id, group_id) + .remove_user_from_group(state, user_id, group_id) .await } - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn remove_user_from_groups<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, group_ids: HashSet<&'a str>, ) -> Result<(), IdentityProviderError> { self.backend_driver - .remove_user_from_groups(db, user_id, group_ids) + .remove_user_from_groups(state, user_id, group_ids) .await } - #[tracing::instrument(level = "debug", skip(self, db))] + #[tracing::instrument(level = "debug", skip(self, state))] async fn set_user_groups<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, group_ids: HashSet<&'a str>, ) -> Result<(), IdentityProviderError> { self.backend_driver - .set_user_groups(db, user_id, group_ids) + .set_user_groups(state, user_id, group_ids) .await } /// List user passkeys. - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn list_user_webauthn_credentials<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result, IdentityProviderError> { self.backend_driver - .list_user_webauthn_credentials(db, user_id) + .list_user_webauthn_credentials(state, user_id) .await } /// Create passkey. - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn create_user_webauthn_credential<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, credential: &Passkey, description: Option<&'a str>, ) -> Result { self.backend_driver - .create_user_webauthn_credential(db, user_id, credential, description) + .create_user_webauthn_credential(state, user_id, credential, description) .await } /// Save passkey registration state - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn save_user_webauthn_credential_registration_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, - state: PasskeyRegistration, + reg_state: PasskeyRegistration, ) -> Result<(), IdentityProviderError> { self.backend_driver - .create_user_webauthn_credential_registration_state(db, user_id, state) + .create_user_webauthn_credential_registration_state(state, user_id, reg_state) .await } /// Save passkey authentication state - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn save_user_webauthn_credential_authentication_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, - state: PasskeyAuthentication, + auth_state: PasskeyAuthentication, ) -> Result<(), IdentityProviderError> { self.backend_driver - .create_user_webauthn_credential_authentication_state(db, user_id, state) + .create_user_webauthn_credential_authentication_state(state, user_id, auth_state) .await } /// Get passkey registration state - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn get_user_webauthn_credential_registration_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result, IdentityProviderError> { self.backend_driver - .get_user_webauthn_credential_registration_state(db, user_id) + .get_user_webauthn_credential_registration_state(state, user_id) .await } /// Get passkey authentication state - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn get_user_webauthn_credential_authentication_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result, IdentityProviderError> { self.backend_driver - .get_user_webauthn_credential_authentication_state(db, user_id) + .get_user_webauthn_credential_authentication_state(state, user_id) .await } /// Delete passkey registration state of a user - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn delete_user_webauthn_credential_registration_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result<(), IdentityProviderError> { self.backend_driver - .delete_user_webauthn_credential_authentication_state(db, user_id) + .delete_user_webauthn_credential_authentication_state(state, user_id) .await } /// Delete passkey authentication state of a user - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn delete_user_webauthn_credential_authentication_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result<(), IdentityProviderError> { self.backend_driver - .delete_user_webauthn_credential_authentication_state(db, user_id) + .delete_user_webauthn_credential_authentication_state(state, user_id) .await } } diff --git a/src/identity/types.rs b/src/identity/types.rs index de25534c..911715b1 100644 --- a/src/identity/types.rs +++ b/src/identity/types.rs @@ -19,15 +19,14 @@ pub mod user; use async_trait::async_trait; use dyn_clone::DynClone; -use sea_orm::DatabaseConnection; use webauthn_rs::prelude::{Passkey, PasskeyAuthentication, PasskeyRegistration}; +use crate::auth::AuthenticatedInfo; use crate::config::Config; use crate::identity::IdentityProviderError; - -use crate::auth::AuthenticatedInfo; pub use crate::identity::types::group::{Group, GroupCreate, GroupListParameters}; pub use crate::identity::types::user::*; +use crate::keystone::ServiceState; #[async_trait] pub trait IdentityBackend: DynClone + Send + Sync + std::fmt::Debug { @@ -37,28 +36,28 @@ pub trait IdentityBackend: DynClone + Send + Sync + std::fmt::Debug { /// Authenticate a user by a password. async fn authenticate_by_password( &self, - db: &DatabaseConnection, + state: &ServiceState, auth: UserPasswordAuthRequest, ) -> Result; /// List Users. async fn list_users( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &UserListParameters, ) -> Result, IdentityProviderError>; /// Get single user by ID. async fn get_user<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result, IdentityProviderError>; /// Find federated user by IDP and Unique ID. async fn find_federated_user<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, idp_id: &'a str, unique_id: &'a str, ) -> Result, IdentityProviderError>; @@ -66,56 +65,56 @@ pub trait IdentityBackend: DynClone + Send + Sync + std::fmt::Debug { /// Create user. async fn create_user( &self, - db: &DatabaseConnection, + state: &ServiceState, user: UserCreate, ) -> Result; /// Delete user. async fn delete_user<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result<(), IdentityProviderError>; /// List groups. async fn list_groups( &self, - db: &DatabaseConnection, + state: &ServiceState, params: &GroupListParameters, ) -> Result, IdentityProviderError>; /// Get single group by ID. async fn get_group<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, group_id: &'a str, ) -> Result, IdentityProviderError>; /// Create group. async fn create_group( &self, - db: &DatabaseConnection, + state: &ServiceState, group: GroupCreate, ) -> Result; /// Delete group by ID. async fn delete_group<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, group_id: &'a str, ) -> Result<(), IdentityProviderError>; /// List groups a user is member of. async fn list_groups_of_user<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result, IdentityProviderError>; /// Add the user to the group. async fn add_user_to_group<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, group_id: &'a str, ) -> Result<(), IdentityProviderError>; @@ -123,14 +122,14 @@ pub trait IdentityBackend: DynClone + Send + Sync + std::fmt::Debug { /// Add user group membership relations. async fn add_users_to_groups<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, memberships: Vec<(&'a str, &'a str)>, ) -> Result<(), IdentityProviderError>; /// Remove the user from the group. async fn remove_user_from_group<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, group_id: &'a str, ) -> Result<(), IdentityProviderError>; @@ -138,7 +137,7 @@ pub trait IdentityBackend: DynClone + Send + Sync + std::fmt::Debug { /// Remove the user from multiple groups. async fn remove_user_from_groups<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, group_ids: HashSet<&'a str>, ) -> Result<(), IdentityProviderError>; @@ -146,7 +145,7 @@ pub trait IdentityBackend: DynClone + Send + Sync + std::fmt::Debug { /// Set group memberships for the user. async fn set_user_groups<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, group_ids: HashSet<&'a str>, ) -> Result<(), IdentityProviderError>; @@ -154,14 +153,14 @@ pub trait IdentityBackend: DynClone + Send + Sync + std::fmt::Debug { /// List user passkeys. async fn list_user_webauthn_credentials<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result, IdentityProviderError>; /// Create passkey. async fn create_user_webauthn_credential<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, passkey: &Passkey, description: Option<&'a str>, @@ -170,7 +169,7 @@ pub trait IdentityBackend: DynClone + Send + Sync + std::fmt::Debug { /// Save passkey registration state. async fn create_user_webauthn_credential_registration_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, state: PasskeyRegistration, ) -> Result<(), IdentityProviderError>; @@ -178,7 +177,7 @@ pub trait IdentityBackend: DynClone + Send + Sync + std::fmt::Debug { /// Save passkey auth state. async fn create_user_webauthn_credential_authentication_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, state: PasskeyAuthentication, ) -> Result<(), IdentityProviderError>; @@ -186,28 +185,28 @@ pub trait IdentityBackend: DynClone + Send + Sync + std::fmt::Debug { /// Get passkey registration state. async fn get_user_webauthn_credential_registration_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result, IdentityProviderError>; /// Get passkey authentication state. async fn get_user_webauthn_credential_authentication_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result, IdentityProviderError>; /// Delete passkey registration state of a user. async fn delete_user_webauthn_credential_registration_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result<(), IdentityProviderError>; /// Delete passkey authentication state of a user. async fn delete_user_webauthn_credential_authentication_state<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, user_id: &'a str, ) -> Result<(), IdentityProviderError>; } diff --git a/src/resource/backends/sql.rs b/src/resource/backends/sql.rs index e34cb3ee..7ce1abf6 100644 --- a/src/resource/backends/sql.rs +++ b/src/resource/backends/sql.rs @@ -22,6 +22,7 @@ use tracing::error; use super::super::types::*; use crate::config::Config; use crate::db::entity::{prelude::Project as DbProject, project as db_project}; +use crate::keystone::ServiceState; use crate::resource::ResourceProviderError; use crate::resource::backends::error::{ResourceDatabaseError, db_err}; @@ -42,38 +43,38 @@ impl ResourceBackend for SqlBackend { /// Get single domain by ID async fn get_domain<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, domain_id: &'a str, ) -> Result, ResourceProviderError> { - Ok(get_domain_by_id(&self.config, db, domain_id).await?) + Ok(get_domain_by_id(&self.config, &state.db, domain_id).await?) } /// Get single domain by Name async fn get_domain_by_name<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, domain_name: &'a str, ) -> Result, ResourceProviderError> { - Ok(get_domain_by_name(&self.config, db, domain_name).await?) + Ok(get_domain_by_name(&self.config, &state.db, domain_name).await?) } /// Get single project by ID async fn get_project<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, project_id: &'a str, ) -> Result, ResourceProviderError> { - Ok(get_project(&self.config, db, project_id).await?) + Ok(get_project(&self.config, &state.db, project_id).await?) } /// Get single project by Name and Domain ID async fn get_project_by_name<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, name: &'a str, domain_id: &'a str, ) -> Result, ResourceProviderError> { - Ok(get_project_by_name(&self.config, db, name, domain_id).await?) + Ok(get_project_by_name(&self.config, &state.db, name, domain_id).await?) } } diff --git a/src/resource/mod.rs b/src/resource/mod.rs index 945676fd..75419277 100644 --- a/src/resource/mod.rs +++ b/src/resource/mod.rs @@ -15,13 +15,13 @@ use async_trait::async_trait; #[cfg(test)] use mockall::mock; -use sea_orm::DatabaseConnection; pub mod backends; pub mod error; pub(crate) mod types; use crate::config::Config; +use crate::keystone::ServiceState; use crate::plugin_manager::PluginManager; use crate::resource::backends::sql::SqlBackend; use crate::resource::error::ResourceProviderError; @@ -36,25 +36,25 @@ pub struct ResourceProvider { pub trait ResourceApi: Send + Sync + Clone { async fn get_domain<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, domain_id: &'a str, ) -> Result, ResourceProviderError>; async fn find_domain_by_name<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, domain_name: &'a str, ) -> Result, ResourceProviderError>; async fn get_project<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, project_id: &'a str, ) -> Result, ResourceProviderError>; async fn get_project_by_name<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, name: &'a str, domain_id: &'a str, ) -> Result, ResourceProviderError>; @@ -70,25 +70,25 @@ mock! { impl ResourceApi for ResourceProvider { async fn get_domain<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, domain_id: &'a str, ) -> Result, ResourceProviderError>; - async fn find_domain_by_name<'a>( - &self, - db: &DatabaseConnection, - domain_name: &'a str, - ) -> Result, ResourceProviderError>; + async fn find_domain_by_name<'a>( + &self, + state: &ServiceState, + domain_name: &'a str, + ) -> Result, ResourceProviderError>; async fn get_project<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, project_id: &'a str, ) -> Result, ResourceProviderError>; async fn get_project_by_name<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, name: &'a str, domain_id: &'a str, ) -> Result, ResourceProviderError>; @@ -127,47 +127,47 @@ impl ResourceProvider { #[async_trait] impl ResourceApi for ResourceProvider { /// Get single domain - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn get_domain<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, domain_id: &'a str, ) -> Result, ResourceProviderError> { - self.backend_driver.get_domain(db, domain_id).await + self.backend_driver.get_domain(state, domain_id).await } /// Get single domain by its name - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn find_domain_by_name<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, domain_name: &'a str, ) -> Result, ResourceProviderError> { self.backend_driver - .get_domain_by_name(db, domain_name) + .get_domain_by_name(state, domain_name) .await } /// Get single project - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn get_project<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, project_id: &'a str, ) -> Result, ResourceProviderError> { - self.backend_driver.get_project(db, project_id).await + self.backend_driver.get_project(state, project_id).await } /// Get single project by Name and Domain ID - #[tracing::instrument(level = "info", skip(self, db))] + #[tracing::instrument(level = "info", skip(self, state))] async fn get_project_by_name<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, name: &'a str, domain_id: &'a str, ) -> Result, ResourceProviderError> { self.backend_driver - .get_project_by_name(db, name, domain_id) + .get_project_by_name(state, name, domain_id) .await } } diff --git a/src/resource/types.rs b/src/resource/types.rs index 64c61217..3924e5e8 100644 --- a/src/resource/types.rs +++ b/src/resource/types.rs @@ -17,9 +17,9 @@ pub mod project; use async_trait::async_trait; use dyn_clone::DynClone; -use sea_orm::DatabaseConnection; use crate::config::Config; +use crate::keystone::ServiceState; use crate::resource::ResourceProviderError; pub use crate::resource::types::domain::{Domain, DomainBuilder, DomainBuilderError}; @@ -33,28 +33,28 @@ pub trait ResourceBackend: DynClone + Send + Sync + std::fmt::Debug { /// Get single domain by ID async fn get_domain<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, domain_id: &'a str, ) -> Result, ResourceProviderError>; /// Get single domain by Name async fn get_domain_by_name<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, domain_name: &'a str, ) -> Result, ResourceProviderError>; /// Get single project by ID async fn get_project<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, project_id: &'a str, ) -> Result, ResourceProviderError>; /// Get single project by Name and Domain ID async fn get_project_by_name<'a>( &self, - db: &DatabaseConnection, + state: &ServiceState, name: &'a str, domain_id: &'a str, ) -> Result, ResourceProviderError>; diff --git a/src/token/mod.rs b/src/token/mod.rs index 786d72c1..92ff27ff 100644 --- a/src/token/mod.rs +++ b/src/token/mod.rs @@ -297,7 +297,7 @@ impl TokenProvider { let user = state .provider .get_identity_provider() - .get_user(&state.db, token.user_id()) + .get_user(state, token.user_id()) .await?; match token { Token::ApplicationCredential(data) => { @@ -618,7 +618,7 @@ impl TokenApi for TokenProvider { let project = state .provider .get_resource_provider() - .get_project(&state.db, &data.project_id) + .get_project(state, &data.project_id) .await?; data.project = project; @@ -629,7 +629,7 @@ impl TokenApi for TokenProvider { let project = state .provider .get_resource_provider() - .get_project(&state.db, &data.project_id) + .get_project(state, &data.project_id) .await?; data.project = project; @@ -640,7 +640,7 @@ impl TokenApi for TokenProvider { let project = state .provider .get_resource_provider() - .get_project(&state.db, &data.project_id) + .get_project(state, &data.project_id) .await?; data.project = project; @@ -651,7 +651,7 @@ impl TokenApi for TokenProvider { let domain = state .provider .get_resource_provider() - .get_domain(&state.db, &data.domain_id) + .get_domain(state, &data.domain_id) .await?; data.domain = domain; @@ -662,7 +662,7 @@ impl TokenApi for TokenProvider { let domain = state .provider .get_resource_provider() - .get_domain(&state.db, &data.domain_id) + .get_domain(state, &data.domain_id) .await?; data.domain = domain; @@ -673,7 +673,7 @@ impl TokenApi for TokenProvider { let project = state .provider .get_resource_provider() - .get_project(&state.db, &data.project_id) + .get_project(state, &data.project_id) .await?; data.project = project;