axum_extra/
typed_header.rs

1//! Extractor and response for typed headers.
2
3use axum::{
4    async_trait,
5    extract::FromRequestParts,
6    response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
7};
8use headers::{Header, HeaderMapExt};
9use http::{request::Parts, StatusCode};
10use std::convert::Infallible;
11
12/// Extractor and response that works with typed header values from [`headers`].
13///
14/// # As extractor
15///
16/// In general, it's recommended to extract only the needed headers via `TypedHeader` rather than
17/// removing all headers with the `HeaderMap` extractor.
18///
19/// ```rust,no_run
20/// use axum::{
21///     routing::get,
22///     Router,
23/// };
24/// use headers::UserAgent;
25/// use axum_extra::TypedHeader;
26///
27/// async fn users_teams_show(
28///     TypedHeader(user_agent): TypedHeader<UserAgent>,
29/// ) {
30///     // ...
31/// }
32///
33/// let app = Router::new().route("/users/:user_id/team/:team_id", get(users_teams_show));
34/// # let _: Router = app;
35/// ```
36///
37/// # As response
38///
39/// ```rust
40/// use axum::{
41///     response::IntoResponse,
42/// };
43/// use headers::ContentType;
44/// use axum_extra::TypedHeader;
45///
46/// async fn handler() -> (TypedHeader<ContentType>, &'static str) {
47///     (
48///         TypedHeader(ContentType::text_utf8()),
49///         "Hello, World!",
50///     )
51/// }
52/// ```
53#[cfg(feature = "typed-header")]
54#[derive(Debug, Clone, Copy)]
55#[must_use]
56pub struct TypedHeader<T>(pub T);
57
58#[async_trait]
59impl<T, S> FromRequestParts<S> for TypedHeader<T>
60where
61    T: Header,
62    S: Send + Sync,
63{
64    type Rejection = TypedHeaderRejection;
65
66    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
67        let mut values = parts.headers.get_all(T::name()).iter();
68        let is_missing = values.size_hint() == (0, Some(0));
69        T::decode(&mut values)
70            .map(Self)
71            .map_err(|err| TypedHeaderRejection {
72                name: T::name(),
73                reason: if is_missing {
74                    // Report a more precise rejection for the missing header case.
75                    TypedHeaderRejectionReason::Missing
76                } else {
77                    TypedHeaderRejectionReason::Error(err)
78                },
79            })
80    }
81}
82
83axum_core::__impl_deref!(TypedHeader);
84
85impl<T> IntoResponseParts for TypedHeader<T>
86where
87    T: Header,
88{
89    type Error = Infallible;
90
91    fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
92        res.headers_mut().typed_insert(self.0);
93        Ok(res)
94    }
95}
96
97impl<T> IntoResponse for TypedHeader<T>
98where
99    T: Header,
100{
101    fn into_response(self) -> Response {
102        let mut res = ().into_response();
103        res.headers_mut().typed_insert(self.0);
104        res
105    }
106}
107
108/// Rejection used for [`TypedHeader`].
109#[cfg(feature = "typed-header")]
110#[derive(Debug)]
111pub struct TypedHeaderRejection {
112    name: &'static http::header::HeaderName,
113    reason: TypedHeaderRejectionReason,
114}
115
116impl TypedHeaderRejection {
117    /// Name of the header that caused the rejection
118    pub fn name(&self) -> &http::header::HeaderName {
119        self.name
120    }
121
122    /// Reason why the header extraction has failed
123    pub fn reason(&self) -> &TypedHeaderRejectionReason {
124        &self.reason
125    }
126
127    /// Returns `true` if the typed header rejection reason is [`Missing`].
128    ///
129    /// [`Missing`]: TypedHeaderRejectionReason::Missing
130    #[must_use]
131    pub fn is_missing(&self) -> bool {
132        self.reason.is_missing()
133    }
134}
135
136/// Additional information regarding a [`TypedHeaderRejection`]
137#[cfg(feature = "typed-header")]
138#[derive(Debug)]
139#[non_exhaustive]
140pub enum TypedHeaderRejectionReason {
141    /// The header was missing from the HTTP request
142    Missing,
143    /// An error occurred when parsing the header from the HTTP request
144    Error(headers::Error),
145}
146
147impl TypedHeaderRejectionReason {
148    /// Returns `true` if the typed header rejection reason is [`Missing`].
149    ///
150    /// [`Missing`]: TypedHeaderRejectionReason::Missing
151    #[must_use]
152    pub fn is_missing(&self) -> bool {
153        matches!(self, Self::Missing)
154    }
155}
156
157impl IntoResponse for TypedHeaderRejection {
158    fn into_response(self) -> Response {
159        let status = StatusCode::BAD_REQUEST;
160        let body = self.to_string();
161        axum_core::__log_rejection!(rejection_type = Self, body_text = body, status = status,);
162        (status, body).into_response()
163    }
164}
165
166impl std::fmt::Display for TypedHeaderRejection {
167    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168        match &self.reason {
169            TypedHeaderRejectionReason::Missing => {
170                write!(f, "Header of type `{}` was missing", self.name)
171            }
172            TypedHeaderRejectionReason::Error(err) => {
173                write!(f, "{} ({})", err, self.name)
174            }
175        }
176    }
177}
178
179impl std::error::Error for TypedHeaderRejection {
180    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
181        match &self.reason {
182            TypedHeaderRejectionReason::Error(err) => Some(err),
183            TypedHeaderRejectionReason::Missing => None,
184        }
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191    use crate::test_helpers::*;
192    use axum::{routing::get, Router};
193
194    #[tokio::test]
195    async fn typed_header() {
196        async fn handle(
197            TypedHeader(user_agent): TypedHeader<headers::UserAgent>,
198            TypedHeader(cookies): TypedHeader<headers::Cookie>,
199        ) -> impl IntoResponse {
200            let user_agent = user_agent.as_str();
201            let cookies = cookies.iter().collect::<Vec<_>>();
202            format!("User-Agent={user_agent:?}, Cookie={cookies:?}")
203        }
204
205        let app = Router::new().route("/", get(handle));
206
207        let client = TestClient::new(app);
208
209        let res = client
210            .get("/")
211            .header("user-agent", "foobar")
212            .header("cookie", "a=1; b=2")
213            .header("cookie", "c=3")
214            .await;
215        let body = res.text().await;
216        assert_eq!(
217            body,
218            r#"User-Agent="foobar", Cookie=[("a", "1"), ("b", "2"), ("c", "3")]"#
219        );
220
221        let res = client.get("/").header("user-agent", "foobar").await;
222        let body = res.text().await;
223        assert_eq!(body, r#"User-Agent="foobar", Cookie=[]"#);
224
225        let res = client.get("/").header("cookie", "a=1").await;
226        let body = res.text().await;
227        assert_eq!(body, "Header of type `user-agent` was missing");
228    }
229}