1use self::attr::FromRequestContainerAttrs;
2use crate::{
3 attr_parsing::{parse_attrs, second},
4 from_request::attr::FromRequestFieldAttrs,
5};
6use proc_macro2::{Span, TokenStream};
7use quote::{quote, quote_spanned, ToTokens};
8use std::{collections::HashSet, fmt, iter};
9use syn::{
10 parse_quote, punctuated::Punctuated, spanned::Spanned, Fields, Ident, Path, Token, Type,
11};
12
13mod attr;
14
15#[derive(Clone, Copy)]
16pub(crate) enum Trait {
17 FromRequest,
18 FromRequestParts,
19}
20
21impl Trait {
22 fn via_marker_type(&self) -> Option<Type> {
23 match self {
24 Trait::FromRequest => Some(parse_quote!(M)),
25 Trait::FromRequestParts => None,
26 }
27 }
28}
29
30impl fmt::Display for Trait {
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 match self {
33 Trait::FromRequest => f.write_str("FromRequest"),
34 Trait::FromRequestParts => f.write_str("FromRequestParts"),
35 }
36 }
37}
38
39#[derive(Debug)]
40enum State {
41 Custom(syn::Type),
42 Default(syn::Type),
43 CannotInfer,
44}
45
46impl State {
47 fn impl_generics(&self) -> impl Iterator<Item = Type> {
52 match self {
53 State::Default(inner) => Some(inner.clone()),
54 State::Custom(_) => None,
55 State::CannotInfer => Some(parse_quote!(S)),
56 }
57 .into_iter()
58 }
59
60 fn trait_generics(&self) -> impl Iterator<Item = Type> {
65 match self {
66 State::Default(inner) | State::Custom(inner) => iter::once(inner.clone()),
67 State::CannotInfer => iter::once(parse_quote!(S)),
68 }
69 }
70
71 fn bounds(&self) -> TokenStream {
72 match self {
73 State::Custom(_) => quote! {},
74 State::Default(inner) => quote! {
75 #inner: ::std::marker::Send + ::std::marker::Sync,
76 },
77 State::CannotInfer => quote! {
78 S: ::std::marker::Send + ::std::marker::Sync,
79 },
80 }
81 }
82}
83
84impl ToTokens for State {
85 fn to_tokens(&self, tokens: &mut TokenStream) {
86 match self {
87 State::Custom(inner) | State::Default(inner) => inner.to_tokens(tokens),
88 State::CannotInfer => quote! { S }.to_tokens(tokens),
89 }
90 }
91}
92
93pub(crate) fn expand(item: syn::Item, tr: Trait) -> syn::Result<TokenStream> {
94 match item {
95 syn::Item::Struct(item) => {
96 let syn::ItemStruct {
97 attrs,
98 ident,
99 generics,
100 fields,
101 semi_token: _,
102 vis: _,
103 struct_token: _,
104 } = item;
105
106 let generic_ident = parse_single_generic_type_on_struct(generics, &fields, tr)?;
107
108 let FromRequestContainerAttrs {
109 via,
110 rejection,
111 state,
112 } = parse_attrs("from_request", &attrs)?;
113
114 let state = match state {
115 Some((_, state)) => State::Custom(state),
116 None => {
117 let mut inferred_state_types: HashSet<_> =
118 infer_state_type_from_field_types(&fields)
119 .chain(infer_state_type_from_field_attributes(&fields))
120 .collect();
121
122 if let Some((_, via)) = &via {
123 inferred_state_types.extend(state_from_via(&ident, via));
124 }
125
126 match inferred_state_types.len() {
127 0 => State::Default(syn::parse_quote!(S)),
128 1 => State::Custom(inferred_state_types.iter().next().unwrap().to_owned()),
129 _ => State::CannotInfer,
130 }
131 }
132 };
133
134 let trait_impl = match (via.map(second), rejection.map(second)) {
135 (Some(via), rejection) => impl_struct_by_extracting_all_at_once(
136 ident,
137 fields,
138 via,
139 rejection,
140 generic_ident,
141 &state,
142 tr,
143 )?,
144 (None, rejection) => {
145 error_on_generic_ident(generic_ident, tr)?;
146 impl_struct_by_extracting_each_field(ident, fields, rejection, &state, tr)?
147 }
148 };
149
150 if let State::CannotInfer = state {
151 let attr_name = match tr {
152 Trait::FromRequest => "from_request",
153 Trait::FromRequestParts => "from_request_parts",
154 };
155 let compile_error = syn::Error::new(
156 Span::call_site(),
157 format_args!(
158 "can't infer state type, please add \
159 `#[{attr_name}(state = MyStateType)]` attribute",
160 ),
161 )
162 .into_compile_error();
163
164 Ok(quote! {
165 #trait_impl
166 #compile_error
167 })
168 } else {
169 Ok(trait_impl)
170 }
171 }
172 syn::Item::Enum(item) => {
173 let syn::ItemEnum {
174 attrs,
175 vis: _,
176 enum_token: _,
177 ident,
178 generics,
179 brace_token: _,
180 variants,
181 } = item;
182
183 let generics_error = format!("`#[derive({tr})]` on enums don't support generics");
184
185 if !generics.params.is_empty() {
186 return Err(syn::Error::new_spanned(generics, generics_error));
187 }
188
189 if let Some(where_clause) = generics.where_clause {
190 return Err(syn::Error::new_spanned(where_clause, generics_error));
191 }
192
193 let FromRequestContainerAttrs {
194 via,
195 rejection,
196 state,
197 } = parse_attrs("from_request", &attrs)?;
198
199 let state = match state {
200 Some((_, state)) => State::Custom(state),
201 None => (|| {
202 let via = via.as_ref().map(|(_, via)| via)?;
203 state_from_via(&ident, via).map(State::Custom)
204 })()
205 .unwrap_or_else(|| State::Default(syn::parse_quote!(S))),
206 };
207
208 match (via.map(second), rejection) {
209 (Some(via), rejection) => impl_enum_by_extracting_all_at_once(
210 ident,
211 variants,
212 via,
213 rejection.map(second),
214 state,
215 tr,
216 ),
217 (None, Some((rejection_kw, _))) => Err(syn::Error::new_spanned(
218 rejection_kw,
219 "cannot use `rejection` without `via`",
220 )),
221 (None, _) => Err(syn::Error::new(
222 Span::call_site(),
223 "missing `#[from_request(via(...))]`",
224 )),
225 }
226 }
227 _ => Err(syn::Error::new_spanned(item, "expected `struct` or `enum`")),
228 }
229}
230
231fn parse_single_generic_type_on_struct(
232 generics: syn::Generics,
233 fields: &syn::Fields,
234 tr: Trait,
235) -> syn::Result<Option<Ident>> {
236 if let Some(where_clause) = generics.where_clause {
237 return Err(syn::Error::new_spanned(
238 where_clause,
239 format_args!("#[derive({tr})] doesn't support structs with `where` clauses"),
240 ));
241 }
242
243 match generics.params.len() {
244 0 => Ok(None),
245 1 => {
246 let param = generics.params.first().unwrap();
247 let ty_ident = match param {
248 syn::GenericParam::Type(ty) => &ty.ident,
249 syn::GenericParam::Lifetime(lifetime) => {
250 return Err(syn::Error::new_spanned(
251 lifetime,
252 format_args!(
253 "#[derive({tr})] doesn't support structs \
254 that are generic over lifetimes"
255 ),
256 ));
257 }
258 syn::GenericParam::Const(konst) => {
259 return Err(syn::Error::new_spanned(
260 konst,
261 format_args!(
262 "#[derive({tr})] doesn't support structs \
263 that have const generics"
264 ),
265 ));
266 }
267 };
268
269 match fields {
270 syn::Fields::Named(fields_named) => {
271 return Err(syn::Error::new_spanned(
272 fields_named,
273 format_args!(
274 "#[derive({tr})] doesn't support named fields \
275 for generic structs. Use a tuple struct instead"
276 ),
277 ));
278 }
279 syn::Fields::Unnamed(fields_unnamed) => {
280 if fields_unnamed.unnamed.len() != 1 {
281 return Err(syn::Error::new_spanned(
282 fields_unnamed,
283 format_args!(
284 "#[derive({tr})] only supports generics on \
285 tuple structs that have exactly one field"
286 ),
287 ));
288 }
289
290 let field = fields_unnamed.unnamed.first().unwrap();
291
292 if let syn::Type::Path(type_path) = &field.ty {
293 if type_path
294 .path
295 .get_ident()
296 .map_or(true, |field_type_ident| field_type_ident != ty_ident)
297 {
298 return Err(syn::Error::new_spanned(
299 type_path,
300 format_args!(
301 "#[derive({tr})] only supports generics on \
302 tuple structs that have exactly one field of the generic type"
303 ),
304 ));
305 }
306 } else {
307 return Err(syn::Error::new_spanned(&field.ty, "Expected type path"));
308 }
309 }
310 syn::Fields::Unit => return Ok(None),
311 }
312
313 Ok(Some(ty_ident.clone()))
314 }
315 _ => Err(syn::Error::new_spanned(
316 generics,
317 format_args!("#[derive({tr})] only supports 0 or 1 generic type parameters"),
318 )),
319 }
320}
321
322fn error_on_generic_ident(generic_ident: Option<Ident>, tr: Trait) -> syn::Result<()> {
323 if let Some(generic_ident) = generic_ident {
324 Err(syn::Error::new_spanned(
325 generic_ident,
326 format_args!(
327 "#[derive({tr})] only supports generics when used with #[from_request(via)]"
328 ),
329 ))
330 } else {
331 Ok(())
332 }
333}
334
335fn impl_struct_by_extracting_each_field(
336 ident: syn::Ident,
337 fields: syn::Fields,
338 rejection: Option<syn::Path>,
339 state: &State,
340 tr: Trait,
341) -> syn::Result<TokenStream> {
342 let trait_fn_body = match state {
343 State::CannotInfer => quote! {
344 ::std::unimplemented!()
345 },
346 _ => {
347 let extract_fields = extract_fields(&fields, &rejection, tr)?;
348 quote! {
349 ::std::result::Result::Ok(Self {
350 #(#extract_fields)*
351 })
352 }
353 }
354 };
355
356 let rejection_ident = if let Some(rejection) = rejection {
357 quote!(#rejection)
358 } else if has_no_fields(&fields) {
359 quote!(::std::convert::Infallible)
360 } else {
361 quote!(::axum::response::Response)
362 };
363
364 let impl_generics = state
365 .impl_generics()
366 .collect::<Punctuated<Type, Token![,]>>();
367
368 let trait_generics = state
369 .trait_generics()
370 .collect::<Punctuated<Type, Token![,]>>();
371
372 let state_bounds = state.bounds();
373
374 Ok(match tr {
375 Trait::FromRequest => quote! {
376 #[::axum::async_trait]
377 #[automatically_derived]
378 impl<#impl_generics> ::axum::extract::FromRequest<#trait_generics> for #ident
379 where
380 #state_bounds
381 {
382 type Rejection = #rejection_ident;
383
384 async fn from_request(
385 mut req: ::axum::http::Request<::axum::body::Body>,
386 state: &#state,
387 ) -> ::std::result::Result<Self, Self::Rejection> {
388 #trait_fn_body
389 }
390 }
391 },
392 Trait::FromRequestParts => quote! {
393 #[::axum::async_trait]
394 #[automatically_derived]
395 impl<#impl_generics> ::axum::extract::FromRequestParts<#trait_generics> for #ident
396 where
397 #state_bounds
398 {
399 type Rejection = #rejection_ident;
400
401 async fn from_request_parts(
402 parts: &mut ::axum::http::request::Parts,
403 state: &#state,
404 ) -> ::std::result::Result<Self, Self::Rejection> {
405 #trait_fn_body
406 }
407 }
408 },
409 })
410}
411
412fn has_no_fields(fields: &syn::Fields) -> bool {
413 match fields {
414 syn::Fields::Named(fields) => fields.named.is_empty(),
415 syn::Fields::Unnamed(fields) => fields.unnamed.is_empty(),
416 syn::Fields::Unit => true,
417 }
418}
419
420fn extract_fields(
421 fields: &syn::Fields,
422 rejection: &Option<syn::Path>,
423 tr: Trait,
424) -> syn::Result<Vec<TokenStream>> {
425 fn member(field: &syn::Field, index: usize) -> TokenStream {
426 match &field.ident {
427 Some(ident) => quote! { #ident },
428 _ => {
429 let member = syn::Member::Unnamed(syn::Index {
430 index: index as u32,
431 span: field.span(),
432 });
433 quote! { #member }
434 }
435 }
436 }
437
438 fn into_inner(via: Option<(attr::kw::via, syn::Path)>, ty_span: Span) -> TokenStream {
439 if let Some((_, path)) = via {
440 let span = path.span();
441 quote_spanned! {span=>
442 |#path(inner)| inner
443 }
444 } else {
445 quote_spanned! {ty_span=>
446 ::std::convert::identity
447 }
448 }
449 }
450
451 let mut fields_iter = fields.iter();
452
453 let last = match tr {
454 Trait::FromRequest => fields_iter.next_back(),
456 Trait::FromRequestParts => None,
458 };
459
460 let mut res: Vec<_> = fields_iter
461 .enumerate()
462 .map(|(index, field)| {
463 let FromRequestFieldAttrs { via } = parse_attrs("from_request", &field.attrs)?;
464
465 let member = member(field, index);
466 let ty_span = field.ty.span();
467 let into_inner = into_inner(via, ty_span);
468
469 if peel_option(&field.ty).is_some() {
470 let tokens = match tr {
471 Trait::FromRequest => {
472 quote_spanned! {ty_span=>
473 #member: {
474 let (mut parts, body) = req.into_parts();
475 let value =
476 ::axum::extract::FromRequestParts::from_request_parts(
477 &mut parts,
478 state,
479 )
480 .await
481 .ok()
482 .map(#into_inner);
483 req = ::axum::http::Request::from_parts(parts, body);
484 value
485 },
486 }
487 }
488 Trait::FromRequestParts => {
489 quote_spanned! {ty_span=>
490 #member: {
491 ::axum::extract::FromRequestParts::from_request_parts(
492 parts,
493 state,
494 )
495 .await
496 .ok()
497 .map(#into_inner)
498 },
499 }
500 }
501 };
502 Ok(tokens)
503 } else if peel_result_ok(&field.ty).is_some() {
504 let tokens = match tr {
505 Trait::FromRequest => {
506 quote_spanned! {ty_span=>
507 #member: {
508 let (mut parts, body) = req.into_parts();
509 let value =
510 ::axum::extract::FromRequestParts::from_request_parts(
511 &mut parts,
512 state,
513 )
514 .await
515 .map(#into_inner);
516 req = ::axum::http::Request::from_parts(parts, body);
517 value
518 },
519 }
520 }
521 Trait::FromRequestParts => {
522 quote_spanned! {ty_span=>
523 #member: {
524 ::axum::extract::FromRequestParts::from_request_parts(
525 parts,
526 state,
527 )
528 .await
529 .map(#into_inner)
530 },
531 }
532 }
533 };
534 Ok(tokens)
535 } else {
536 let map_err = if let Some(rejection) = rejection {
537 quote! { <#rejection as ::std::convert::From<_>>::from }
538 } else {
539 quote! { ::axum::response::IntoResponse::into_response }
540 };
541
542 let tokens = match tr {
543 Trait::FromRequest => {
544 quote_spanned! {ty_span=>
545 #member: {
546 let (mut parts, body) = req.into_parts();
547 let value =
548 ::axum::extract::FromRequestParts::from_request_parts(
549 &mut parts,
550 state,
551 )
552 .await
553 .map(#into_inner)
554 .map_err(#map_err)?;
555 req = ::axum::http::Request::from_parts(parts, body);
556 value
557 },
558 }
559 }
560 Trait::FromRequestParts => {
561 quote_spanned! {ty_span=>
562 #member: {
563 ::axum::extract::FromRequestParts::from_request_parts(
564 parts,
565 state,
566 )
567 .await
568 .map(#into_inner)
569 .map_err(#map_err)?
570 },
571 }
572 }
573 };
574 Ok(tokens)
575 }
576 })
577 .collect::<syn::Result<_>>()?;
578
579 if let Some(field) = last {
581 let FromRequestFieldAttrs { via } = parse_attrs("from_request", &field.attrs)?;
582
583 let member = member(field, fields.len() - 1);
584 let ty_span = field.ty.span();
585 let into_inner = into_inner(via, ty_span);
586
587 let item = if peel_option(&field.ty).is_some() {
588 quote_spanned! {ty_span=>
589 #member: {
590 ::axum::extract::FromRequest::from_request(req, state)
591 .await
592 .ok()
593 .map(#into_inner)
594 },
595 }
596 } else if peel_result_ok(&field.ty).is_some() {
597 quote_spanned! {ty_span=>
598 #member: {
599 ::axum::extract::FromRequest::from_request(req, state)
600 .await
601 .map(#into_inner)
602 },
603 }
604 } else {
605 let map_err = if let Some(rejection) = rejection {
606 quote! { <#rejection as ::std::convert::From<_>>::from }
607 } else {
608 quote! { ::axum::response::IntoResponse::into_response }
609 };
610
611 quote_spanned! {ty_span=>
612 #member: {
613 ::axum::extract::FromRequest::from_request(req, state)
614 .await
615 .map(#into_inner)
616 .map_err(#map_err)?
617 },
618 }
619 };
620
621 res.push(item);
622 }
623
624 Ok(res)
625}
626
627fn peel_option(ty: &syn::Type) -> Option<&syn::Type> {
628 let type_path = if let syn::Type::Path(type_path) = ty {
629 type_path
630 } else {
631 return None;
632 };
633
634 let segment = type_path.path.segments.last()?;
635
636 if segment.ident != "Option" {
637 return None;
638 }
639
640 let args = match &segment.arguments {
641 syn::PathArguments::AngleBracketed(args) => args,
642 syn::PathArguments::Parenthesized(_) | syn::PathArguments::None => return None,
643 };
644
645 let ty = if args.args.len() == 1 {
646 args.args.last().unwrap()
647 } else {
648 return None;
649 };
650
651 if let syn::GenericArgument::Type(ty) = ty {
652 Some(ty)
653 } else {
654 None
655 }
656}
657
658fn peel_result_ok(ty: &syn::Type) -> Option<&syn::Type> {
659 let type_path = if let syn::Type::Path(type_path) = ty {
660 type_path
661 } else {
662 return None;
663 };
664
665 let segment = type_path.path.segments.last()?;
666
667 if segment.ident != "Result" {
668 return None;
669 }
670
671 let args = match &segment.arguments {
672 syn::PathArguments::AngleBracketed(args) => args,
673 syn::PathArguments::Parenthesized(_) | syn::PathArguments::None => return None,
674 };
675
676 let ty = if args.args.len() == 2 {
677 args.args.first().unwrap()
678 } else {
679 return None;
680 };
681
682 if let syn::GenericArgument::Type(ty) = ty {
683 Some(ty)
684 } else {
685 None
686 }
687}
688
689fn impl_struct_by_extracting_all_at_once(
690 ident: syn::Ident,
691 fields: syn::Fields,
692 via_path: syn::Path,
693 rejection: Option<syn::Path>,
694 generic_ident: Option<Ident>,
695 state: &State,
696 tr: Trait,
697) -> syn::Result<TokenStream> {
698 let fields = match fields {
699 syn::Fields::Named(fields) => fields.named.into_iter(),
700 syn::Fields::Unnamed(fields) => fields.unnamed.into_iter(),
701 syn::Fields::Unit => Punctuated::<_, Token![,]>::new().into_iter(),
702 };
703
704 for field in fields {
705 let FromRequestFieldAttrs { via } = parse_attrs("from_request", &field.attrs)?;
706
707 if let Some((via, _)) = via {
708 return Err(syn::Error::new_spanned(
709 via,
710 "`#[from_request(via(...))]` on a field cannot be used \
711 together with `#[from_request(...)]` on the container",
712 ));
713 }
714 }
715
716 let path_span = via_path.span();
717
718 let (associated_rejection_type, map_err) = if let Some(rejection) = &rejection {
719 let rejection = quote! { #rejection };
720 let map_err = quote! { ::std::convert::From::from };
721 (rejection, map_err)
722 } else {
723 let rejection = quote! {
724 ::axum::response::Response
725 };
726 let map_err = quote! { ::axum::response::IntoResponse::into_response };
727 (rejection, map_err)
728 };
729
730 let via_marker_type = if path_ident_is_state(&via_path) {
744 tr.via_marker_type()
745 } else {
746 None
747 };
748
749 let impl_generics = via_marker_type
750 .iter()
751 .cloned()
752 .chain(state.impl_generics())
753 .chain(generic_ident.is_some().then(|| parse_quote!(T)))
754 .collect::<Punctuated<Type, Token![,]>>();
755
756 let trait_generics = state
757 .trait_generics()
758 .chain(via_marker_type)
759 .collect::<Punctuated<Type, Token![,]>>();
760
761 let ident_generics = generic_ident
762 .is_some()
763 .then(|| quote! { <T> })
764 .unwrap_or_default();
765
766 let rejection_bound = rejection.as_ref().map(|rejection| {
767 match (tr, generic_ident.is_some()) {
768 (Trait::FromRequest, true) => {
769 quote! {
770 #rejection: ::std::convert::From<<#via_path<T> as ::axum::extract::FromRequest<#trait_generics>>::Rejection>,
771 }
772 },
773 (Trait::FromRequest, false) => {
774 quote! {
775 #rejection: ::std::convert::From<<#via_path<Self> as ::axum::extract::FromRequest<#trait_generics>>::Rejection>,
776 }
777 },
778 (Trait::FromRequestParts, true) => {
779 quote! {
780 #rejection: ::std::convert::From<<#via_path<T> as ::axum::extract::FromRequestParts<#trait_generics>>::Rejection>,
781 }
782 },
783 (Trait::FromRequestParts, false) => {
784 quote! {
785 #rejection: ::std::convert::From<<#via_path<Self> as ::axum::extract::FromRequestParts<#trait_generics>>::Rejection>,
786 }
787 }
788 }
789 }).unwrap_or_default();
790
791 let via_type_generics = if generic_ident.is_some() {
792 quote! { T }
793 } else {
794 quote! { Self }
795 };
796
797 let value_to_self = if generic_ident.is_some() {
798 quote! {
799 #ident(value)
800 }
801 } else {
802 quote! { value }
803 };
804
805 let state_bounds = state.bounds();
806
807 let tokens = match tr {
808 Trait::FromRequest => {
809 quote_spanned! {path_span=>
810 #[::axum::async_trait]
811 #[automatically_derived]
812 impl<#impl_generics> ::axum::extract::FromRequest<#trait_generics> for #ident #ident_generics
813 where
814 #via_path<#via_type_generics>: ::axum::extract::FromRequest<#trait_generics>,
815 #rejection_bound
816 #state_bounds
817 {
818 type Rejection = #associated_rejection_type;
819
820 async fn from_request(
821 req: ::axum::http::Request<::axum::body::Body>,
822 state: &#state,
823 ) -> ::std::result::Result<Self, Self::Rejection> {
824 ::axum::extract::FromRequest::from_request(req, state)
825 .await
826 .map(|#via_path(value)| #value_to_self)
827 .map_err(#map_err)
828 }
829 }
830 }
831 }
832 Trait::FromRequestParts => {
833 quote_spanned! {path_span=>
834 #[::axum::async_trait]
835 #[automatically_derived]
836 impl<#impl_generics> ::axum::extract::FromRequestParts<#trait_generics> for #ident #ident_generics
837 where
838 #via_path<#via_type_generics>: ::axum::extract::FromRequestParts<#trait_generics>,
839 #rejection_bound
840 #state_bounds
841 {
842 type Rejection = #associated_rejection_type;
843
844 async fn from_request_parts(
845 parts: &mut ::axum::http::request::Parts,
846 state: &#state,
847 ) -> ::std::result::Result<Self, Self::Rejection> {
848 ::axum::extract::FromRequestParts::from_request_parts(parts, state)
849 .await
850 .map(|#via_path(value)| #value_to_self)
851 .map_err(#map_err)
852 }
853 }
854 }
855 }
856 };
857
858 Ok(tokens)
859}
860
861fn impl_enum_by_extracting_all_at_once(
862 ident: syn::Ident,
863 variants: Punctuated<syn::Variant, Token![,]>,
864 path: syn::Path,
865 rejection: Option<syn::Path>,
866 state: State,
867 tr: Trait,
868) -> syn::Result<TokenStream> {
869 for variant in variants {
870 let FromRequestFieldAttrs { via } = parse_attrs("from_request", &variant.attrs)?;
871
872 if let Some((via, _)) = via {
873 return Err(syn::Error::new_spanned(
874 via,
875 "`#[from_request(via(...))]` cannot be used on variants",
876 ));
877 }
878
879 let fields = match variant.fields {
880 syn::Fields::Named(fields) => fields.named.into_iter(),
881 syn::Fields::Unnamed(fields) => fields.unnamed.into_iter(),
882 syn::Fields::Unit => Punctuated::<_, Token![,]>::new().into_iter(),
883 };
884
885 for field in fields {
886 let FromRequestFieldAttrs { via } = parse_attrs("from_request", &field.attrs)?;
887 if let Some((via, _)) = via {
888 return Err(syn::Error::new_spanned(
889 via,
890 "`#[from_request(via(...))]` cannot be used inside variants",
891 ));
892 }
893 }
894 }
895
896 let (associated_rejection_type, map_err) = if let Some(rejection) = &rejection {
897 let rejection = quote! { #rejection };
898 let map_err = quote! { ::std::convert::From::from };
899 (rejection, map_err)
900 } else {
901 let rejection = quote! {
902 ::axum::response::Response
903 };
904 let map_err = quote! { ::axum::response::IntoResponse::into_response };
905 (rejection, map_err)
906 };
907
908 let path_span = path.span();
909
910 let impl_generics = state
911 .impl_generics()
912 .collect::<Punctuated<Type, Token![,]>>();
913
914 let trait_generics = state
915 .trait_generics()
916 .collect::<Punctuated<Type, Token![,]>>();
917
918 let state_bounds = state.bounds();
919
920 let tokens = match tr {
921 Trait::FromRequest => {
922 quote_spanned! {path_span=>
923 #[::axum::async_trait]
924 #[automatically_derived]
925 impl<#impl_generics> ::axum::extract::FromRequest<#trait_generics> for #ident
926 where
927 #state_bounds
928 {
929 type Rejection = #associated_rejection_type;
930
931 async fn from_request(
932 req: ::axum::http::Request<::axum::body::Body>,
933 state: &#state,
934 ) -> ::std::result::Result<Self, Self::Rejection> {
935 ::axum::extract::FromRequest::from_request(req, state)
936 .await
937 .map(|#path(inner)| inner)
938 .map_err(#map_err)
939 }
940 }
941 }
942 }
943 Trait::FromRequestParts => {
944 quote_spanned! {path_span=>
945 #[::axum::async_trait]
946 #[automatically_derived]
947 impl<#impl_generics> ::axum::extract::FromRequestParts<#trait_generics> for #ident
948 where
949 #state_bounds
950 {
951 type Rejection = #associated_rejection_type;
952
953 async fn from_request_parts(
954 parts: &mut ::axum::http::request::Parts,
955 state: &#state,
956 ) -> ::std::result::Result<Self, Self::Rejection> {
957 ::axum::extract::FromRequestParts::from_request_parts(parts, state)
958 .await
959 .map(|#path(inner)| inner)
960 .map_err(#map_err)
961 }
962 }
963 }
964 }
965 };
966
967 Ok(tokens)
968}
969
970fn infer_state_type_from_field_types(fields: &Fields) -> impl Iterator<Item = Type> + '_ {
980 match fields {
981 Fields::Named(fields_named) => Box::new(crate::infer_state_types(
982 fields_named.named.iter().map(|field| &field.ty),
983 )) as Box<dyn Iterator<Item = Type>>,
984 Fields::Unnamed(fields_unnamed) => Box::new(crate::infer_state_types(
985 fields_unnamed.unnamed.iter().map(|field| &field.ty),
986 )),
987 Fields::Unit => Box::new(iter::empty()),
988 }
989}
990
991fn infer_state_type_from_field_attributes(fields: &Fields) -> impl Iterator<Item = Type> + '_ {
1003 match fields {
1004 Fields::Named(fields_named) => {
1005 Box::new(fields_named.named.iter().filter_map(|field| {
1006 let FromRequestFieldAttrs { via } =
1009 parse_attrs("from_request", &field.attrs).ok()?;
1010 let (_, via_path) = via?;
1011 path_ident_is_state(&via_path).then(|| field.ty.clone())
1012 })) as Box<dyn Iterator<Item = Type>>
1013 }
1014 Fields::Unnamed(fields_unnamed) => {
1015 Box::new(fields_unnamed.unnamed.iter().filter_map(|field| {
1016 let FromRequestFieldAttrs { via } =
1019 parse_attrs("from_request", &field.attrs).ok()?;
1020 let (_, via_path) = via?;
1021 path_ident_is_state(&via_path).then(|| field.ty.clone())
1022 }))
1023 }
1024 Fields::Unit => Box::new(iter::empty()),
1025 }
1026}
1027
1028fn path_ident_is_state(path: &Path) -> bool {
1029 if let Some(last_segment) = path.segments.last() {
1030 last_segment.ident == "State"
1031 } else {
1032 false
1033 }
1034}
1035
1036fn state_from_via(ident: &Ident, via: &Path) -> Option<Type> {
1037 path_ident_is_state(via).then(|| parse_quote!(#ident))
1038}
1039
1040#[test]
1041fn ui() {
1042 crate::run_ui_tests("from_request");
1043}
1044
1045#[allow(dead_code)]
1055fn test_field_doesnt_impl_from_request() {}