axum_extra/extract/
multipart.rs1use axum::{
6 async_trait,
7 body::{Body, Bytes},
8 extract::FromRequest,
9 response::{IntoResponse, Response},
10 RequestExt,
11};
12use futures_util::stream::Stream;
13use http::{
14 header::{HeaderMap, CONTENT_TYPE},
15 Request, StatusCode,
16};
17use std::{
18 error::Error,
19 fmt,
20 pin::Pin,
21 task::{Context, Poll},
22};
23
24#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
88#[derive(Debug)]
89pub struct Multipart {
90 inner: multer::Multipart<'static>,
91}
92
93#[async_trait]
94impl<S> FromRequest<S> for Multipart
95where
96 S: Send + Sync,
97{
98 type Rejection = MultipartRejection;
99
100 async fn from_request(req: Request<Body>, _state: &S) -> Result<Self, Self::Rejection> {
101 let boundary = parse_boundary(req.headers()).ok_or(InvalidBoundary)?;
102 let stream = req.with_limited_body().into_body();
103 let multipart = multer::Multipart::new(stream.into_data_stream(), boundary);
104 Ok(Self { inner: multipart })
105 }
106}
107
108impl Multipart {
109 pub async fn next_field(&mut self) -> Result<Option<Field>, MultipartError> {
111 let field = self
112 .inner
113 .next_field()
114 .await
115 .map_err(MultipartError::from_multer)?;
116
117 if let Some(field) = field {
118 Ok(Some(Field { inner: field }))
119 } else {
120 Ok(None)
121 }
122 }
123
124 pub fn into_stream(self) -> impl Stream<Item = Result<Field, MultipartError>> + Send + 'static {
126 futures_util::stream::try_unfold(self, |mut multipart| async move {
127 let field = multipart.next_field().await?;
128 Ok(field.map(|field| (field, multipart)))
129 })
130 }
131}
132
133#[derive(Debug)]
135pub struct Field {
136 inner: multer::Field<'static>,
137}
138
139impl Stream for Field {
140 type Item = Result<Bytes, MultipartError>;
141
142 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
143 Pin::new(&mut self.inner)
144 .poll_next(cx)
145 .map_err(MultipartError::from_multer)
146 }
147}
148
149impl Field {
150 pub fn name(&self) -> Option<&str> {
154 self.inner.name()
155 }
156
157 pub fn file_name(&self) -> Option<&str> {
161 self.inner.file_name()
162 }
163
164 pub fn content_type(&self) -> Option<&str> {
166 self.inner.content_type().map(|m| m.as_ref())
167 }
168
169 pub fn headers(&self) -> &HeaderMap {
171 self.inner.headers()
172 }
173
174 pub async fn bytes(self) -> Result<Bytes, MultipartError> {
176 self.inner
177 .bytes()
178 .await
179 .map_err(MultipartError::from_multer)
180 }
181
182 pub async fn text(self) -> Result<String, MultipartError> {
184 self.inner.text().await.map_err(MultipartError::from_multer)
185 }
186
187 pub async fn chunk(&mut self) -> Result<Option<Bytes>, MultipartError> {
226 self.inner
227 .chunk()
228 .await
229 .map_err(MultipartError::from_multer)
230 }
231}
232
233#[derive(Debug)]
235pub struct MultipartError {
236 source: multer::Error,
237}
238
239impl MultipartError {
240 fn from_multer(multer: multer::Error) -> Self {
241 Self { source: multer }
242 }
243
244 pub fn body_text(&self) -> String {
246 axum_core::__log_rejection!(
247 rejection_type = Self,
248 body_text = self.body_text(),
249 status = self.status(),
250 );
251 self.source.to_string()
252 }
253
254 pub fn status(&self) -> http::StatusCode {
256 status_code_from_multer_error(&self.source)
257 }
258}
259
260fn status_code_from_multer_error(err: &multer::Error) -> StatusCode {
261 match err {
262 multer::Error::UnknownField { .. }
263 | multer::Error::IncompleteFieldData { .. }
264 | multer::Error::IncompleteHeaders
265 | multer::Error::ReadHeaderFailed(..)
266 | multer::Error::DecodeHeaderName { .. }
267 | multer::Error::DecodeContentType(..)
268 | multer::Error::NoBoundary
269 | multer::Error::DecodeHeaderValue { .. }
270 | multer::Error::NoMultipart
271 | multer::Error::IncompleteStream => StatusCode::BAD_REQUEST,
272 multer::Error::FieldSizeExceeded { .. } | multer::Error::StreamSizeExceeded { .. } => {
273 StatusCode::PAYLOAD_TOO_LARGE
274 }
275 multer::Error::StreamReadFailed(err) => {
276 if let Some(err) = err.downcast_ref::<multer::Error>() {
277 return status_code_from_multer_error(err);
278 }
279
280 if err
281 .downcast_ref::<axum::Error>()
282 .and_then(|err| err.source())
283 .and_then(|err| err.downcast_ref::<http_body_util::LengthLimitError>())
284 .is_some()
285 {
286 return StatusCode::PAYLOAD_TOO_LARGE;
287 }
288
289 StatusCode::INTERNAL_SERVER_ERROR
290 }
291 _ => StatusCode::INTERNAL_SERVER_ERROR,
292 }
293}
294
295impl IntoResponse for MultipartError {
296 fn into_response(self) -> Response {
297 (self.status(), self.body_text()).into_response()
298 }
299}
300
301impl fmt::Display for MultipartError {
302 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
303 write!(f, "Error parsing `multipart/form-data` request")
304 }
305}
306
307impl std::error::Error for MultipartError {
308 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
309 Some(&self.source)
310 }
311}
312
313fn parse_boundary(headers: &HeaderMap) -> Option<String> {
314 let content_type = headers.get(CONTENT_TYPE)?.to_str().ok()?;
315 multer::parse_boundary(content_type).ok()
316}
317
318#[derive(Debug)]
322#[non_exhaustive]
323pub enum MultipartRejection {
324 #[allow(missing_docs)]
325 InvalidBoundary(InvalidBoundary),
326}
327
328impl IntoResponse for MultipartRejection {
329 fn into_response(self) -> Response {
330 match self {
331 Self::InvalidBoundary(inner) => inner.into_response(),
332 }
333 }
334}
335
336impl MultipartRejection {
337 pub fn body_text(&self) -> String {
339 match self {
340 Self::InvalidBoundary(inner) => inner.body_text(),
341 }
342 }
343
344 pub fn status(&self) -> http::StatusCode {
346 match self {
347 Self::InvalidBoundary(inner) => inner.status(),
348 }
349 }
350}
351
352impl From<InvalidBoundary> for MultipartRejection {
353 fn from(inner: InvalidBoundary) -> Self {
354 Self::InvalidBoundary(inner)
355 }
356}
357
358impl std::fmt::Display for MultipartRejection {
359 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360 match self {
361 Self::InvalidBoundary(inner) => write!(f, "{}", inner.body_text()),
362 }
363 }
364}
365
366impl std::error::Error for MultipartRejection {
367 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
368 match self {
369 Self::InvalidBoundary(inner) => Some(inner),
370 }
371 }
372}
373
374#[derive(Debug, Default)]
377#[non_exhaustive]
378pub struct InvalidBoundary;
379
380impl IntoResponse for InvalidBoundary {
381 fn into_response(self) -> Response {
382 let body = self.body_text();
383 axum_core::__log_rejection!(
384 rejection_type = Self,
385 body_text = body,
386 status = self.status(),
387 );
388 (self.status(), body).into_response()
389 }
390}
391
392impl InvalidBoundary {
393 pub fn body_text(&self) -> String {
395 "Invalid `boundary` for `multipart/form-data` request".into()
396 }
397
398 pub fn status(&self) -> http::StatusCode {
400 http::StatusCode::BAD_REQUEST
401 }
402}
403
404impl std::fmt::Display for InvalidBoundary {
405 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
406 write!(f, "{}", self.body_text())
407 }
408}
409
410impl std::error::Error for InvalidBoundary {}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415 use crate::test_helpers::*;
416 use axum::{extract::DefaultBodyLimit, routing::post, Router};
417
418 #[tokio::test]
419 async fn content_type_with_encoding() {
420 const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
421 const FILE_NAME: &str = "index.html";
422 const CONTENT_TYPE: &str = "text/html; charset=utf-8";
423
424 async fn handle(mut multipart: Multipart) -> impl IntoResponse {
425 let field = multipart.next_field().await.unwrap().unwrap();
426
427 assert_eq!(field.file_name().unwrap(), FILE_NAME);
428 assert_eq!(field.content_type().unwrap(), CONTENT_TYPE);
429 assert_eq!(field.bytes().await.unwrap(), BYTES);
430
431 assert!(multipart.next_field().await.unwrap().is_none());
432 }
433
434 let app = Router::new().route("/", post(handle));
435
436 let client = TestClient::new(app);
437
438 let form = reqwest::multipart::Form::new().part(
439 "file",
440 reqwest::multipart::Part::bytes(BYTES)
441 .file_name(FILE_NAME)
442 .mime_str(CONTENT_TYPE)
443 .unwrap(),
444 );
445
446 client.post("/").multipart(form).await;
447 }
448
449 fn _multipart_from_request_limited() {
451 async fn handler(_: Multipart) {}
452 let _app: Router<()> = Router::new().route("/", post(handler));
453 }
454
455 #[tokio::test]
456 async fn body_too_large() {
457 const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
458
459 async fn handle(mut multipart: Multipart) -> Result<(), MultipartError> {
460 while let Some(field) = multipart.next_field().await? {
461 field.bytes().await?;
462 }
463 Ok(())
464 }
465
466 let app = Router::new()
467 .route("/", post(handle))
468 .layer(DefaultBodyLimit::max(BYTES.len() - 1));
469
470 let client = TestClient::new(app);
471
472 let form =
473 reqwest::multipart::Form::new().part("file", reqwest::multipart::Part::bytes(BYTES));
474
475 let res = client.post("/").multipart(form).await;
476 assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
477 }
478}