axum_extra/extract/
with_rejection.rs

1use axum::async_trait;
2use axum::extract::{FromRequest, FromRequestParts, Request};
3use axum::response::IntoResponse;
4use http::request::Parts;
5use std::fmt::{Debug, Display};
6use std::marker::PhantomData;
7use std::ops::{Deref, DerefMut};
8
9#[cfg(feature = "typed-routing")]
10use crate::routing::TypedPath;
11
12/// Extractor for customizing extractor rejections
13///
14/// `WithRejection` wraps another extractor and gives you the result. If the
15/// extraction fails, the `Rejection` is transformed into `R` and returned as a
16/// response
17///
18/// `E` is expected to implement [`FromRequest`]
19///
20/// `R` is expected to implement [`IntoResponse`] and [`From<E::Rejection>`]
21///
22///
23/// # Example
24///
25/// ```rust
26/// use axum::extract::rejection::JsonRejection;
27/// use axum::response::{Response, IntoResponse};
28/// use axum::Json;
29/// use axum_extra::extract::WithRejection;
30/// use serde::Deserialize;
31///
32/// struct MyRejection { /* ... */ }
33///
34/// impl From<JsonRejection> for MyRejection {
35///     fn from(rejection: JsonRejection) -> MyRejection {
36///         // ...
37///         # todo!()
38///     }
39/// }
40///
41/// impl IntoResponse for MyRejection {
42///     fn into_response(self) -> Response {
43///         // ...
44///         # todo!()
45///     }
46/// }
47/// #[derive(Debug, Deserialize)]
48/// struct Person { /* ... */ }
49///
50/// async fn handler(
51///     // If the `Json` extractor ever fails, `MyRejection` will be sent to the
52///     // client using the `IntoResponse` impl
53///     WithRejection(Json(Person), _): WithRejection<Json<Person>, MyRejection>
54/// ) { /* ... */ }
55/// # let _: axum::Router = axum::Router::new().route("/", axum::routing::get(handler));
56/// ```
57///
58/// [`FromRequest`]: axum::extract::FromRequest
59/// [`IntoResponse`]: axum::response::IntoResponse
60/// [`From<E::Rejection>`]: std::convert::From
61pub struct WithRejection<E, R>(pub E, pub PhantomData<R>);
62
63impl<E, R> WithRejection<E, R> {
64    /// Returns the wrapped extractor
65    pub fn into_inner(self) -> E {
66        self.0
67    }
68}
69
70impl<E, R> Debug for WithRejection<E, R>
71where
72    E: Debug,
73{
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        f.debug_tuple("WithRejection")
76            .field(&self.0)
77            .field(&self.1)
78            .finish()
79    }
80}
81
82impl<E, R> Clone for WithRejection<E, R>
83where
84    E: Clone,
85{
86    fn clone(&self) -> Self {
87        Self(self.0.clone(), self.1)
88    }
89}
90
91impl<E, R> Copy for WithRejection<E, R> where E: Copy {}
92
93impl<E: Default, R> Default for WithRejection<E, R> {
94    fn default() -> Self {
95        Self(Default::default(), Default::default())
96    }
97}
98
99impl<E, R> Deref for WithRejection<E, R> {
100    type Target = E;
101
102    fn deref(&self) -> &Self::Target {
103        &self.0
104    }
105}
106
107impl<E, R> DerefMut for WithRejection<E, R> {
108    fn deref_mut(&mut self) -> &mut Self::Target {
109        &mut self.0
110    }
111}
112
113#[async_trait]
114impl<E, R, S> FromRequest<S> for WithRejection<E, R>
115where
116    S: Send + Sync,
117    E: FromRequest<S>,
118    R: From<E::Rejection> + IntoResponse,
119{
120    type Rejection = R;
121
122    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
123        let extractor = E::from_request(req, state).await?;
124        Ok(WithRejection(extractor, PhantomData))
125    }
126}
127
128#[async_trait]
129impl<E, R, S> FromRequestParts<S> for WithRejection<E, R>
130where
131    S: Send + Sync,
132    E: FromRequestParts<S>,
133    R: From<E::Rejection> + IntoResponse,
134{
135    type Rejection = R;
136
137    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
138        let extractor = E::from_request_parts(parts, state).await?;
139        Ok(WithRejection(extractor, PhantomData))
140    }
141}
142
143#[cfg(feature = "typed-routing")]
144impl<E, R> TypedPath for WithRejection<E, R>
145where
146    E: TypedPath,
147{
148    const PATH: &'static str = E::PATH;
149}
150
151impl<E, R> Display for WithRejection<E, R>
152where
153    E: Display,
154{
155    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156        write!(f, "{}", self.0)
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163    use axum::body::Body;
164    use axum::http::Request;
165    use axum::response::Response;
166
167    #[tokio::test]
168    async fn extractor_rejection_is_transformed() {
169        struct TestExtractor;
170        struct TestRejection;
171
172        #[async_trait]
173        impl<S> FromRequestParts<S> for TestExtractor
174        where
175            S: Send + Sync,
176        {
177            type Rejection = ();
178
179            async fn from_request_parts(
180                _parts: &mut Parts,
181                _state: &S,
182            ) -> Result<Self, Self::Rejection> {
183                Err(())
184            }
185        }
186
187        impl IntoResponse for TestRejection {
188            fn into_response(self) -> Response {
189                ().into_response()
190            }
191        }
192
193        impl From<()> for TestRejection {
194            fn from(_: ()) -> Self {
195                TestRejection
196            }
197        }
198
199        let req = Request::new(Body::empty());
200        let result = WithRejection::<TestExtractor, TestRejection>::from_request(req, &()).await;
201        assert!(matches!(result, Err(TestRejection)));
202
203        let (mut parts, _) = Request::new(()).into_parts();
204        let result =
205            WithRejection::<TestExtractor, TestRejection>::from_request_parts(&mut parts, &())
206                .await;
207        assert!(matches!(result, Err(TestRejection)));
208    }
209}