axum_macros/
debug_handler.rs

1use std::{collections::HashSet, fmt};
2
3use crate::{
4    attr_parsing::{parse_assignment_attribute, second},
5    with_position::{Position, WithPosition},
6};
7use proc_macro2::{Ident, Span, TokenStream};
8use quote::{format_ident, quote, quote_spanned};
9use syn::{parse::Parse, spanned::Spanned, FnArg, ItemFn, ReturnType, Token, Type};
10
11pub(crate) fn expand(attr: Attrs, item_fn: ItemFn, kind: FunctionKind) -> TokenStream {
12    let Attrs { state_ty } = attr;
13
14    let mut state_ty = state_ty.map(second);
15
16    let check_extractor_count = check_extractor_count(&item_fn, kind);
17    let check_path_extractor = check_path_extractor(&item_fn, kind);
18    let check_output_tuples = check_output_tuples(&item_fn);
19    let check_output_impls_into_response = if check_output_tuples.is_empty() {
20        check_output_impls_into_response(&item_fn)
21    } else {
22        check_output_tuples
23    };
24
25    // If the function is generic, we can't reliably check its inputs or whether the future it
26    // returns is `Send`. Skip those checks to avoid unhelpful additional compiler errors.
27    let check_inputs_and_future_send = if item_fn.sig.generics.params.is_empty() {
28        let mut err = None;
29
30        if state_ty.is_none() {
31            let state_types_from_args = state_types_from_args(&item_fn);
32
33            #[allow(clippy::comparison_chain)]
34            if state_types_from_args.len() == 1 {
35                state_ty = state_types_from_args.into_iter().next();
36            } else if state_types_from_args.len() > 1 {
37                err = Some(
38                    syn::Error::new(
39                        Span::call_site(),
40                        format!(
41                            "can't infer state type, please add set it explicitly, as in \
42                            `#[axum_macros::debug_{kind}(state = MyStateType)]`"
43                        ),
44                    )
45                    .into_compile_error(),
46                );
47            }
48        }
49
50        err.unwrap_or_else(|| {
51            let state_ty = state_ty.unwrap_or_else(|| syn::parse_quote!(()));
52
53            let check_future_send = check_future_send(&item_fn, kind);
54
55            if let Some(check_input_order) = check_input_order(&item_fn, kind) {
56                quote! {
57                    #check_input_order
58                    #check_future_send
59                }
60            } else {
61                let check_inputs_impls_from_request =
62                    check_inputs_impls_from_request(&item_fn, state_ty, kind);
63
64                quote! {
65                    #check_inputs_impls_from_request
66                    #check_future_send
67                }
68            }
69        })
70    } else {
71        syn::Error::new_spanned(
72            &item_fn.sig.generics,
73            format!("`#[axum_macros::debug_{kind}]` doesn't support generic functions"),
74        )
75        .into_compile_error()
76    };
77
78    let middleware_takes_next_as_last_arg =
79        matches!(kind, FunctionKind::Middleware).then(|| next_is_last_input(&item_fn));
80
81    quote! {
82        #item_fn
83        #check_extractor_count
84        #check_path_extractor
85        #check_output_impls_into_response
86        #check_inputs_and_future_send
87        #middleware_takes_next_as_last_arg
88    }
89}
90
91#[derive(Clone, Copy)]
92pub(crate) enum FunctionKind {
93    Handler,
94    Middleware,
95}
96
97impl fmt::Display for FunctionKind {
98    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99        match self {
100            FunctionKind::Handler => f.write_str("handler"),
101            FunctionKind::Middleware => f.write_str("middleware"),
102        }
103    }
104}
105
106impl FunctionKind {
107    fn name_uppercase_plural(&self) -> &'static str {
108        match self {
109            FunctionKind::Handler => "Handlers",
110            FunctionKind::Middleware => "Middleware",
111        }
112    }
113}
114
115mod kw {
116    syn::custom_keyword!(body);
117    syn::custom_keyword!(state);
118}
119
120pub(crate) struct Attrs {
121    state_ty: Option<(kw::state, Type)>,
122}
123
124impl Parse for Attrs {
125    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
126        let mut state_ty = None;
127
128        while !input.is_empty() {
129            let lh = input.lookahead1();
130            if lh.peek(kw::state) {
131                parse_assignment_attribute(input, &mut state_ty)?;
132            } else {
133                return Err(lh.error());
134            }
135
136            let _ = input.parse::<Token![,]>();
137        }
138
139        Ok(Self { state_ty })
140    }
141}
142
143fn check_extractor_count(item_fn: &ItemFn, kind: FunctionKind) -> Option<TokenStream> {
144    let max_extractors = 16;
145    let inputs = item_fn
146        .sig
147        .inputs
148        .iter()
149        .filter(|arg| skip_next_arg(arg, kind))
150        .count();
151    if inputs <= max_extractors {
152        None
153    } else {
154        let error_message = format!(
155            "{} cannot take more than {max_extractors} arguments. \
156            Use `(a, b): (ExtractorA, ExtractorA)` to further nest extractors",
157            kind.name_uppercase_plural(),
158        );
159        let error = syn::Error::new_spanned(&item_fn.sig.inputs, error_message).to_compile_error();
160        Some(error)
161    }
162}
163
164fn extractor_idents(
165    item_fn: &ItemFn,
166    kind: FunctionKind,
167) -> impl Iterator<Item = (usize, &syn::FnArg, &syn::Ident)> {
168    item_fn
169        .sig
170        .inputs
171        .iter()
172        .filter(move |arg| skip_next_arg(arg, kind))
173        .enumerate()
174        .filter_map(|(idx, fn_arg)| match fn_arg {
175            FnArg::Receiver(_) => None,
176            FnArg::Typed(pat_type) => {
177                if let Type::Path(type_path) = &*pat_type.ty {
178                    type_path
179                        .path
180                        .segments
181                        .last()
182                        .map(|segment| (idx, fn_arg, &segment.ident))
183                } else {
184                    None
185                }
186            }
187        })
188}
189
190fn check_path_extractor(item_fn: &ItemFn, kind: FunctionKind) -> TokenStream {
191    let path_extractors = extractor_idents(item_fn, kind)
192        .filter(|(_, _, ident)| *ident == "Path")
193        .collect::<Vec<_>>();
194
195    if path_extractors.len() > 1 {
196        path_extractors
197            .into_iter()
198            .map(|(_, arg, _)| {
199                syn::Error::new_spanned(
200                    arg,
201                    "Multiple parameters must be extracted with a tuple \
202                    `Path<(_, _)>` or a struct `Path<YourParams>`, not by applying \
203                    multiple `Path<_>` extractors",
204                )
205                .to_compile_error()
206            })
207            .collect()
208    } else {
209        quote! {}
210    }
211}
212
213fn is_self_pat_type(typed: &syn::PatType) -> bool {
214    let ident = if let syn::Pat::Ident(ident) = &*typed.pat {
215        &ident.ident
216    } else {
217        return false;
218    };
219
220    ident == "self"
221}
222
223fn check_inputs_impls_from_request(
224    item_fn: &ItemFn,
225    state_ty: Type,
226    kind: FunctionKind,
227) -> TokenStream {
228    let takes_self = item_fn.sig.inputs.first().map_or(false, |arg| match arg {
229        FnArg::Receiver(_) => true,
230        FnArg::Typed(typed) => is_self_pat_type(typed),
231    });
232
233    WithPosition::new(
234        item_fn
235            .sig
236            .inputs
237            .iter()
238            .filter(|arg| skip_next_arg(arg, kind)),
239    )
240    .enumerate()
241    .map(|(idx, arg)| {
242        let must_impl_from_request_parts = match &arg {
243            Position::First(_) | Position::Middle(_) => true,
244            Position::Last(_) | Position::Only(_) => false,
245        };
246
247        let arg = arg.into_inner();
248
249        let (span, ty) = match arg {
250            FnArg::Receiver(receiver) => {
251                if receiver.reference.is_some() {
252                    return syn::Error::new_spanned(
253                        receiver,
254                        "Handlers must only take owned values",
255                    )
256                    .into_compile_error();
257                }
258
259                let span = receiver.span();
260                (span, syn::parse_quote!(Self))
261            }
262            FnArg::Typed(typed) => {
263                let ty = &typed.ty;
264                let span = ty.span();
265
266                if is_self_pat_type(typed) {
267                    (span, syn::parse_quote!(Self))
268                } else {
269                    (span, ty.clone())
270                }
271            }
272        };
273
274        let consumes_request = request_consuming_type_name(&ty).is_some();
275
276        let check_fn = format_ident!(
277            "__axum_macros_check_{}_{}_from_request_check",
278            item_fn.sig.ident,
279            idx,
280            span = span,
281        );
282
283        let call_check_fn = format_ident!(
284            "__axum_macros_check_{}_{}_from_request_call_check",
285            item_fn.sig.ident,
286            idx,
287            span = span,
288        );
289
290        let call_check_fn_body = if takes_self {
291            quote_spanned! {span=>
292                Self::#check_fn();
293            }
294        } else {
295            quote_spanned! {span=>
296                #check_fn();
297            }
298        };
299
300        let check_fn_generics = if must_impl_from_request_parts || consumes_request {
301            quote! {}
302        } else {
303            quote! { <M> }
304        };
305
306        let from_request_bound = if must_impl_from_request_parts {
307            quote_spanned! {span=>
308                #ty: ::axum::extract::FromRequestParts<#state_ty> + Send
309            }
310        } else if consumes_request {
311            quote_spanned! {span=>
312                #ty: ::axum::extract::FromRequest<#state_ty> + Send
313            }
314        } else {
315            quote_spanned! {span=>
316                #ty: ::axum::extract::FromRequest<#state_ty, M> + Send
317            }
318        };
319
320        quote_spanned! {span=>
321            #[allow(warnings)]
322            #[doc(hidden)]
323            fn #check_fn #check_fn_generics()
324            where
325                #from_request_bound,
326            {}
327
328            // we have to call the function to actually trigger a compile error
329            // since the function is generic, just defining it is not enough
330            #[allow(warnings)]
331            #[doc(hidden)]
332            fn #call_check_fn()
333            {
334                #call_check_fn_body
335            }
336        }
337    })
338    .collect::<TokenStream>()
339}
340
341fn check_output_tuples(item_fn: &ItemFn) -> TokenStream {
342    let elems = match &item_fn.sig.output {
343        ReturnType::Type(_, ty) => match &**ty {
344            Type::Tuple(tuple) => &tuple.elems,
345            _ => return quote! {},
346        },
347        ReturnType::Default => return quote! {},
348    };
349
350    let handler_ident = &item_fn.sig.ident;
351
352    match elems.len() {
353        0 => quote! {},
354        n if n > 17 => syn::Error::new_spanned(
355            &item_fn.sig.output,
356            "Cannot return tuples with more than 17 elements",
357        )
358        .to_compile_error(),
359        _ => WithPosition::new(elems)
360            .enumerate()
361            .map(|(idx, arg)| match arg {
362                Position::First(ty) => match extract_clean_typename(ty).as_deref() {
363                    Some("StatusCode" | "Response") => quote! {},
364                    Some("Parts") => check_is_response_parts(ty, handler_ident, idx),
365                    Some(_) | None => {
366                        if let Some(tn) = well_known_last_response_type(ty) {
367                            syn::Error::new_spanned(
368                                ty,
369                                format!(
370                                    "`{tn}` must be the last element \
371                                    in a response tuple"
372                                ),
373                            )
374                            .to_compile_error()
375                        } else {
376                            check_into_response_parts(ty, handler_ident, idx)
377                        }
378                    }
379                },
380                Position::Middle(ty) => {
381                    if let Some(tn) = well_known_last_response_type(ty) {
382                        syn::Error::new_spanned(
383                            ty,
384                            format!("`{tn}` must be the last element in a response tuple"),
385                        )
386                        .to_compile_error()
387                    } else {
388                        check_into_response_parts(ty, handler_ident, idx)
389                    }
390                }
391                Position::Last(ty) | Position::Only(ty) => check_into_response(handler_ident, ty),
392            })
393            .collect::<TokenStream>(),
394    }
395}
396
397fn check_into_response(handler: &Ident, ty: &Type) -> TokenStream {
398    let (span, ty) = (ty.span(), ty.clone());
399
400    let check_fn = format_ident!(
401        "__axum_macros_check_{handler}_into_response_check",
402        span = span,
403    );
404
405    let call_check_fn = format_ident!(
406        "__axum_macros_check_{handler}_into_response_call_check",
407        span = span,
408    );
409
410    let call_check_fn_body = quote_spanned! {span=>
411        #check_fn();
412    };
413
414    let from_request_bound = quote_spanned! {span=>
415        #ty: ::axum::response::IntoResponse
416    };
417    quote_spanned! {span=>
418        #[allow(warnings)]
419        #[allow(unreachable_code)]
420        #[doc(hidden)]
421        fn #check_fn()
422        where
423            #from_request_bound,
424        {}
425
426        // we have to call the function to actually trigger a compile error
427        // since the function is generic, just defining it is not enough
428        #[allow(warnings)]
429        #[allow(unreachable_code)]
430        #[doc(hidden)]
431        fn #call_check_fn() {
432            #call_check_fn_body
433        }
434    }
435}
436
437fn check_is_response_parts(ty: &Type, ident: &Ident, index: usize) -> TokenStream {
438    let (span, ty) = (ty.span(), ty.clone());
439
440    let check_fn = format_ident!(
441        "__axum_macros_check_{}_is_response_parts_{index}_check",
442        ident,
443        span = span,
444    );
445
446    quote_spanned! {span=>
447        #[allow(warnings)]
448        #[allow(unreachable_code)]
449        #[doc(hidden)]
450        fn #check_fn(parts: #ty) -> ::axum::http::response::Parts {
451            parts
452        }
453    }
454}
455
456fn check_into_response_parts(ty: &Type, ident: &Ident, index: usize) -> TokenStream {
457    let (span, ty) = (ty.span(), ty.clone());
458
459    let check_fn = format_ident!(
460        "__axum_macros_check_{}_into_response_parts_{index}_check",
461        ident,
462        span = span,
463    );
464
465    let call_check_fn = format_ident!(
466        "__axum_macros_check_{}_into_response_parts_{index}_call_check",
467        ident,
468        span = span,
469    );
470
471    let call_check_fn_body = quote_spanned! {span=>
472        #check_fn();
473    };
474
475    let from_request_bound = quote_spanned! {span=>
476        #ty: ::axum::response::IntoResponseParts
477    };
478    quote_spanned! {span=>
479        #[allow(warnings)]
480        #[allow(unreachable_code)]
481        #[doc(hidden)]
482        fn #check_fn()
483        where
484            #from_request_bound,
485        {}
486
487        // we have to call the function to actually trigger a compile error
488        // since the function is generic, just defining it is not enough
489        #[allow(warnings)]
490        #[allow(unreachable_code)]
491        #[doc(hidden)]
492        fn #call_check_fn() {
493            #call_check_fn_body
494        }
495    }
496}
497
498fn check_input_order(item_fn: &ItemFn, kind: FunctionKind) -> Option<TokenStream> {
499    let number_of_inputs = item_fn
500        .sig
501        .inputs
502        .iter()
503        .filter(|arg| skip_next_arg(arg, kind))
504        .count();
505
506    let types_that_consume_the_request = item_fn
507        .sig
508        .inputs
509        .iter()
510        .filter(|arg| skip_next_arg(arg, kind))
511        .enumerate()
512        .filter_map(|(idx, arg)| {
513            let ty = match arg {
514                FnArg::Typed(pat_type) => &*pat_type.ty,
515                FnArg::Receiver(_) => return None,
516            };
517            let type_name = request_consuming_type_name(ty)?;
518
519            Some((idx, type_name, ty.span()))
520        })
521        .collect::<Vec<_>>();
522
523    if types_that_consume_the_request.is_empty() {
524        return None;
525    };
526
527    // exactly one type that consumes the request
528    if types_that_consume_the_request.len() == 1 {
529        // and that is not the last
530        if types_that_consume_the_request[0].0 != number_of_inputs - 1 {
531            let (_idx, type_name, span) = &types_that_consume_the_request[0];
532            let error = format!(
533                "`{type_name}` consumes the request body and thus must be \
534                the last argument to the handler function"
535            );
536            return Some(quote_spanned! {*span=>
537                compile_error!(#error);
538            });
539        } else {
540            return None;
541        }
542    }
543
544    if types_that_consume_the_request.len() == 2 {
545        let (_, first, _) = &types_that_consume_the_request[0];
546        let (_, second, _) = &types_that_consume_the_request[1];
547        let error = format!(
548            "Can't have two extractors that consume the request body. \
549            `{first}` and `{second}` both do that.",
550        );
551        let span = item_fn.sig.inputs.span();
552        Some(quote_spanned! {span=>
553            compile_error!(#error);
554        })
555    } else {
556        let types = WithPosition::new(types_that_consume_the_request)
557            .map(|pos| match pos {
558                Position::First((_, type_name, _)) | Position::Middle((_, type_name, _)) => {
559                    format!("`{type_name}`, ")
560                }
561                Position::Last((_, type_name, _)) => format!("and `{type_name}`"),
562                Position::Only(_) => unreachable!(),
563            })
564            .collect::<String>();
565
566        let error = format!(
567            "Can't have more than one extractor that consume the request body. \
568            {types} all do that.",
569        );
570        let span = item_fn.sig.inputs.span();
571        Some(quote_spanned! {span=>
572            compile_error!(#error);
573        })
574    }
575}
576
577fn extract_clean_typename(ty: &Type) -> Option<String> {
578    let path = match ty {
579        Type::Path(type_path) => &type_path.path,
580        _ => return None,
581    };
582    path.segments.last().map(|p| p.ident.to_string())
583}
584
585fn request_consuming_type_name(ty: &Type) -> Option<&'static str> {
586    let typename = extract_clean_typename(ty)?;
587
588    let type_name = match &*typename {
589        "Json" => "Json<_>",
590        "RawBody" => "RawBody<_>",
591        "RawForm" => "RawForm",
592        "Multipart" => "Multipart",
593        "Protobuf" => "Protobuf",
594        "JsonLines" => "JsonLines<_>",
595        "Form" => "Form<_>",
596        "Request" => "Request<_>",
597        "Bytes" => "Bytes",
598        "String" => "String",
599        "Parts" => "Parts",
600        _ => return None,
601    };
602
603    Some(type_name)
604}
605
606fn well_known_last_response_type(ty: &Type) -> Option<&'static str> {
607    let typename = match extract_clean_typename(ty) {
608        Some(tn) => tn,
609        None => return None,
610    };
611
612    let type_name = match &*typename {
613        "Json" => "Json<_>",
614        "Protobuf" => "Protobuf",
615        "JsonLines" => "JsonLines<_>",
616        "Form" => "Form<_>",
617        "Bytes" => "Bytes",
618        "String" => "String",
619        _ => return None,
620    };
621
622    Some(type_name)
623}
624
625fn check_output_impls_into_response(item_fn: &ItemFn) -> TokenStream {
626    let ty = match &item_fn.sig.output {
627        syn::ReturnType::Default => return quote! {},
628        syn::ReturnType::Type(_, ty) => ty,
629    };
630    let span = ty.span();
631
632    let declare_inputs = item_fn
633        .sig
634        .inputs
635        .iter()
636        .filter_map(|arg| match arg {
637            FnArg::Receiver(_) => None,
638            FnArg::Typed(pat_ty) => {
639                let pat = &pat_ty.pat;
640                let ty = &pat_ty.ty;
641                Some(quote! {
642                    let #pat: #ty = panic!();
643                })
644            }
645        })
646        .collect::<TokenStream>();
647
648    let block = &item_fn.block;
649
650    let make_value_name = format_ident!(
651        "__axum_macros_check_{}_into_response_make_value",
652        item_fn.sig.ident
653    );
654
655    let make = if item_fn.sig.asyncness.is_some() {
656        quote_spanned! {span=>
657            #[allow(warnings)]
658            #[allow(unreachable_code)]
659            #[doc(hidden)]
660            async fn #make_value_name() -> #ty {
661                #declare_inputs
662                #block
663            }
664        }
665    } else {
666        quote_spanned! {span=>
667            #[allow(warnings)]
668            #[allow(unreachable_code)]
669            #[doc(hidden)]
670            fn #make_value_name() -> #ty {
671                #declare_inputs
672                #block
673            }
674        }
675    };
676
677    let name = format_ident!("__axum_macros_check_{}_into_response", item_fn.sig.ident);
678
679    if let Some(receiver) = self_receiver(item_fn) {
680        quote_spanned! {span=>
681            #make
682
683            #[allow(warnings)]
684            #[allow(unreachable_code)]
685            #[doc(hidden)]
686            async fn #name() {
687                let value = #receiver #make_value_name().await;
688                fn check<T>(_: T)
689                    where T: ::axum::response::IntoResponse
690                {}
691                check(value);
692            }
693        }
694    } else {
695        quote_spanned! {span=>
696            #[allow(warnings)]
697            #[allow(unreachable_code)]
698            #[doc(hidden)]
699            async fn #name() {
700                #make
701
702                let value = #make_value_name().await;
703
704                fn check<T>(_: T)
705                where T: ::axum::response::IntoResponse
706                {}
707
708                check(value);
709            }
710        }
711    }
712}
713
714fn check_future_send(item_fn: &ItemFn, kind: FunctionKind) -> TokenStream {
715    if item_fn.sig.asyncness.is_none() {
716        match &item_fn.sig.output {
717            syn::ReturnType::Default => {
718                return syn::Error::new_spanned(
719                    item_fn.sig.fn_token,
720                    format!("{} must be `async fn`s", kind.name_uppercase_plural()),
721                )
722                .into_compile_error();
723            }
724            syn::ReturnType::Type(_, ty) => ty,
725        };
726    }
727
728    let span = item_fn.sig.ident.span();
729
730    let handler_name = &item_fn.sig.ident;
731
732    let args = item_fn.sig.inputs.iter().map(|_| {
733        quote_spanned! {span=> panic!() }
734    });
735
736    let name = format_ident!("__axum_macros_check_{}_future", item_fn.sig.ident);
737
738    let do_check = quote! {
739        fn check<T>(_: T)
740            where T: ::std::future::Future + Send
741        {}
742        check(future);
743    };
744
745    if let Some(receiver) = self_receiver(item_fn) {
746        quote! {
747            #[allow(warnings)]
748            #[allow(unreachable_code)]
749            #[doc(hidden)]
750            fn #name() {
751                let future = #receiver #handler_name(#(#args),*);
752                #do_check
753            }
754        }
755    } else {
756        quote! {
757            #[allow(warnings)]
758            #[allow(unreachable_code)]
759            #[doc(hidden)]
760            fn #name() {
761                #item_fn
762                let future = #handler_name(#(#args),*);
763                #do_check
764            }
765        }
766    }
767}
768
769fn self_receiver(item_fn: &ItemFn) -> Option<TokenStream> {
770    let takes_self = item_fn.sig.inputs.iter().any(|arg| match arg {
771        FnArg::Receiver(_) => true,
772        FnArg::Typed(typed) => is_self_pat_type(typed),
773    });
774
775    if takes_self {
776        return Some(quote! { Self:: });
777    }
778
779    if let syn::ReturnType::Type(_, ty) = &item_fn.sig.output {
780        if let syn::Type::Path(path) = &**ty {
781            let segments = &path.path.segments;
782            if segments.len() == 1 {
783                if let Some(last) = segments.last() {
784                    match &last.arguments {
785                        syn::PathArguments::None if last.ident == "Self" => {
786                            return Some(quote! { Self:: });
787                        }
788                        _ => {}
789                    }
790                }
791            }
792        }
793    }
794
795    None
796}
797
798/// Given a signature like
799///
800/// ```skip
801/// #[debug_handler]
802/// async fn handler(
803///     _: axum::extract::State<AppState>,
804///     _: State<AppState>,
805/// ) {}
806/// ```
807///
808/// This will extract `AppState`.
809///
810/// Returns `None` if there are no `State` args or multiple of different types.
811fn state_types_from_args(item_fn: &ItemFn) -> HashSet<Type> {
812    let types = item_fn
813        .sig
814        .inputs
815        .iter()
816        .filter_map(|input| match input {
817            FnArg::Receiver(_) => None,
818            FnArg::Typed(pat_type) => Some(pat_type),
819        })
820        .map(|pat_type| &*pat_type.ty);
821    crate::infer_state_types(types).collect()
822}
823
824fn next_is_last_input(item_fn: &ItemFn) -> TokenStream {
825    let next_args = item_fn
826        .sig
827        .inputs
828        .iter()
829        .enumerate()
830        .filter(|(_, arg)| !skip_next_arg(arg, FunctionKind::Middleware))
831        .collect::<Vec<_>>();
832
833    if next_args.is_empty() {
834        return quote! {
835            compile_error!(
836                "Middleware functions must take `axum::middleware::Next` as the last argument",
837            );
838        };
839    }
840
841    if next_args.len() == 1 {
842        let (idx, arg) = &next_args[0];
843        if *idx != item_fn.sig.inputs.len() - 1 {
844            return quote_spanned! {arg.span()=>
845                compile_error!("`axum::middleware::Next` must the last argument");
846            };
847        }
848    }
849
850    if next_args.len() >= 2 {
851        return quote! {
852            compile_error!(
853                "Middleware functions can only take one argument of type `axum::middleware::Next`",
854            );
855        };
856    }
857
858    quote! {}
859}
860
861fn skip_next_arg(arg: &FnArg, kind: FunctionKind) -> bool {
862    match kind {
863        FunctionKind::Handler => true,
864        FunctionKind::Middleware => match arg {
865            FnArg::Receiver(_) => true,
866            FnArg::Typed(pat_type) => {
867                if let Type::Path(type_path) = &*pat_type.ty {
868                    type_path
869                        .path
870                        .segments
871                        .last()
872                        .map_or(true, |path_segment| path_segment.ident != "Next")
873                } else {
874                    true
875                }
876            }
877        },
878    }
879}
880
881#[test]
882fn ui_debug_handler() {
883    crate::run_ui_tests("debug_handler");
884}
885
886#[test]
887fn ui_debug_middleware() {
888    crate::run_ui_tests("debug_middleware");
889}