Skip to content

Commit cb8ae90

Browse files
feat(shield-credentials): add credentials method
1 parent ec6b966 commit cb8ae90

File tree

13 files changed

+518
-9
lines changed

13 files changed

+518
-9
lines changed

Cargo.lock

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

packages/core/shield/src/form.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ pub struct Form {
1818
#[derive(Clone, Debug)]
1919
pub struct Input {
2020
pub name: String,
21+
pub label: Option<String>,
2122
pub r#type: InputType,
2223
pub value: Option<String>,
2324
pub attributes: Option<HashMap<String, Attribute>>,

packages/core/shield/src/method.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ pub trait Method: Send + Sync {
3939
request: SignOutRequest,
4040
session: Session,
4141
options: &ShieldOptions,
42-
) -> Result<Response, ShieldError>;
42+
) -> Result<Option<Response>, ShieldError>;
4343
}
4444

4545
#[cfg(test)]
@@ -111,8 +111,8 @@ pub(crate) mod tests {
111111
_request: SignOutRequest,
112112
_session: Session,
113113
_options: &ShieldOptions,
114-
) -> Result<Response, ShieldError> {
115-
todo!("redirect back?")
114+
) -> Result<Option<Response>, ShieldError> {
115+
Ok(None)
116116
}
117117
}
118118
}

packages/core/shield/src/shield.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,12 @@ impl<U: User> Shield<U> {
185185
)
186186
.await?
187187
} else {
188-
Response::Redirect(self.options.sign_out_redirect.clone())
188+
None
189189
};
190190

191+
let response =
192+
response.unwrap_or_else(|| Response::Redirect(self.options.sign_out_redirect.clone()));
193+
191194
session.purge().await?;
192195

193196
Ok(response)

packages/methods/shield-credentials/Cargo.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,10 @@ repository.workspace = true
99
version.workspace = true
1010

1111
[dependencies]
12+
async-trait.workspace = true
13+
serde.workspace = true
14+
serde_json.workspace = true
1215
shield.workspace = true
16+
17+
[dev-dependencies]
18+
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
use async_trait::async_trait;
2+
use serde::de::DeserializeOwned;
3+
use shield::{Form, ShieldError, User};
4+
5+
#[async_trait]
6+
pub trait Credentials<U: User, D: DeserializeOwned>: Send + Sync {
7+
fn form(&self) -> Form;
8+
9+
async fn sign_in(&self, data: D) -> Result<U, ShieldError>;
10+
}
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
use std::{pin::Pin, sync::Arc};
2+
3+
use async_trait::async_trait;
4+
use serde::Deserialize;
5+
use shield::{Form, Input, InputType, ShieldError, User};
6+
7+
use crate::Credentials;
8+
9+
#[derive(Debug, Deserialize)]
10+
pub struct EmailPasswordData {
11+
pub email: String,
12+
pub password: String,
13+
}
14+
15+
type SignInFn<U> = dyn Fn(EmailPasswordData) -> Pin<Box<dyn Future<Output = Result<U, ShieldError>> + Send + Sync>>
16+
+ Send
17+
+ Sync;
18+
19+
pub struct EmailPasswordCredentials<U: User> {
20+
sign_in_fn: Arc<SignInFn<U>>,
21+
}
22+
23+
impl<U: User> EmailPasswordCredentials<U> {
24+
pub fn new(
25+
sign_in_fn: impl Fn(
26+
EmailPasswordData,
27+
)
28+
-> Pin<Box<dyn Future<Output = Result<U, ShieldError>> + Send + Sync>>
29+
+ Send
30+
+ Sync
31+
+ 'static,
32+
) -> Self {
33+
Self {
34+
sign_in_fn: Arc::new(sign_in_fn),
35+
}
36+
}
37+
}
38+
39+
#[async_trait]
40+
impl<U: User> Credentials<U, EmailPasswordData> for EmailPasswordCredentials<U> {
41+
fn form(&self) -> Form {
42+
Form {
43+
inputs: vec![
44+
Input {
45+
name: "email".to_owned(),
46+
label: Some("Email address".to_owned()),
47+
r#type: InputType::Email {
48+
autocomplete: Some("email".to_owned()),
49+
dirname: None,
50+
list: None,
51+
maxlength: None,
52+
minlength: None,
53+
multiple: None,
54+
pattern: None,
55+
placeholder: Some("Email address".to_owned()),
56+
readonly: None,
57+
required: Some(true),
58+
size: None,
59+
},
60+
value: None,
61+
attributes: None,
62+
},
63+
Input {
64+
name: "password".to_owned(),
65+
label: Some("Password".to_owned()),
66+
r#type: InputType::Password {
67+
autocomplete: Some("current-password".to_owned()),
68+
dirname: None,
69+
maxlength: None,
70+
minlength: None,
71+
pattern: None,
72+
placeholder: Some("Password".to_owned()),
73+
readonly: None,
74+
required: Some(true),
75+
size: None,
76+
},
77+
value: None,
78+
attributes: None,
79+
},
80+
],
81+
attributes: None,
82+
}
83+
}
84+
85+
async fn sign_in(&self, data: EmailPasswordData) -> Result<U, ShieldError> {
86+
(self.sign_in_fn)(data).await
87+
}
88+
}
89+
90+
#[cfg(test)]
91+
mod tests {
92+
use async_trait::async_trait;
93+
use serde::{Deserialize, Serialize};
94+
use shield::{EmailAddress, ShieldError, StorageError, User};
95+
96+
use crate::Credentials;
97+
98+
use super::{EmailPasswordCredentials, EmailPasswordData};
99+
100+
#[derive(Clone, Debug, Deserialize, Serialize)]
101+
pub struct TestUser {
102+
id: String,
103+
name: Option<String>,
104+
}
105+
106+
#[async_trait]
107+
impl User for TestUser {
108+
fn id(&self) -> String {
109+
self.id.clone()
110+
}
111+
112+
fn name(&self) -> Option<String> {
113+
self.name.clone()
114+
}
115+
116+
async fn email_addresses(&self) -> Result<Vec<EmailAddress>, StorageError> {
117+
Ok(vec![])
118+
}
119+
120+
fn additional(&self) -> Option<impl Serialize> {
121+
None::<()>
122+
}
123+
}
124+
125+
#[tokio::test]
126+
async fn email_password_credentials() -> Result<(), ShieldError> {
127+
let credentials = EmailPasswordCredentials::new(|data: EmailPasswordData| {
128+
Box::pin(async move {
129+
if data.email == "test@example.com" && data.password == "test" {
130+
Ok(TestUser {
131+
id: "1".to_owned(),
132+
name: Some("Test".to_owned()),
133+
})
134+
} else {
135+
Err(ShieldError::Validation(
136+
"Incorrect email and password combination.".to_owned(),
137+
))
138+
}
139+
})
140+
});
141+
142+
assert!(
143+
credentials
144+
.sign_in(EmailPasswordData {
145+
email: "test@example.com".to_owned(),
146+
password: "incorrect".to_owned(),
147+
})
148+
.await
149+
.is_err_and(|err| err
150+
.to_string()
151+
.contains("Incorrect email and password combination."))
152+
);
153+
154+
let user = credentials
155+
.sign_in(EmailPasswordData {
156+
email: "test@example.com".to_owned(),
157+
password: "test".to_owned(),
158+
})
159+
.await?;
160+
161+
assert_eq!(user.name, Some("Test".to_owned()));
162+
163+
Ok(())
164+
}
165+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,10 @@
1+
mod credentials;
2+
mod email_password;
3+
mod method;
4+
mod provider;
5+
mod username_password;
16

7+
pub use credentials::*;
8+
pub use email_password::*;
9+
pub use method::*;
10+
pub use username_password::*;
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
use std::sync::Arc;
2+
3+
use async_trait::async_trait;
4+
use serde::de::DeserializeOwned;
5+
use shield::{
6+
Authentication, Method, Provider, Response, Session, SessionError, ShieldError, ShieldOptions,
7+
SignInCallbackRequest, SignInRequest, SignOutRequest, User,
8+
};
9+
10+
use crate::{Credentials, provider::CredentialsProvider};
11+
12+
pub const CREDENTIALS_METHOD_ID: &str = "credentials";
13+
14+
pub struct CredentialsMethod<U: User, D: DeserializeOwned> {
15+
credentials: Arc<dyn Credentials<U, D>>,
16+
}
17+
18+
impl<U: User, D: DeserializeOwned> CredentialsMethod<U, D> {
19+
pub fn new<C: Credentials<U, D> + 'static>(credentials: C) -> Self {
20+
Self {
21+
credentials: Arc::new(credentials),
22+
}
23+
}
24+
}
25+
26+
#[async_trait]
27+
impl<U: User + 'static, D: DeserializeOwned + 'static> Method for CredentialsMethod<U, D> {
28+
fn id(&self) -> String {
29+
CREDENTIALS_METHOD_ID.to_owned()
30+
}
31+
32+
async fn providers(&self) -> Result<Vec<Box<dyn Provider>>, ShieldError> {
33+
Ok(vec![Box::new(CredentialsProvider::new(
34+
self.credentials.clone(),
35+
))])
36+
}
37+
38+
async fn provider_by_id(
39+
&self,
40+
_provider_id: &str,
41+
) -> Result<Option<Box<dyn Provider>>, ShieldError> {
42+
Ok(None)
43+
}
44+
45+
async fn sign_in(
46+
&self,
47+
request: SignInRequest,
48+
session: Session,
49+
options: &ShieldOptions,
50+
) -> Result<Response, ShieldError> {
51+
if request.provider_id.is_some() {
52+
return Err(ShieldError::Validation(
53+
"Provider should be none.".to_owned(),
54+
));
55+
}
56+
57+
let Some(form_data) = request.form_data else {
58+
return Err(ShieldError::Validation("Missing form data.".to_owned()));
59+
};
60+
61+
let data = serde_json::from_value(form_data)
62+
.map_err(|err| ShieldError::Validation(err.to_string()))?;
63+
64+
let user = self.credentials.sign_in(data).await?;
65+
66+
session.renew().await?;
67+
68+
{
69+
let session_data = session.data();
70+
let mut session_data = session_data
71+
.lock()
72+
.map_err(|err| SessionError::Lock(err.to_string()))?;
73+
74+
session_data.authentication = Some(Authentication {
75+
method_id: self.id(),
76+
provider_id: None,
77+
user_id: user.id(),
78+
});
79+
}
80+
81+
Ok(Response::Redirect(
82+
request
83+
.redirect_url
84+
.unwrap_or(options.sign_in_redirect.clone()),
85+
))
86+
}
87+
88+
async fn sign_in_callback(
89+
&self,
90+
_request: SignInCallbackRequest,
91+
_session: Session,
92+
_options: &ShieldOptions,
93+
) -> Result<Response, ShieldError> {
94+
Err(ShieldError::Validation(
95+
"Credentials method does not have a sign in callback.".to_owned(),
96+
))
97+
}
98+
99+
async fn sign_out(
100+
&self,
101+
_request: SignOutRequest,
102+
_session: Session,
103+
_options: &ShieldOptions,
104+
) -> Result<Option<Response>, ShieldError> {
105+
Ok(None)
106+
}
107+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
use std::sync::Arc;
2+
3+
use serde::de::DeserializeOwned;
4+
use shield::{Form, Provider, User};
5+
6+
use crate::{CREDENTIALS_METHOD_ID, Credentials};
7+
8+
pub struct CredentialsProvider<U: User, D: DeserializeOwned> {
9+
credentials: Arc<dyn Credentials<U, D>>,
10+
}
11+
12+
impl<U: User, D: DeserializeOwned> CredentialsProvider<U, D> {
13+
pub(crate) fn new(credentials: Arc<dyn Credentials<U, D>>) -> Self {
14+
Self { credentials }
15+
}
16+
}
17+
18+
impl<U: User, D: DeserializeOwned> Provider for CredentialsProvider<U, D> {
19+
fn method_id(&self) -> String {
20+
CREDENTIALS_METHOD_ID.to_owned()
21+
}
22+
23+
fn id(&self) -> Option<String> {
24+
None
25+
}
26+
27+
fn name(&self) -> String {
28+
"Credentials".to_owned()
29+
}
30+
31+
fn icon_url(&self) -> Option<String> {
32+
None
33+
}
34+
35+
fn form(&self) -> Option<Form> {
36+
Some(self.credentials.form())
37+
}
38+
}

0 commit comments

Comments
 (0)