axum_macros/
from_request.rs

1use self::attr::FromRequestContainerAttrs;
2use crate::{
3    attr_parsing::{parse_attrs, second},
4    from_request::attr::FromRequestFieldAttrs,
5};
6use proc_macro2::{Span, TokenStream};
7use quote::{quote, quote_spanned, ToTokens};
8use std::{collections::HashSet, fmt, iter};
9use syn::{
10    parse_quote, punctuated::Punctuated, spanned::Spanned, Fields, Ident, Path, Token, Type,
11};
12
13mod attr;
14
15#[derive(Clone, Copy)]
16pub(crate) enum Trait {
17    FromRequest,
18    FromRequestParts,
19}
20
21impl Trait {
22    fn via_marker_type(&self) -> Option<Type> {
23        match self {
24            Trait::FromRequest => Some(parse_quote!(M)),
25            Trait::FromRequestParts => None,
26        }
27    }
28}
29
30impl fmt::Display for Trait {
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        match self {
33            Trait::FromRequest => f.write_str("FromRequest"),
34            Trait::FromRequestParts => f.write_str("FromRequestParts"),
35        }
36    }
37}
38
39#[derive(Debug)]
40enum State {
41    Custom(syn::Type),
42    Default(syn::Type),
43    CannotInfer,
44}
45
46impl State {
47    /// ```not_rust
48    /// impl<T> A for B {}
49    ///      ^ this type
50    /// ```
51    fn impl_generics(&self) -> impl Iterator<Item = Type> {
52        match self {
53            State::Default(inner) => Some(inner.clone()),
54            State::Custom(_) => None,
55            State::CannotInfer => Some(parse_quote!(S)),
56        }
57        .into_iter()
58    }
59
60    /// ```not_rust
61    /// impl<T> A<T> for B {}
62    ///           ^ this type
63    /// ```
64    fn trait_generics(&self) -> impl Iterator<Item = Type> {
65        match self {
66            State::Default(inner) | State::Custom(inner) => iter::once(inner.clone()),
67            State::CannotInfer => iter::once(parse_quote!(S)),
68        }
69    }
70
71    fn bounds(&self) -> TokenStream {
72        match self {
73            State::Custom(_) => quote! {},
74            State::Default(inner) => quote! {
75                #inner: ::std::marker::Send + ::std::marker::Sync,
76            },
77            State::CannotInfer => quote! {
78                S: ::std::marker::Send + ::std::marker::Sync,
79            },
80        }
81    }
82}
83
84impl ToTokens for State {
85    fn to_tokens(&self, tokens: &mut TokenStream) {
86        match self {
87            State::Custom(inner) | State::Default(inner) => inner.to_tokens(tokens),
88            State::CannotInfer => quote! { S }.to_tokens(tokens),
89        }
90    }
91}
92
93pub(crate) fn expand(item: syn::Item, tr: Trait) -> syn::Result<TokenStream> {
94    match item {
95        syn::Item::Struct(item) => {
96            let syn::ItemStruct {
97                attrs,
98                ident,
99                generics,
100                fields,
101                semi_token: _,
102                vis: _,
103                struct_token: _,
104            } = item;
105
106            let generic_ident = parse_single_generic_type_on_struct(generics, &fields, tr)?;
107
108            let FromRequestContainerAttrs {
109                via,
110                rejection,
111                state,
112            } = parse_attrs("from_request", &attrs)?;
113
114            let state = match state {
115                Some((_, state)) => State::Custom(state),
116                None => {
117                    let mut inferred_state_types: HashSet<_> =
118                        infer_state_type_from_field_types(&fields)
119                            .chain(infer_state_type_from_field_attributes(&fields))
120                            .collect();
121
122                    if let Some((_, via)) = &via {
123                        inferred_state_types.extend(state_from_via(&ident, via));
124                    }
125
126                    match inferred_state_types.len() {
127                        0 => State::Default(syn::parse_quote!(S)),
128                        1 => State::Custom(inferred_state_types.iter().next().unwrap().to_owned()),
129                        _ => State::CannotInfer,
130                    }
131                }
132            };
133
134            let trait_impl = match (via.map(second), rejection.map(second)) {
135                (Some(via), rejection) => impl_struct_by_extracting_all_at_once(
136                    ident,
137                    fields,
138                    via,
139                    rejection,
140                    generic_ident,
141                    &state,
142                    tr,
143                )?,
144                (None, rejection) => {
145                    error_on_generic_ident(generic_ident, tr)?;
146                    impl_struct_by_extracting_each_field(ident, fields, rejection, &state, tr)?
147                }
148            };
149
150            if let State::CannotInfer = state {
151                let attr_name = match tr {
152                    Trait::FromRequest => "from_request",
153                    Trait::FromRequestParts => "from_request_parts",
154                };
155                let compile_error = syn::Error::new(
156                    Span::call_site(),
157                    format_args!(
158                        "can't infer state type, please add \
159                         `#[{attr_name}(state = MyStateType)]` attribute",
160                    ),
161                )
162                .into_compile_error();
163
164                Ok(quote! {
165                    #trait_impl
166                    #compile_error
167                })
168            } else {
169                Ok(trait_impl)
170            }
171        }
172        syn::Item::Enum(item) => {
173            let syn::ItemEnum {
174                attrs,
175                vis: _,
176                enum_token: _,
177                ident,
178                generics,
179                brace_token: _,
180                variants,
181            } = item;
182
183            let generics_error = format!("`#[derive({tr})]` on enums don't support generics");
184
185            if !generics.params.is_empty() {
186                return Err(syn::Error::new_spanned(generics, generics_error));
187            }
188
189            if let Some(where_clause) = generics.where_clause {
190                return Err(syn::Error::new_spanned(where_clause, generics_error));
191            }
192
193            let FromRequestContainerAttrs {
194                via,
195                rejection,
196                state,
197            } = parse_attrs("from_request", &attrs)?;
198
199            let state = match state {
200                Some((_, state)) => State::Custom(state),
201                None => (|| {
202                    let via = via.as_ref().map(|(_, via)| via)?;
203                    state_from_via(&ident, via).map(State::Custom)
204                })()
205                .unwrap_or_else(|| State::Default(syn::parse_quote!(S))),
206            };
207
208            match (via.map(second), rejection) {
209                (Some(via), rejection) => impl_enum_by_extracting_all_at_once(
210                    ident,
211                    variants,
212                    via,
213                    rejection.map(second),
214                    state,
215                    tr,
216                ),
217                (None, Some((rejection_kw, _))) => Err(syn::Error::new_spanned(
218                    rejection_kw,
219                    "cannot use `rejection` without `via`",
220                )),
221                (None, _) => Err(syn::Error::new(
222                    Span::call_site(),
223                    "missing `#[from_request(via(...))]`",
224                )),
225            }
226        }
227        _ => Err(syn::Error::new_spanned(item, "expected `struct` or `enum`")),
228    }
229}
230
231fn parse_single_generic_type_on_struct(
232    generics: syn::Generics,
233    fields: &syn::Fields,
234    tr: Trait,
235) -> syn::Result<Option<Ident>> {
236    if let Some(where_clause) = generics.where_clause {
237        return Err(syn::Error::new_spanned(
238            where_clause,
239            format_args!("#[derive({tr})] doesn't support structs with `where` clauses"),
240        ));
241    }
242
243    match generics.params.len() {
244        0 => Ok(None),
245        1 => {
246            let param = generics.params.first().unwrap();
247            let ty_ident = match param {
248                syn::GenericParam::Type(ty) => &ty.ident,
249                syn::GenericParam::Lifetime(lifetime) => {
250                    return Err(syn::Error::new_spanned(
251                        lifetime,
252                        format_args!(
253                            "#[derive({tr})] doesn't support structs \
254                             that are generic over lifetimes"
255                        ),
256                    ));
257                }
258                syn::GenericParam::Const(konst) => {
259                    return Err(syn::Error::new_spanned(
260                        konst,
261                        format_args!(
262                            "#[derive({tr})] doesn't support structs \
263                             that have const generics"
264                        ),
265                    ));
266                }
267            };
268
269            match fields {
270                syn::Fields::Named(fields_named) => {
271                    return Err(syn::Error::new_spanned(
272                        fields_named,
273                        format_args!(
274                            "#[derive({tr})] doesn't support named fields \
275                             for generic structs. Use a tuple struct instead"
276                        ),
277                    ));
278                }
279                syn::Fields::Unnamed(fields_unnamed) => {
280                    if fields_unnamed.unnamed.len() != 1 {
281                        return Err(syn::Error::new_spanned(
282                            fields_unnamed,
283                            format_args!(
284                                "#[derive({tr})] only supports generics on \
285                                 tuple structs that have exactly one field"
286                            ),
287                        ));
288                    }
289
290                    let field = fields_unnamed.unnamed.first().unwrap();
291
292                    if let syn::Type::Path(type_path) = &field.ty {
293                        if type_path
294                            .path
295                            .get_ident()
296                            .map_or(true, |field_type_ident| field_type_ident != ty_ident)
297                        {
298                            return Err(syn::Error::new_spanned(
299                                type_path,
300                                format_args!(
301                                    "#[derive({tr})] only supports generics on \
302                                     tuple structs that have exactly one field of the generic type"
303                                ),
304                            ));
305                        }
306                    } else {
307                        return Err(syn::Error::new_spanned(&field.ty, "Expected type path"));
308                    }
309                }
310                syn::Fields::Unit => return Ok(None),
311            }
312
313            Ok(Some(ty_ident.clone()))
314        }
315        _ => Err(syn::Error::new_spanned(
316            generics,
317            format_args!("#[derive({tr})] only supports 0 or 1 generic type parameters"),
318        )),
319    }
320}
321
322fn error_on_generic_ident(generic_ident: Option<Ident>, tr: Trait) -> syn::Result<()> {
323    if let Some(generic_ident) = generic_ident {
324        Err(syn::Error::new_spanned(
325            generic_ident,
326            format_args!(
327                "#[derive({tr})] only supports generics when used with #[from_request(via)]"
328            ),
329        ))
330    } else {
331        Ok(())
332    }
333}
334
335fn impl_struct_by_extracting_each_field(
336    ident: syn::Ident,
337    fields: syn::Fields,
338    rejection: Option<syn::Path>,
339    state: &State,
340    tr: Trait,
341) -> syn::Result<TokenStream> {
342    let trait_fn_body = match state {
343        State::CannotInfer => quote! {
344            ::std::unimplemented!()
345        },
346        _ => {
347            let extract_fields = extract_fields(&fields, &rejection, tr)?;
348            quote! {
349                ::std::result::Result::Ok(Self {
350                    #(#extract_fields)*
351                })
352            }
353        }
354    };
355
356    let rejection_ident = if let Some(rejection) = rejection {
357        quote!(#rejection)
358    } else if has_no_fields(&fields) {
359        quote!(::std::convert::Infallible)
360    } else {
361        quote!(::axum::response::Response)
362    };
363
364    let impl_generics = state
365        .impl_generics()
366        .collect::<Punctuated<Type, Token![,]>>();
367
368    let trait_generics = state
369        .trait_generics()
370        .collect::<Punctuated<Type, Token![,]>>();
371
372    let state_bounds = state.bounds();
373
374    Ok(match tr {
375        Trait::FromRequest => quote! {
376            #[::axum::async_trait]
377            #[automatically_derived]
378            impl<#impl_generics> ::axum::extract::FromRequest<#trait_generics> for #ident
379            where
380                #state_bounds
381            {
382                type Rejection = #rejection_ident;
383
384                async fn from_request(
385                    mut req: ::axum::http::Request<::axum::body::Body>,
386                    state: &#state,
387                ) -> ::std::result::Result<Self, Self::Rejection> {
388                    #trait_fn_body
389                }
390            }
391        },
392        Trait::FromRequestParts => quote! {
393            #[::axum::async_trait]
394            #[automatically_derived]
395            impl<#impl_generics> ::axum::extract::FromRequestParts<#trait_generics> for #ident
396            where
397                #state_bounds
398            {
399                type Rejection = #rejection_ident;
400
401                async fn from_request_parts(
402                    parts: &mut ::axum::http::request::Parts,
403                    state: &#state,
404                ) -> ::std::result::Result<Self, Self::Rejection> {
405                    #trait_fn_body
406                }
407            }
408        },
409    })
410}
411
412fn has_no_fields(fields: &syn::Fields) -> bool {
413    match fields {
414        syn::Fields::Named(fields) => fields.named.is_empty(),
415        syn::Fields::Unnamed(fields) => fields.unnamed.is_empty(),
416        syn::Fields::Unit => true,
417    }
418}
419
420fn extract_fields(
421    fields: &syn::Fields,
422    rejection: &Option<syn::Path>,
423    tr: Trait,
424) -> syn::Result<Vec<TokenStream>> {
425    fn member(field: &syn::Field, index: usize) -> TokenStream {
426        match &field.ident {
427            Some(ident) => quote! { #ident },
428            _ => {
429                let member = syn::Member::Unnamed(syn::Index {
430                    index: index as u32,
431                    span: field.span(),
432                });
433                quote! { #member }
434            }
435        }
436    }
437
438    fn into_inner(via: Option<(attr::kw::via, syn::Path)>, ty_span: Span) -> TokenStream {
439        if let Some((_, path)) = via {
440            let span = path.span();
441            quote_spanned! {span=>
442                |#path(inner)| inner
443            }
444        } else {
445            quote_spanned! {ty_span=>
446                ::std::convert::identity
447            }
448        }
449    }
450
451    let mut fields_iter = fields.iter();
452
453    let last = match tr {
454        // Use FromRequestParts for all elements except the last
455        Trait::FromRequest => fields_iter.next_back(),
456        // Use FromRequestParts for all elements
457        Trait::FromRequestParts => None,
458    };
459
460    let mut res: Vec<_> = fields_iter
461        .enumerate()
462        .map(|(index, field)| {
463            let FromRequestFieldAttrs { via } = parse_attrs("from_request", &field.attrs)?;
464
465            let member = member(field, index);
466            let ty_span = field.ty.span();
467            let into_inner = into_inner(via, ty_span);
468
469            if peel_option(&field.ty).is_some() {
470                let tokens = match tr {
471                    Trait::FromRequest => {
472                        quote_spanned! {ty_span=>
473                            #member: {
474                                let (mut parts, body) = req.into_parts();
475                                let value =
476                                    ::axum::extract::FromRequestParts::from_request_parts(
477                                        &mut parts,
478                                        state,
479                                    )
480                                    .await
481                                    .ok()
482                                    .map(#into_inner);
483                                req = ::axum::http::Request::from_parts(parts, body);
484                                value
485                            },
486                        }
487                    }
488                    Trait::FromRequestParts => {
489                        quote_spanned! {ty_span=>
490                            #member: {
491                                ::axum::extract::FromRequestParts::from_request_parts(
492                                    parts,
493                                    state,
494                                )
495                                .await
496                                .ok()
497                                .map(#into_inner)
498                            },
499                        }
500                    }
501                };
502                Ok(tokens)
503            } else if peel_result_ok(&field.ty).is_some() {
504                let tokens = match tr {
505                    Trait::FromRequest => {
506                        quote_spanned! {ty_span=>
507                            #member: {
508                                let (mut parts, body) = req.into_parts();
509                                let value =
510                                    ::axum::extract::FromRequestParts::from_request_parts(
511                                        &mut parts,
512                                        state,
513                                    )
514                                    .await
515                                    .map(#into_inner);
516                                req = ::axum::http::Request::from_parts(parts, body);
517                                value
518                            },
519                        }
520                    }
521                    Trait::FromRequestParts => {
522                        quote_spanned! {ty_span=>
523                            #member: {
524                                ::axum::extract::FromRequestParts::from_request_parts(
525                                    parts,
526                                    state,
527                                )
528                                .await
529                                .map(#into_inner)
530                            },
531                        }
532                    }
533                };
534                Ok(tokens)
535            } else {
536                let map_err = if let Some(rejection) = rejection {
537                    quote! { <#rejection as ::std::convert::From<_>>::from }
538                } else {
539                    quote! { ::axum::response::IntoResponse::into_response }
540                };
541
542                let tokens = match tr {
543                    Trait::FromRequest => {
544                        quote_spanned! {ty_span=>
545                            #member: {
546                                let (mut parts, body) = req.into_parts();
547                                let value =
548                                    ::axum::extract::FromRequestParts::from_request_parts(
549                                        &mut parts,
550                                        state,
551                                    )
552                                    .await
553                                    .map(#into_inner)
554                                    .map_err(#map_err)?;
555                                req = ::axum::http::Request::from_parts(parts, body);
556                                value
557                            },
558                        }
559                    }
560                    Trait::FromRequestParts => {
561                        quote_spanned! {ty_span=>
562                            #member: {
563                                ::axum::extract::FromRequestParts::from_request_parts(
564                                    parts,
565                                    state,
566                                )
567                                .await
568                                .map(#into_inner)
569                                .map_err(#map_err)?
570                            },
571                        }
572                    }
573                };
574                Ok(tokens)
575            }
576        })
577        .collect::<syn::Result<_>>()?;
578
579    // Handle the last element, if deriving FromRequest
580    if let Some(field) = last {
581        let FromRequestFieldAttrs { via } = parse_attrs("from_request", &field.attrs)?;
582
583        let member = member(field, fields.len() - 1);
584        let ty_span = field.ty.span();
585        let into_inner = into_inner(via, ty_span);
586
587        let item = if peel_option(&field.ty).is_some() {
588            quote_spanned! {ty_span=>
589                #member: {
590                    ::axum::extract::FromRequest::from_request(req, state)
591                        .await
592                        .ok()
593                        .map(#into_inner)
594                },
595            }
596        } else if peel_result_ok(&field.ty).is_some() {
597            quote_spanned! {ty_span=>
598                #member: {
599                    ::axum::extract::FromRequest::from_request(req, state)
600                        .await
601                        .map(#into_inner)
602                },
603            }
604        } else {
605            let map_err = if let Some(rejection) = rejection {
606                quote! { <#rejection as ::std::convert::From<_>>::from }
607            } else {
608                quote! { ::axum::response::IntoResponse::into_response }
609            };
610
611            quote_spanned! {ty_span=>
612                #member: {
613                    ::axum::extract::FromRequest::from_request(req, state)
614                        .await
615                        .map(#into_inner)
616                        .map_err(#map_err)?
617                },
618            }
619        };
620
621        res.push(item);
622    }
623
624    Ok(res)
625}
626
627fn peel_option(ty: &syn::Type) -> Option<&syn::Type> {
628    let type_path = if let syn::Type::Path(type_path) = ty {
629        type_path
630    } else {
631        return None;
632    };
633
634    let segment = type_path.path.segments.last()?;
635
636    if segment.ident != "Option" {
637        return None;
638    }
639
640    let args = match &segment.arguments {
641        syn::PathArguments::AngleBracketed(args) => args,
642        syn::PathArguments::Parenthesized(_) | syn::PathArguments::None => return None,
643    };
644
645    let ty = if args.args.len() == 1 {
646        args.args.last().unwrap()
647    } else {
648        return None;
649    };
650
651    if let syn::GenericArgument::Type(ty) = ty {
652        Some(ty)
653    } else {
654        None
655    }
656}
657
658fn peel_result_ok(ty: &syn::Type) -> Option<&syn::Type> {
659    let type_path = if let syn::Type::Path(type_path) = ty {
660        type_path
661    } else {
662        return None;
663    };
664
665    let segment = type_path.path.segments.last()?;
666
667    if segment.ident != "Result" {
668        return None;
669    }
670
671    let args = match &segment.arguments {
672        syn::PathArguments::AngleBracketed(args) => args,
673        syn::PathArguments::Parenthesized(_) | syn::PathArguments::None => return None,
674    };
675
676    let ty = if args.args.len() == 2 {
677        args.args.first().unwrap()
678    } else {
679        return None;
680    };
681
682    if let syn::GenericArgument::Type(ty) = ty {
683        Some(ty)
684    } else {
685        None
686    }
687}
688
689fn impl_struct_by_extracting_all_at_once(
690    ident: syn::Ident,
691    fields: syn::Fields,
692    via_path: syn::Path,
693    rejection: Option<syn::Path>,
694    generic_ident: Option<Ident>,
695    state: &State,
696    tr: Trait,
697) -> syn::Result<TokenStream> {
698    let fields = match fields {
699        syn::Fields::Named(fields) => fields.named.into_iter(),
700        syn::Fields::Unnamed(fields) => fields.unnamed.into_iter(),
701        syn::Fields::Unit => Punctuated::<_, Token![,]>::new().into_iter(),
702    };
703
704    for field in fields {
705        let FromRequestFieldAttrs { via } = parse_attrs("from_request", &field.attrs)?;
706
707        if let Some((via, _)) = via {
708            return Err(syn::Error::new_spanned(
709                via,
710                "`#[from_request(via(...))]` on a field cannot be used \
711                together with `#[from_request(...)]` on the container",
712            ));
713        }
714    }
715
716    let path_span = via_path.span();
717
718    let (associated_rejection_type, map_err) = if let Some(rejection) = &rejection {
719        let rejection = quote! { #rejection };
720        let map_err = quote! { ::std::convert::From::from };
721        (rejection, map_err)
722    } else {
723        let rejection = quote! {
724            ::axum::response::Response
725        };
726        let map_err = quote! { ::axum::response::IntoResponse::into_response };
727        (rejection, map_err)
728    };
729
730    // for something like
731    //
732    // ```
733    // #[derive(Clone, Default, FromRequest)]
734    // #[from_request(via(State))]
735    // struct AppState {}
736    // ```
737    //
738    // we need to implement `impl<M> FromRequest<AppState, M>` but only for
739    // - `#[derive(FromRequest)]`, not `#[derive(FromRequestParts)]`
740    // - `State`, not other extractors
741    //
742    // honestly not sure why but the tests all pass
743    let via_marker_type = if path_ident_is_state(&via_path) {
744        tr.via_marker_type()
745    } else {
746        None
747    };
748
749    let impl_generics = via_marker_type
750        .iter()
751        .cloned()
752        .chain(state.impl_generics())
753        .chain(generic_ident.is_some().then(|| parse_quote!(T)))
754        .collect::<Punctuated<Type, Token![,]>>();
755
756    let trait_generics = state
757        .trait_generics()
758        .chain(via_marker_type)
759        .collect::<Punctuated<Type, Token![,]>>();
760
761    let ident_generics = generic_ident
762        .is_some()
763        .then(|| quote! { <T> })
764        .unwrap_or_default();
765
766    let rejection_bound = rejection.as_ref().map(|rejection| {
767        match (tr, generic_ident.is_some()) {
768            (Trait::FromRequest, true) => {
769                quote! {
770                    #rejection: ::std::convert::From<<#via_path<T> as ::axum::extract::FromRequest<#trait_generics>>::Rejection>,
771                }
772            },
773            (Trait::FromRequest, false) => {
774                quote! {
775                    #rejection: ::std::convert::From<<#via_path<Self> as ::axum::extract::FromRequest<#trait_generics>>::Rejection>,
776                }
777            },
778            (Trait::FromRequestParts, true) => {
779                quote! {
780                    #rejection: ::std::convert::From<<#via_path<T> as ::axum::extract::FromRequestParts<#trait_generics>>::Rejection>,
781                }
782            },
783            (Trait::FromRequestParts, false) => {
784                quote! {
785                    #rejection: ::std::convert::From<<#via_path<Self> as ::axum::extract::FromRequestParts<#trait_generics>>::Rejection>,
786                }
787            }
788        }
789    }).unwrap_or_default();
790
791    let via_type_generics = if generic_ident.is_some() {
792        quote! { T }
793    } else {
794        quote! { Self }
795    };
796
797    let value_to_self = if generic_ident.is_some() {
798        quote! {
799            #ident(value)
800        }
801    } else {
802        quote! { value }
803    };
804
805    let state_bounds = state.bounds();
806
807    let tokens = match tr {
808        Trait::FromRequest => {
809            quote_spanned! {path_span=>
810                #[::axum::async_trait]
811                #[automatically_derived]
812                impl<#impl_generics> ::axum::extract::FromRequest<#trait_generics> for #ident #ident_generics
813                where
814                    #via_path<#via_type_generics>: ::axum::extract::FromRequest<#trait_generics>,
815                    #rejection_bound
816                    #state_bounds
817                {
818                    type Rejection = #associated_rejection_type;
819
820                    async fn from_request(
821                        req: ::axum::http::Request<::axum::body::Body>,
822                        state: &#state,
823                    ) -> ::std::result::Result<Self, Self::Rejection> {
824                        ::axum::extract::FromRequest::from_request(req, state)
825                            .await
826                            .map(|#via_path(value)| #value_to_self)
827                            .map_err(#map_err)
828                    }
829                }
830            }
831        }
832        Trait::FromRequestParts => {
833            quote_spanned! {path_span=>
834                #[::axum::async_trait]
835                #[automatically_derived]
836                impl<#impl_generics> ::axum::extract::FromRequestParts<#trait_generics> for #ident #ident_generics
837                where
838                    #via_path<#via_type_generics>: ::axum::extract::FromRequestParts<#trait_generics>,
839                    #rejection_bound
840                    #state_bounds
841                {
842                    type Rejection = #associated_rejection_type;
843
844                    async fn from_request_parts(
845                        parts: &mut ::axum::http::request::Parts,
846                        state: &#state,
847                    ) -> ::std::result::Result<Self, Self::Rejection> {
848                        ::axum::extract::FromRequestParts::from_request_parts(parts, state)
849                            .await
850                            .map(|#via_path(value)| #value_to_self)
851                            .map_err(#map_err)
852                    }
853                }
854            }
855        }
856    };
857
858    Ok(tokens)
859}
860
861fn impl_enum_by_extracting_all_at_once(
862    ident: syn::Ident,
863    variants: Punctuated<syn::Variant, Token![,]>,
864    path: syn::Path,
865    rejection: Option<syn::Path>,
866    state: State,
867    tr: Trait,
868) -> syn::Result<TokenStream> {
869    for variant in variants {
870        let FromRequestFieldAttrs { via } = parse_attrs("from_request", &variant.attrs)?;
871
872        if let Some((via, _)) = via {
873            return Err(syn::Error::new_spanned(
874                via,
875                "`#[from_request(via(...))]` cannot be used on variants",
876            ));
877        }
878
879        let fields = match variant.fields {
880            syn::Fields::Named(fields) => fields.named.into_iter(),
881            syn::Fields::Unnamed(fields) => fields.unnamed.into_iter(),
882            syn::Fields::Unit => Punctuated::<_, Token![,]>::new().into_iter(),
883        };
884
885        for field in fields {
886            let FromRequestFieldAttrs { via } = parse_attrs("from_request", &field.attrs)?;
887            if let Some((via, _)) = via {
888                return Err(syn::Error::new_spanned(
889                    via,
890                    "`#[from_request(via(...))]` cannot be used inside variants",
891                ));
892            }
893        }
894    }
895
896    let (associated_rejection_type, map_err) = if let Some(rejection) = &rejection {
897        let rejection = quote! { #rejection };
898        let map_err = quote! { ::std::convert::From::from };
899        (rejection, map_err)
900    } else {
901        let rejection = quote! {
902            ::axum::response::Response
903        };
904        let map_err = quote! { ::axum::response::IntoResponse::into_response };
905        (rejection, map_err)
906    };
907
908    let path_span = path.span();
909
910    let impl_generics = state
911        .impl_generics()
912        .collect::<Punctuated<Type, Token![,]>>();
913
914    let trait_generics = state
915        .trait_generics()
916        .collect::<Punctuated<Type, Token![,]>>();
917
918    let state_bounds = state.bounds();
919
920    let tokens = match tr {
921        Trait::FromRequest => {
922            quote_spanned! {path_span=>
923                #[::axum::async_trait]
924                #[automatically_derived]
925                impl<#impl_generics> ::axum::extract::FromRequest<#trait_generics> for #ident
926                where
927                    #state_bounds
928                {
929                    type Rejection = #associated_rejection_type;
930
931                    async fn from_request(
932                        req: ::axum::http::Request<::axum::body::Body>,
933                        state: &#state,
934                    ) -> ::std::result::Result<Self, Self::Rejection> {
935                        ::axum::extract::FromRequest::from_request(req, state)
936                            .await
937                            .map(|#path(inner)| inner)
938                            .map_err(#map_err)
939                    }
940                }
941            }
942        }
943        Trait::FromRequestParts => {
944            quote_spanned! {path_span=>
945                #[::axum::async_trait]
946                #[automatically_derived]
947                impl<#impl_generics> ::axum::extract::FromRequestParts<#trait_generics> for #ident
948                where
949                    #state_bounds
950                {
951                    type Rejection = #associated_rejection_type;
952
953                    async fn from_request_parts(
954                        parts: &mut ::axum::http::request::Parts,
955                        state: &#state,
956                    ) -> ::std::result::Result<Self, Self::Rejection> {
957                        ::axum::extract::FromRequestParts::from_request_parts(parts, state)
958                            .await
959                            .map(|#path(inner)| inner)
960                            .map_err(#map_err)
961                    }
962                }
963            }
964        }
965    };
966
967    Ok(tokens)
968}
969
970/// For a struct like
971///
972/// ```skip
973/// struct Extractor {
974///     state: State<AppState>,
975/// }
976/// ```
977///
978/// We can infer the state type to be `AppState` because it appears inside a `State`
979fn infer_state_type_from_field_types(fields: &Fields) -> impl Iterator<Item = Type> + '_ {
980    match fields {
981        Fields::Named(fields_named) => Box::new(crate::infer_state_types(
982            fields_named.named.iter().map(|field| &field.ty),
983        )) as Box<dyn Iterator<Item = Type>>,
984        Fields::Unnamed(fields_unnamed) => Box::new(crate::infer_state_types(
985            fields_unnamed.unnamed.iter().map(|field| &field.ty),
986        )),
987        Fields::Unit => Box::new(iter::empty()),
988    }
989}
990
991/// For a struct like
992///
993/// ```skip
994/// struct Extractor {
995///     #[from_request(via(State))]
996///     state: AppState,
997/// }
998/// ```
999///
1000/// We can infer the state type to be `AppState` because it has `via(State)` and thus can be
1001/// extracted with `State<AppState>`
1002fn infer_state_type_from_field_attributes(fields: &Fields) -> impl Iterator<Item = Type> + '_ {
1003    match fields {
1004        Fields::Named(fields_named) => {
1005            Box::new(fields_named.named.iter().filter_map(|field| {
1006                // TODO(david): it's a little wasteful to parse the attributes again here
1007                // ideally we should parse things once and pass the data down
1008                let FromRequestFieldAttrs { via } =
1009                    parse_attrs("from_request", &field.attrs).ok()?;
1010                let (_, via_path) = via?;
1011                path_ident_is_state(&via_path).then(|| field.ty.clone())
1012            })) as Box<dyn Iterator<Item = Type>>
1013        }
1014        Fields::Unnamed(fields_unnamed) => {
1015            Box::new(fields_unnamed.unnamed.iter().filter_map(|field| {
1016                // TODO(david): it's a little wasteful to parse the attributes again here
1017                // ideally we should parse things once and pass the data down
1018                let FromRequestFieldAttrs { via } =
1019                    parse_attrs("from_request", &field.attrs).ok()?;
1020                let (_, via_path) = via?;
1021                path_ident_is_state(&via_path).then(|| field.ty.clone())
1022            }))
1023        }
1024        Fields::Unit => Box::new(iter::empty()),
1025    }
1026}
1027
1028fn path_ident_is_state(path: &Path) -> bool {
1029    if let Some(last_segment) = path.segments.last() {
1030        last_segment.ident == "State"
1031    } else {
1032        false
1033    }
1034}
1035
1036fn state_from_via(ident: &Ident, via: &Path) -> Option<Type> {
1037    path_ident_is_state(via).then(|| parse_quote!(#ident))
1038}
1039
1040#[test]
1041fn ui() {
1042    crate::run_ui_tests("from_request");
1043}
1044
1045/// For some reason the compiler error for this is different locally and on CI. No idea why... So
1046/// we don't use trybuild for this test.
1047///
1048/// ```compile_fail
1049/// #[derive(axum_macros::FromRequest)]
1050/// struct Extractor {
1051///     thing: bool,
1052/// }
1053/// ```
1054#[allow(dead_code)]
1055fn test_field_doesnt_impl_from_request() {}