axum_extra/
typed_header.rs1use axum::{
4 async_trait,
5 extract::FromRequestParts,
6 response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
7};
8use headers::{Header, HeaderMapExt};
9use http::{request::Parts, StatusCode};
10use std::convert::Infallible;
11
12#[cfg(feature = "typed-header")]
54#[derive(Debug, Clone, Copy)]
55#[must_use]
56pub struct TypedHeader<T>(pub T);
57
58#[async_trait]
59impl<T, S> FromRequestParts<S> for TypedHeader<T>
60where
61 T: Header,
62 S: Send + Sync,
63{
64 type Rejection = TypedHeaderRejection;
65
66 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
67 let mut values = parts.headers.get_all(T::name()).iter();
68 let is_missing = values.size_hint() == (0, Some(0));
69 T::decode(&mut values)
70 .map(Self)
71 .map_err(|err| TypedHeaderRejection {
72 name: T::name(),
73 reason: if is_missing {
74 TypedHeaderRejectionReason::Missing
76 } else {
77 TypedHeaderRejectionReason::Error(err)
78 },
79 })
80 }
81}
82
83axum_core::__impl_deref!(TypedHeader);
84
85impl<T> IntoResponseParts for TypedHeader<T>
86where
87 T: Header,
88{
89 type Error = Infallible;
90
91 fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
92 res.headers_mut().typed_insert(self.0);
93 Ok(res)
94 }
95}
96
97impl<T> IntoResponse for TypedHeader<T>
98where
99 T: Header,
100{
101 fn into_response(self) -> Response {
102 let mut res = ().into_response();
103 res.headers_mut().typed_insert(self.0);
104 res
105 }
106}
107
108#[cfg(feature = "typed-header")]
110#[derive(Debug)]
111pub struct TypedHeaderRejection {
112 name: &'static http::header::HeaderName,
113 reason: TypedHeaderRejectionReason,
114}
115
116impl TypedHeaderRejection {
117 pub fn name(&self) -> &http::header::HeaderName {
119 self.name
120 }
121
122 pub fn reason(&self) -> &TypedHeaderRejectionReason {
124 &self.reason
125 }
126
127 #[must_use]
131 pub fn is_missing(&self) -> bool {
132 self.reason.is_missing()
133 }
134}
135
136#[cfg(feature = "typed-header")]
138#[derive(Debug)]
139#[non_exhaustive]
140pub enum TypedHeaderRejectionReason {
141 Missing,
143 Error(headers::Error),
145}
146
147impl TypedHeaderRejectionReason {
148 #[must_use]
152 pub fn is_missing(&self) -> bool {
153 matches!(self, Self::Missing)
154 }
155}
156
157impl IntoResponse for TypedHeaderRejection {
158 fn into_response(self) -> Response {
159 let status = StatusCode::BAD_REQUEST;
160 let body = self.to_string();
161 axum_core::__log_rejection!(rejection_type = Self, body_text = body, status = status,);
162 (status, body).into_response()
163 }
164}
165
166impl std::fmt::Display for TypedHeaderRejection {
167 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168 match &self.reason {
169 TypedHeaderRejectionReason::Missing => {
170 write!(f, "Header of type `{}` was missing", self.name)
171 }
172 TypedHeaderRejectionReason::Error(err) => {
173 write!(f, "{} ({})", err, self.name)
174 }
175 }
176 }
177}
178
179impl std::error::Error for TypedHeaderRejection {
180 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
181 match &self.reason {
182 TypedHeaderRejectionReason::Error(err) => Some(err),
183 TypedHeaderRejectionReason::Missing => None,
184 }
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use crate::test_helpers::*;
192 use axum::{routing::get, Router};
193
194 #[tokio::test]
195 async fn typed_header() {
196 async fn handle(
197 TypedHeader(user_agent): TypedHeader<headers::UserAgent>,
198 TypedHeader(cookies): TypedHeader<headers::Cookie>,
199 ) -> impl IntoResponse {
200 let user_agent = user_agent.as_str();
201 let cookies = cookies.iter().collect::<Vec<_>>();
202 format!("User-Agent={user_agent:?}, Cookie={cookies:?}")
203 }
204
205 let app = Router::new().route("/", get(handle));
206
207 let client = TestClient::new(app);
208
209 let res = client
210 .get("/")
211 .header("user-agent", "foobar")
212 .header("cookie", "a=1; b=2")
213 .header("cookie", "c=3")
214 .await;
215 let body = res.text().await;
216 assert_eq!(
217 body,
218 r#"User-Agent="foobar", Cookie=[("a", "1"), ("b", "2"), ("c", "3")]"#
219 );
220
221 let res = client.get("/").header("user-agent", "foobar").await;
222 let body = res.text().await;
223 assert_eq!(body, r#"User-Agent="foobar", Cookie=[]"#);
224
225 let res = client.get("/").header("cookie", "a=1").await;
226 let body = res.text().await;
227 assert_eq!(body, "Header of type `user-agent` was missing");
228 }
229}