axum_extra/extract/
multipart.rs

1//! Extractor that parses `multipart/form-data` requests commonly used with file uploads.
2//!
3//! See [`Multipart`] for more details.
4
5use axum::{
6    async_trait,
7    body::{Body, Bytes},
8    extract::FromRequest,
9    response::{IntoResponse, Response},
10    RequestExt,
11};
12use futures_util::stream::Stream;
13use http::{
14    header::{HeaderMap, CONTENT_TYPE},
15    Request, StatusCode,
16};
17use std::{
18    error::Error,
19    fmt,
20    pin::Pin,
21    task::{Context, Poll},
22};
23
24/// Extractor that parses `multipart/form-data` requests (commonly used with file uploads).
25///
26/// ⚠️ Since extracting multipart form data from the request requires consuming the body, the
27/// `Multipart` extractor must be *last* if there are multiple extractors in a handler.
28/// See ["the order of extractors"][order-of-extractors]
29///
30/// [order-of-extractors]: crate::extract#the-order-of-extractors
31///
32/// # Example
33///
34/// ```
35/// use axum::{
36///     routing::post,
37///     Router,
38/// };
39/// use axum_extra::extract::Multipart;
40///
41/// async fn upload(mut multipart: Multipart) {
42///     while let Some(mut field) = multipart.next_field().await.unwrap() {
43///         let name = field.name().unwrap().to_string();
44///         let data = field.bytes().await.unwrap();
45///
46///         println!("Length of `{}` is {} bytes", name, data.len());
47///     }
48/// }
49///
50/// let app = Router::new().route("/upload", post(upload));
51/// # let _: Router = app;
52/// ```
53///
54/// # Field Exclusivity
55///
56/// A [`Field`] represents a raw, self-decoding stream into multipart data. As such, only one
57/// [`Field`] from a given Multipart instance may be live at once. That is, a [`Field`] emitted by
58/// [`next_field()`] must be dropped before calling [`next_field()`] again. Failure to do so will
59/// result in an error.
60///
61/// ```
62/// use axum_extra::extract::Multipart;
63///
64/// async fn handler(mut multipart: Multipart) {
65///     let field_1 = multipart.next_field().await;
66///
67///     // We cannot get the next field while `field_1` is still alive. Have to drop `field_1`
68///     // first.
69///     let field_2 = multipart.next_field().await;
70///     assert!(field_2.is_err());
71/// }
72/// ```
73///
74/// In general you should consume `Multipart` by looping over the fields in order and make sure not
75/// to keep `Field`s around from previous loop iterations. That will minimize the risk of runtime
76/// errors.
77///
78/// # Differences between this and `axum::extract::Multipart`
79///
80/// `axum::extract::Multipart` uses lifetimes to enforce field exclusivity at compile time, however
81/// that leads to significant usability issues such as `Field` not being `'static`.
82///
83/// `axum_extra::extract::Multipart` instead enforces field exclusivity at runtime which makes
84/// things easier to use at the cost of possible runtime errors.
85///
86/// [`next_field()`]: Multipart::next_field
87#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
88#[derive(Debug)]
89pub struct Multipart {
90    inner: multer::Multipart<'static>,
91}
92
93#[async_trait]
94impl<S> FromRequest<S> for Multipart
95where
96    S: Send + Sync,
97{
98    type Rejection = MultipartRejection;
99
100    async fn from_request(req: Request<Body>, _state: &S) -> Result<Self, Self::Rejection> {
101        let boundary = parse_boundary(req.headers()).ok_or(InvalidBoundary)?;
102        let stream = req.with_limited_body().into_body();
103        let multipart = multer::Multipart::new(stream.into_data_stream(), boundary);
104        Ok(Self { inner: multipart })
105    }
106}
107
108impl Multipart {
109    /// Yields the next [`Field`] if available.
110    pub async fn next_field(&mut self) -> Result<Option<Field>, MultipartError> {
111        let field = self
112            .inner
113            .next_field()
114            .await
115            .map_err(MultipartError::from_multer)?;
116
117        if let Some(field) = field {
118            Ok(Some(Field { inner: field }))
119        } else {
120            Ok(None)
121        }
122    }
123
124    /// Convert the `Multipart` into a stream of its fields.
125    pub fn into_stream(self) -> impl Stream<Item = Result<Field, MultipartError>> + Send + 'static {
126        futures_util::stream::try_unfold(self, |mut multipart| async move {
127            let field = multipart.next_field().await?;
128            Ok(field.map(|field| (field, multipart)))
129        })
130    }
131}
132
133/// A single field in a multipart stream.
134#[derive(Debug)]
135pub struct Field {
136    inner: multer::Field<'static>,
137}
138
139impl Stream for Field {
140    type Item = Result<Bytes, MultipartError>;
141
142    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
143        Pin::new(&mut self.inner)
144            .poll_next(cx)
145            .map_err(MultipartError::from_multer)
146    }
147}
148
149impl Field {
150    /// The field name found in the
151    /// [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition)
152    /// header.
153    pub fn name(&self) -> Option<&str> {
154        self.inner.name()
155    }
156
157    /// The file name found in the
158    /// [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition)
159    /// header.
160    pub fn file_name(&self) -> Option<&str> {
161        self.inner.file_name()
162    }
163
164    /// Get the [content type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type) of the field.
165    pub fn content_type(&self) -> Option<&str> {
166        self.inner.content_type().map(|m| m.as_ref())
167    }
168
169    /// Get a map of headers as [`HeaderMap`].
170    pub fn headers(&self) -> &HeaderMap {
171        self.inner.headers()
172    }
173
174    /// Get the full data of the field as [`Bytes`].
175    pub async fn bytes(self) -> Result<Bytes, MultipartError> {
176        self.inner
177            .bytes()
178            .await
179            .map_err(MultipartError::from_multer)
180    }
181
182    /// Get the full field data as text.
183    pub async fn text(self) -> Result<String, MultipartError> {
184        self.inner.text().await.map_err(MultipartError::from_multer)
185    }
186
187    /// Stream a chunk of the field data.
188    ///
189    /// When the field data has been exhausted, this will return [`None`].
190    ///
191    /// Note this does the same thing as `Field`'s [`Stream`] implementation.
192    ///
193    /// # Example
194    ///
195    /// ```
196    /// use axum::{
197    ///    routing::post,
198    ///    response::IntoResponse,
199    ///    http::StatusCode,
200    ///    Router,
201    /// };
202    /// use axum_extra::extract::Multipart;
203    ///
204    /// async fn upload(mut multipart: Multipart) -> Result<(), (StatusCode, String)> {
205    ///     while let Some(mut field) = multipart
206    ///         .next_field()
207    ///         .await
208    ///         .map_err(|err| (StatusCode::BAD_REQUEST, err.to_string()))?
209    ///     {
210    ///         while let Some(chunk) = field
211    ///             .chunk()
212    ///             .await
213    ///             .map_err(|err| (StatusCode::BAD_REQUEST, err.to_string()))?
214    ///         {
215    ///             println!("received {} bytes", chunk.len());
216    ///         }
217    ///     }
218    ///
219    ///     Ok(())
220    /// }
221    ///
222    /// let app = Router::new().route("/upload", post(upload));
223    /// # let _: Router = app;
224    /// ```
225    pub async fn chunk(&mut self) -> Result<Option<Bytes>, MultipartError> {
226        self.inner
227            .chunk()
228            .await
229            .map_err(MultipartError::from_multer)
230    }
231}
232
233/// Errors associated with parsing `multipart/form-data` requests.
234#[derive(Debug)]
235pub struct MultipartError {
236    source: multer::Error,
237}
238
239impl MultipartError {
240    fn from_multer(multer: multer::Error) -> Self {
241        Self { source: multer }
242    }
243
244    /// Get the response body text used for this rejection.
245    pub fn body_text(&self) -> String {
246        axum_core::__log_rejection!(
247            rejection_type = Self,
248            body_text = self.body_text(),
249            status = self.status(),
250        );
251        self.source.to_string()
252    }
253
254    /// Get the status code used for this rejection.
255    pub fn status(&self) -> http::StatusCode {
256        status_code_from_multer_error(&self.source)
257    }
258}
259
260fn status_code_from_multer_error(err: &multer::Error) -> StatusCode {
261    match err {
262        multer::Error::UnknownField { .. }
263        | multer::Error::IncompleteFieldData { .. }
264        | multer::Error::IncompleteHeaders
265        | multer::Error::ReadHeaderFailed(..)
266        | multer::Error::DecodeHeaderName { .. }
267        | multer::Error::DecodeContentType(..)
268        | multer::Error::NoBoundary
269        | multer::Error::DecodeHeaderValue { .. }
270        | multer::Error::NoMultipart
271        | multer::Error::IncompleteStream => StatusCode::BAD_REQUEST,
272        multer::Error::FieldSizeExceeded { .. } | multer::Error::StreamSizeExceeded { .. } => {
273            StatusCode::PAYLOAD_TOO_LARGE
274        }
275        multer::Error::StreamReadFailed(err) => {
276            if let Some(err) = err.downcast_ref::<multer::Error>() {
277                return status_code_from_multer_error(err);
278            }
279
280            if err
281                .downcast_ref::<axum::Error>()
282                .and_then(|err| err.source())
283                .and_then(|err| err.downcast_ref::<http_body_util::LengthLimitError>())
284                .is_some()
285            {
286                return StatusCode::PAYLOAD_TOO_LARGE;
287            }
288
289            StatusCode::INTERNAL_SERVER_ERROR
290        }
291        _ => StatusCode::INTERNAL_SERVER_ERROR,
292    }
293}
294
295impl IntoResponse for MultipartError {
296    fn into_response(self) -> Response {
297        (self.status(), self.body_text()).into_response()
298    }
299}
300
301impl fmt::Display for MultipartError {
302    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
303        write!(f, "Error parsing `multipart/form-data` request")
304    }
305}
306
307impl std::error::Error for MultipartError {
308    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
309        Some(&self.source)
310    }
311}
312
313fn parse_boundary(headers: &HeaderMap) -> Option<String> {
314    let content_type = headers.get(CONTENT_TYPE)?.to_str().ok()?;
315    multer::parse_boundary(content_type).ok()
316}
317
318/// Rejection used for [`Multipart`].
319///
320/// Contains one variant for each way the [`Multipart`] extractor can fail.
321#[derive(Debug)]
322#[non_exhaustive]
323pub enum MultipartRejection {
324    #[allow(missing_docs)]
325    InvalidBoundary(InvalidBoundary),
326}
327
328impl IntoResponse for MultipartRejection {
329    fn into_response(self) -> Response {
330        match self {
331            Self::InvalidBoundary(inner) => inner.into_response(),
332        }
333    }
334}
335
336impl MultipartRejection {
337    /// Get the response body text used for this rejection.
338    pub fn body_text(&self) -> String {
339        match self {
340            Self::InvalidBoundary(inner) => inner.body_text(),
341        }
342    }
343
344    /// Get the status code used for this rejection.
345    pub fn status(&self) -> http::StatusCode {
346        match self {
347            Self::InvalidBoundary(inner) => inner.status(),
348        }
349    }
350}
351
352impl From<InvalidBoundary> for MultipartRejection {
353    fn from(inner: InvalidBoundary) -> Self {
354        Self::InvalidBoundary(inner)
355    }
356}
357
358impl std::fmt::Display for MultipartRejection {
359    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360        match self {
361            Self::InvalidBoundary(inner) => write!(f, "{}", inner.body_text()),
362        }
363    }
364}
365
366impl std::error::Error for MultipartRejection {
367    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
368        match self {
369            Self::InvalidBoundary(inner) => Some(inner),
370        }
371    }
372}
373
374/// Rejection type used if the `boundary` in a `multipart/form-data` is
375/// missing or invalid.
376#[derive(Debug, Default)]
377#[non_exhaustive]
378pub struct InvalidBoundary;
379
380impl IntoResponse for InvalidBoundary {
381    fn into_response(self) -> Response {
382        let body = self.body_text();
383        axum_core::__log_rejection!(
384            rejection_type = Self,
385            body_text = body,
386            status = self.status(),
387        );
388        (self.status(), body).into_response()
389    }
390}
391
392impl InvalidBoundary {
393    /// Get the response body text used for this rejection.
394    pub fn body_text(&self) -> String {
395        "Invalid `boundary` for `multipart/form-data` request".into()
396    }
397
398    /// Get the status code used for this rejection.
399    pub fn status(&self) -> http::StatusCode {
400        http::StatusCode::BAD_REQUEST
401    }
402}
403
404impl std::fmt::Display for InvalidBoundary {
405    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
406        write!(f, "{}", self.body_text())
407    }
408}
409
410impl std::error::Error for InvalidBoundary {}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415    use crate::test_helpers::*;
416    use axum::{extract::DefaultBodyLimit, routing::post, Router};
417
418    #[tokio::test]
419    async fn content_type_with_encoding() {
420        const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
421        const FILE_NAME: &str = "index.html";
422        const CONTENT_TYPE: &str = "text/html; charset=utf-8";
423
424        async fn handle(mut multipart: Multipart) -> impl IntoResponse {
425            let field = multipart.next_field().await.unwrap().unwrap();
426
427            assert_eq!(field.file_name().unwrap(), FILE_NAME);
428            assert_eq!(field.content_type().unwrap(), CONTENT_TYPE);
429            assert_eq!(field.bytes().await.unwrap(), BYTES);
430
431            assert!(multipart.next_field().await.unwrap().is_none());
432        }
433
434        let app = Router::new().route("/", post(handle));
435
436        let client = TestClient::new(app);
437
438        let form = reqwest::multipart::Form::new().part(
439            "file",
440            reqwest::multipart::Part::bytes(BYTES)
441                .file_name(FILE_NAME)
442                .mime_str(CONTENT_TYPE)
443                .unwrap(),
444        );
445
446        client.post("/").multipart(form).await;
447    }
448
449    // No need for this to be a #[test], we just want to make sure it compiles
450    fn _multipart_from_request_limited() {
451        async fn handler(_: Multipart) {}
452        let _app: Router<()> = Router::new().route("/", post(handler));
453    }
454
455    #[tokio::test]
456    async fn body_too_large() {
457        const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
458
459        async fn handle(mut multipart: Multipart) -> Result<(), MultipartError> {
460            while let Some(field) = multipart.next_field().await? {
461                field.bytes().await?;
462            }
463            Ok(())
464        }
465
466        let app = Router::new()
467            .route("/", post(handle))
468            .layer(DefaultBodyLimit::max(BYTES.len() - 1));
469
470        let client = TestClient::new(app);
471
472        let form =
473            reqwest::multipart::Form::new().part("file", reqwest::multipart::Part::bytes(BYTES));
474
475        let res = client.post("/").multipart(form).await;
476        assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
477    }
478}