|
1 | 1 | use async_trait::async_trait; |
2 | | -use axum::{ |
3 | | - extract::FromRequestParts, |
4 | | - http::{request::Parts, StatusCode}, |
5 | | -}; |
6 | | -use shield::{Session, Shield, User}; |
| 2 | +use axum::{extract::FromRequestParts, http::request::Parts}; |
| 3 | +use shield::{ConfigurationError, Session, Shield, ShieldError, User}; |
| 4 | + |
| 5 | +use crate::error::RouteError; |
7 | 6 |
|
8 | 7 | pub struct ExtractShield<U: User>(pub Shield<U>); |
9 | 8 |
|
10 | 9 | #[async_trait] |
11 | 10 | impl<S: Send + Sync, U: User + Clone + 'static> FromRequestParts<S> for ExtractShield<U> { |
12 | | - type Rejection = (StatusCode, &'static str); |
| 11 | + type Rejection = RouteError; |
13 | 12 |
|
14 | 13 | async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> { |
15 | 14 | parts |
16 | 15 | .extensions |
17 | 16 | .get::<Shield<U>>() |
18 | 17 | .cloned() |
19 | 18 | .map(ExtractShield) |
20 | | - .ok_or(( |
21 | | - StatusCode::INTERNAL_SERVER_ERROR, |
22 | | - "Can't extract Shield. Is `ShieldLayer` enabled?", |
23 | | - )) |
| 19 | + .ok_or(ShieldError::Configuration(ConfigurationError::Invalid( |
| 20 | + "Can't extract Shield. Is `ShieldLayer` enabled?".to_owned(), |
| 21 | + ))) |
| 22 | + .map_err(RouteError::from) |
24 | 23 | } |
25 | 24 | } |
26 | 25 |
|
27 | 26 | pub struct ExtractSession(pub Session); |
28 | 27 |
|
29 | 28 | #[async_trait] |
30 | 29 | impl<S: Send + Sync> FromRequestParts<S> for ExtractSession { |
31 | | - type Rejection = (StatusCode, &'static str); |
| 30 | + type Rejection = RouteError; |
32 | 31 |
|
33 | 32 | async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> { |
34 | 33 | parts |
35 | 34 | .extensions |
36 | 35 | .get::<Session>() |
37 | 36 | .cloned() |
38 | 37 | .map(ExtractSession) |
39 | | - .ok_or(( |
40 | | - StatusCode::INTERNAL_SERVER_ERROR, |
41 | | - "Can't extract Shield session. Is `ShieldLayer` enabled?", |
42 | | - )) |
| 38 | + .ok_or(ShieldError::Configuration(ConfigurationError::Invalid( |
| 39 | + "Can't extract Shield. Is `ShieldLayer` enabled?".to_owned(), |
| 40 | + ))) |
| 41 | + .map_err(RouteError::from) |
43 | 42 | } |
44 | 43 | } |
45 | 44 |
|
46 | 45 | pub struct ExtractUser<U: User>(pub Option<U>); |
47 | 46 |
|
48 | 47 | #[async_trait] |
49 | 48 | impl<S: Send + Sync, U: User + Clone + 'static> FromRequestParts<S> for ExtractUser<U> { |
50 | | - type Rejection = (StatusCode, &'static str); |
| 49 | + type Rejection = RouteError; |
51 | 50 |
|
52 | 51 | async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> { |
53 | 52 | parts |
54 | 53 | .extensions |
55 | 54 | .get::<Option<U>>() |
56 | 55 | .cloned() |
57 | 56 | .map(ExtractUser) |
58 | | - .ok_or(( |
59 | | - StatusCode::INTERNAL_SERVER_ERROR, |
60 | | - "Can't extract Shield user. Is `ShieldLayer` enabled?", |
61 | | - )) |
| 57 | + .ok_or(ShieldError::Configuration(ConfigurationError::Invalid( |
| 58 | + "Can't extract Shield. Is `ShieldLayer` enabled?".to_owned(), |
| 59 | + ))) |
| 60 | + .map_err(RouteError::from) |
62 | 61 | } |
63 | 62 | } |
64 | 63 |
|
65 | 64 | pub struct UserRequired<U: User>(pub U); |
66 | 65 |
|
67 | 66 | #[async_trait] |
68 | 67 | impl<S: Send + Sync, U: User + Clone + 'static> FromRequestParts<S> for UserRequired<U> { |
69 | | - type Rejection = (StatusCode, &'static str); |
| 68 | + type Rejection = RouteError; |
70 | 69 |
|
71 | 70 | async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> { |
72 | 71 | parts |
73 | 72 | .extensions |
74 | 73 | .get::<Option<U>>() |
75 | 74 | .cloned() |
76 | | - .ok_or(( |
77 | | - StatusCode::INTERNAL_SERVER_ERROR, |
78 | | - "Can't extract Shield user. Is `ShieldLayer` enabled?", |
79 | | - )) |
80 | | - .and_then(|user| user.ok_or((StatusCode::UNAUTHORIZED, "Unauthorized"))) |
| 75 | + .ok_or(ShieldError::Configuration(ConfigurationError::Invalid( |
| 76 | + "Can't extract Shield. Is `ShieldLayer` enabled?".to_owned(), |
| 77 | + ))) |
| 78 | + .and_then(|user| user.ok_or(ShieldError::Unauthorized)) |
81 | 79 | .map(UserRequired) |
| 80 | + .map_err(RouteError::from) |
82 | 81 | } |
83 | 82 | } |
0 commit comments