axum_macros/lib.rs
1//! Macros for [`axum`].
2//!
3//! [`axum`]: https://crates.io/crates/axum
4
5#![warn(
6 clippy::all,
7 clippy::dbg_macro,
8 clippy::todo,
9 clippy::empty_enum,
10 clippy::enum_glob_use,
11 clippy::mem_forget,
12 clippy::unused_self,
13 clippy::filter_map_next,
14 clippy::needless_continue,
15 clippy::needless_borrow,
16 clippy::match_wildcard_for_single_variants,
17 clippy::if_let_mutex,
18 clippy::await_holding_lock,
19 clippy::match_on_vec_items,
20 clippy::imprecise_flops,
21 clippy::suboptimal_flops,
22 clippy::lossy_float_literal,
23 clippy::rest_pat_in_fully_bound_structs,
24 clippy::fn_params_excessive_bools,
25 clippy::exit,
26 clippy::inefficient_to_string,
27 clippy::linkedlist,
28 clippy::macro_use_imports,
29 clippy::option_option,
30 clippy::verbose_file_reads,
31 clippy::unnested_or_patterns,
32 clippy::str_to_string,
33 rust_2018_idioms,
34 future_incompatible,
35 nonstandard_style,
36 missing_debug_implementations,
37 missing_docs
38)]
39#![deny(unreachable_pub)]
40#![allow(elided_lifetimes_in_paths, clippy::type_complexity)]
41#![forbid(unsafe_code)]
42#![cfg_attr(docsrs, feature(doc_cfg))]
43#![cfg_attr(test, allow(clippy::float_cmp))]
44#![cfg_attr(not(test), warn(clippy::print_stdout, clippy::dbg_macro))]
45
46use debug_handler::FunctionKind;
47use proc_macro::TokenStream;
48use quote::{quote, ToTokens};
49use syn::{parse::Parse, Type};
50
51mod attr_parsing;
52#[cfg(feature = "__private")]
53mod axum_test;
54mod debug_handler;
55mod from_ref;
56mod from_request;
57mod typed_path;
58mod with_position;
59
60use from_request::Trait::{FromRequest, FromRequestParts};
61
62/// Derive an implementation of [`FromRequest`].
63///
64/// Supports generating two kinds of implementations:
65/// 1. One that extracts each field individually.
66/// 2. Another that extracts the whole type at once via another extractor.
67///
68/// # Each field individually
69///
70/// By default `#[derive(FromRequest)]` will call `FromRequest::from_request` for each field:
71///
72/// ```
73/// use axum_macros::FromRequest;
74/// use axum::{
75/// extract::Extension,
76/// body::Bytes,
77/// };
78/// use axum_extra::{
79/// TypedHeader,
80/// headers::ContentType,
81/// };
82///
83/// #[derive(FromRequest)]
84/// struct MyExtractor {
85/// state: Extension<State>,
86/// content_type: TypedHeader<ContentType>,
87/// request_body: Bytes,
88/// }
89///
90/// #[derive(Clone)]
91/// struct State {
92/// // ...
93/// }
94///
95/// async fn handler(extractor: MyExtractor) {}
96/// ```
97///
98/// This requires that each field is an extractor (i.e. implements [`FromRequest`]).
99///
100/// Note that only the last field can consume the request body. Therefore this doesn't compile:
101///
102/// ```compile_fail
103/// use axum_macros::FromRequest;
104/// use axum::body::Bytes;
105///
106/// #[derive(FromRequest)]
107/// struct MyExtractor {
108/// // only the last field can implement `FromRequest`
109/// // other fields must only implement `FromRequestParts`
110/// bytes: Bytes,
111/// string: String,
112/// }
113/// ```
114///
115/// ## Extracting via another extractor
116///
117/// You can use `#[from_request(via(...))]` to extract a field via another extractor, meaning the
118/// field itself doesn't need to implement `FromRequest`:
119///
120/// ```
121/// use axum_macros::FromRequest;
122/// use axum::{
123/// extract::Extension,
124/// body::Bytes,
125/// };
126/// use axum_extra::{
127/// TypedHeader,
128/// headers::ContentType,
129/// };
130///
131/// #[derive(FromRequest)]
132/// struct MyExtractor {
133/// // This will extracted via `Extension::<State>::from_request`
134/// #[from_request(via(Extension))]
135/// state: State,
136/// // and this via `TypedHeader::<ContentType>::from_request`
137/// #[from_request(via(TypedHeader))]
138/// content_type: ContentType,
139/// // Can still be combined with other extractors
140/// request_body: Bytes,
141/// }
142///
143/// #[derive(Clone)]
144/// struct State {
145/// // ...
146/// }
147///
148/// async fn handler(extractor: MyExtractor) {}
149/// ```
150///
151/// Note this requires the via extractor to be a generic newtype struct (a tuple struct with
152/// exactly one public field) that implements `FromRequest`:
153///
154/// ```
155/// pub struct ViaExtractor<T>(pub T);
156///
157/// // impl<T, S> FromRequest<S> for ViaExtractor<T> { ... }
158/// ```
159///
160/// More complex via extractors are not supported and require writing a manual implementation.
161///
162/// ## Optional fields
163///
164/// `#[from_request(via(...))]` supports `Option<_>` and `Result<_, _>` to make fields optional:
165///
166/// ```
167/// use axum_macros::FromRequest;
168/// use axum_extra::{
169/// TypedHeader,
170/// headers::{ContentType, UserAgent},
171/// typed_header::TypedHeaderRejection,
172/// };
173///
174/// #[derive(FromRequest)]
175/// struct MyExtractor {
176/// // This will extracted via `Option::<TypedHeader<ContentType>>::from_request`
177/// #[from_request(via(TypedHeader))]
178/// content_type: Option<ContentType>,
179/// // This will extracted via
180/// // `Result::<TypedHeader<UserAgent>, TypedHeaderRejection>::from_request`
181/// #[from_request(via(TypedHeader))]
182/// user_agent: Result<UserAgent, TypedHeaderRejection>,
183/// }
184///
185/// async fn handler(extractor: MyExtractor) {}
186/// ```
187///
188/// ## The rejection
189///
190/// By default [`axum::response::Response`] will be used as the rejection. You can also use your own
191/// rejection type with `#[from_request(rejection(YourType))]`:
192///
193/// ```
194/// use axum::{
195/// extract::{
196/// rejection::{ExtensionRejection, StringRejection},
197/// FromRequest,
198/// },
199/// Extension,
200/// response::{Response, IntoResponse},
201/// };
202///
203/// #[derive(FromRequest)]
204/// #[from_request(rejection(MyRejection))]
205/// struct MyExtractor {
206/// state: Extension<String>,
207/// body: String,
208/// }
209///
210/// struct MyRejection(Response);
211///
212/// // This tells axum how to convert `Extension`'s rejections into `MyRejection`
213/// impl From<ExtensionRejection> for MyRejection {
214/// fn from(rejection: ExtensionRejection) -> Self {
215/// // ...
216/// # todo!()
217/// }
218/// }
219///
220/// // This tells axum how to convert `String`'s rejections into `MyRejection`
221/// impl From<StringRejection> for MyRejection {
222/// fn from(rejection: StringRejection) -> Self {
223/// // ...
224/// # todo!()
225/// }
226/// }
227///
228/// // All rejections must implement `IntoResponse`
229/// impl IntoResponse for MyRejection {
230/// fn into_response(self) -> Response {
231/// self.0
232/// }
233/// }
234/// ```
235///
236/// ## Concrete state
237///
238/// If the extraction can be done only for a concrete state, that type can be specified with
239/// `#[from_request(state(YourState))]`:
240///
241/// ```
242/// use axum::extract::{FromRequest, FromRequestParts};
243///
244/// #[derive(Clone)]
245/// struct CustomState;
246///
247/// struct MyInnerType;
248///
249/// #[axum::async_trait]
250/// impl FromRequestParts<CustomState> for MyInnerType {
251/// // ...
252/// # type Rejection = ();
253///
254/// # async fn from_request_parts(
255/// # _parts: &mut axum::http::request::Parts,
256/// # _state: &CustomState
257/// # ) -> Result<Self, Self::Rejection> {
258/// # todo!()
259/// # }
260/// }
261///
262/// #[derive(FromRequest)]
263/// #[from_request(state(CustomState))]
264/// struct MyExtractor {
265/// custom: MyInnerType,
266/// body: String,
267/// }
268/// ```
269///
270/// This is not needed for a `State<T>` as the type is inferred in that case.
271///
272/// ```
273/// use axum::extract::{FromRequest, FromRequestParts, State};
274///
275/// #[derive(Clone)]
276/// struct CustomState;
277///
278/// #[derive(FromRequest)]
279/// struct MyExtractor {
280/// custom: State<CustomState>,
281/// body: String,
282/// }
283/// ```
284///
285/// # The whole type at once
286///
287/// By using `#[from_request(via(...))]` on the container you can extract the whole type at once,
288/// instead of each field individually:
289///
290/// ```
291/// use axum_macros::FromRequest;
292/// use axum::extract::Extension;
293///
294/// // This will extracted via `Extension::<State>::from_request`
295/// #[derive(Clone, FromRequest)]
296/// #[from_request(via(Extension))]
297/// struct State {
298/// // ...
299/// }
300///
301/// async fn handler(state: State) {}
302/// ```
303///
304/// The rejection will be the "via extractors"'s rejection. For the previous example that would be
305/// [`axum::extract::rejection::ExtensionRejection`].
306///
307/// You can use a different rejection type with `#[from_request(rejection(YourType))]`:
308///
309/// ```
310/// use axum_macros::FromRequest;
311/// use axum::{
312/// extract::{Extension, rejection::ExtensionRejection},
313/// response::{IntoResponse, Response},
314/// Json,
315/// http::StatusCode,
316/// };
317/// use serde_json::json;
318///
319/// // This will extracted via `Extension::<State>::from_request`
320/// #[derive(Clone, FromRequest)]
321/// #[from_request(
322/// via(Extension),
323/// // Use your own rejection type
324/// rejection(MyRejection),
325/// )]
326/// struct State {
327/// // ...
328/// }
329///
330/// struct MyRejection(Response);
331///
332/// // This tells axum how to convert `Extension`'s rejections into `MyRejection`
333/// impl From<ExtensionRejection> for MyRejection {
334/// fn from(rejection: ExtensionRejection) -> Self {
335/// let response = (
336/// StatusCode::INTERNAL_SERVER_ERROR,
337/// Json(json!({ "error": "Something went wrong..." })),
338/// ).into_response();
339///
340/// MyRejection(response)
341/// }
342/// }
343///
344/// // All rejections must implement `IntoResponse`
345/// impl IntoResponse for MyRejection {
346/// fn into_response(self) -> Response {
347/// self.0
348/// }
349/// }
350///
351/// async fn handler(state: State) {}
352/// ```
353///
354/// This allows you to wrap other extractors and easily customize the rejection:
355///
356/// ```
357/// use axum_macros::FromRequest;
358/// use axum::{
359/// extract::{Extension, rejection::JsonRejection},
360/// response::{IntoResponse, Response},
361/// http::StatusCode,
362/// };
363/// use serde_json::json;
364/// use serde::Deserialize;
365///
366/// // create an extractor that internally uses `axum::Json` but has a custom rejection
367/// #[derive(FromRequest)]
368/// #[from_request(via(axum::Json), rejection(MyRejection))]
369/// struct MyJson<T>(T);
370///
371/// struct MyRejection(Response);
372///
373/// impl From<JsonRejection> for MyRejection {
374/// fn from(rejection: JsonRejection) -> Self {
375/// let response = (
376/// StatusCode::INTERNAL_SERVER_ERROR,
377/// axum::Json(json!({ "error": rejection.to_string() })),
378/// ).into_response();
379///
380/// MyRejection(response)
381/// }
382/// }
383///
384/// impl IntoResponse for MyRejection {
385/// fn into_response(self) -> Response {
386/// self.0
387/// }
388/// }
389///
390/// #[derive(Deserialize)]
391/// struct Payload {}
392///
393/// async fn handler(
394/// // make sure to use `MyJson` and not `axum::Json`
395/// MyJson(payload): MyJson<Payload>,
396/// ) {}
397/// ```
398///
399/// # Known limitations
400///
401/// Generics are only supported on tuple structs with exactly one field. Thus this doesn't work
402///
403/// ```compile_fail
404/// #[derive(axum_macros::FromRequest)]
405/// struct MyExtractor<T> {
406/// thing: Option<T>,
407/// }
408/// ```
409///
410/// [`FromRequest`]: https://docs.rs/axum/0.7/axum/extract/trait.FromRequest.html
411/// [`axum::response::Response`]: https://docs.rs/axum/0.7/axum/response/type.Response.html
412/// [`axum::extract::rejection::ExtensionRejection`]: https://docs.rs/axum/0.7/axum/extract/rejection/enum.ExtensionRejection.html
413#[proc_macro_derive(FromRequest, attributes(from_request))]
414pub fn derive_from_request(item: TokenStream) -> TokenStream {
415 expand_with(item, |item| from_request::expand(item, FromRequest))
416}
417
418/// Derive an implementation of [`FromRequestParts`].
419///
420/// This works similarly to `#[derive(FromRequest)]` except it uses [`FromRequestParts`]. All the
421/// same options are supported.
422///
423/// # Example
424///
425/// ```
426/// use axum_macros::FromRequestParts;
427/// use axum::{
428/// extract::Query,
429/// };
430/// use axum_extra::{
431/// TypedHeader,
432/// headers::ContentType,
433/// };
434/// use std::collections::HashMap;
435///
436/// #[derive(FromRequestParts)]
437/// struct MyExtractor {
438/// #[from_request(via(Query))]
439/// query_params: HashMap<String, String>,
440/// content_type: TypedHeader<ContentType>,
441/// }
442///
443/// async fn handler(extractor: MyExtractor) {}
444/// ```
445///
446/// # Cannot extract the body
447///
448/// [`FromRequestParts`] cannot extract the request body:
449///
450/// ```compile_fail
451/// use axum_macros::FromRequestParts;
452///
453/// #[derive(FromRequestParts)]
454/// struct MyExtractor {
455/// body: String,
456/// }
457/// ```
458///
459/// Use `#[derive(FromRequest)]` for that.
460///
461/// [`FromRequestParts`]: https://docs.rs/axum/0.7/axum/extract/trait.FromRequestParts.html
462#[proc_macro_derive(FromRequestParts, attributes(from_request))]
463pub fn derive_from_request_parts(item: TokenStream) -> TokenStream {
464 expand_with(item, |item| from_request::expand(item, FromRequestParts))
465}
466
467/// Generates better error messages when applied to handler functions.
468///
469/// While using [`axum`], you can get long error messages for simple mistakes. For example:
470///
471/// ```compile_fail
472/// use axum::{routing::get, Router};
473///
474/// #[tokio::main]
475/// async fn main() {
476/// let app = Router::new().route("/", get(handler));
477///
478/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
479/// axum::serve(listener, app).await.unwrap();
480/// }
481///
482/// fn handler() -> &'static str {
483/// "Hello, world"
484/// }
485/// ```
486///
487/// You will get a long error message about function not implementing [`Handler`] trait. But why
488/// does this function not implement it? To figure it out, the [`debug_handler`] macro can be used.
489///
490/// ```compile_fail
491/// # use axum::{routing::get, Router};
492/// # use axum_macros::debug_handler;
493/// #
494/// # #[tokio::main]
495/// # async fn main() {
496/// # let app = Router::new().route("/", get(handler));
497/// #
498/// # let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
499/// # axum::serve(listener, app).await.unwrap();
500/// # }
501/// #
502/// #[debug_handler]
503/// fn handler() -> &'static str {
504/// "Hello, world"
505/// }
506/// ```
507///
508/// ```text
509/// error: handlers must be async functions
510/// --> main.rs:xx:1
511/// |
512/// xx | fn handler() -> &'static str {
513/// | ^^
514/// ```
515///
516/// As the error message says, handler function needs to be async.
517///
518/// ```no_run
519/// use axum::{routing::get, Router, debug_handler};
520///
521/// #[tokio::main]
522/// async fn main() {
523/// let app = Router::new().route("/", get(handler));
524///
525/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
526/// axum::serve(listener, app).await.unwrap();
527/// }
528///
529/// #[debug_handler]
530/// async fn handler() -> &'static str {
531/// "Hello, world"
532/// }
533/// ```
534///
535/// # Changing state type
536///
537/// By default `#[debug_handler]` assumes your state type is `()` unless your handler has a
538/// [`axum::extract::State`] argument:
539///
540/// ```
541/// use axum::{debug_handler, extract::State};
542///
543/// #[debug_handler]
544/// async fn handler(
545/// // this makes `#[debug_handler]` use `AppState`
546/// State(state): State<AppState>,
547/// ) {}
548///
549/// #[derive(Clone)]
550/// struct AppState {}
551/// ```
552///
553/// If your handler takes multiple [`axum::extract::State`] arguments or you need to otherwise
554/// customize the state type you can set it with `#[debug_handler(state = ...)]`:
555///
556/// ```
557/// use axum::{debug_handler, extract::{State, FromRef}};
558///
559/// #[debug_handler(state = AppState)]
560/// async fn handler(
561/// State(app_state): State<AppState>,
562/// State(inner_state): State<InnerState>,
563/// ) {}
564///
565/// #[derive(Clone)]
566/// struct AppState {
567/// inner: InnerState,
568/// }
569///
570/// #[derive(Clone)]
571/// struct InnerState {}
572///
573/// impl FromRef<AppState> for InnerState {
574/// fn from_ref(state: &AppState) -> Self {
575/// state.inner.clone()
576/// }
577/// }
578/// ```
579///
580/// # Limitations
581///
582/// This macro does not work for functions in an `impl` block that don't have a `self` parameter:
583///
584/// ```compile_fail
585/// use axum::{debug_handler, extract::Path};
586///
587/// struct App {}
588///
589/// impl App {
590/// #[debug_handler]
591/// async fn handler(Path(_): Path<String>) {}
592/// }
593/// ```
594///
595/// This will yield an error similar to this:
596///
597/// ```text
598/// error[E0425]: cannot find function `__axum_macros_check_handler_0_from_request_check` in this scope
599// --> src/main.rs:xx:xx
600// |
601// xx | pub async fn handler(Path(_): Path<String>) {}
602// | ^^^^ not found in this scope
603/// ```
604///
605/// # Performance
606///
607/// This macro has no effect when compiled with the release profile. (eg. `cargo build --release`)
608///
609/// [`axum`]: https://docs.rs/axum/0.7
610/// [`Handler`]: https://docs.rs/axum/0.7/axum/handler/trait.Handler.html
611/// [`axum::extract::State`]: https://docs.rs/axum/0.7/axum/extract/struct.State.html
612/// [`debug_handler`]: macro@debug_handler
613#[proc_macro_attribute]
614pub fn debug_handler(_attr: TokenStream, input: TokenStream) -> TokenStream {
615 #[cfg(not(debug_assertions))]
616 return input;
617
618 #[cfg(debug_assertions)]
619 return expand_attr_with(_attr, input, |attrs, item_fn| {
620 debug_handler::expand(attrs, item_fn, FunctionKind::Handler)
621 });
622}
623
624/// Generates better error messages when applied to middleware functions.
625///
626/// This works similarly to [`#[debug_handler]`](macro@debug_handler) except for middleware using
627/// [`axum::middleware::from_fn`].
628///
629/// # Example
630///
631/// ```no_run
632/// use axum::{
633/// routing::get,
634/// extract::Request,
635/// response::Response,
636/// Router,
637/// middleware::{self, Next},
638/// debug_middleware,
639/// };
640///
641/// #[tokio::main]
642/// async fn main() {
643/// let app = Router::new()
644/// .route("/", get(|| async {}))
645/// .layer(middleware::from_fn(my_middleware));
646///
647/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
648/// axum::serve(listener, app).await.unwrap();
649/// }
650///
651/// // if this wasn't a valid middleware function #[debug_middleware] would
652/// // improve compile error
653/// #[debug_middleware]
654/// async fn my_middleware(
655/// request: Request,
656/// next: Next,
657/// ) -> Response {
658/// next.run(request).await
659/// }
660/// ```
661///
662/// # Performance
663///
664/// This macro has no effect when compiled with the release profile. (eg. `cargo build --release`)
665///
666/// [`axum`]: https://docs.rs/axum/latest
667/// [`axum::middleware::from_fn`]: https://docs.rs/axum/0.7/axum/middleware/fn.from_fn.html
668/// [`debug_middleware`]: macro@debug_middleware
669#[proc_macro_attribute]
670pub fn debug_middleware(_attr: TokenStream, input: TokenStream) -> TokenStream {
671 #[cfg(not(debug_assertions))]
672 return input;
673
674 #[cfg(debug_assertions)]
675 return expand_attr_with(_attr, input, |attrs, item_fn| {
676 debug_handler::expand(attrs, item_fn, FunctionKind::Middleware)
677 });
678}
679
680/// Private API: Do no use this!
681///
682/// Attribute macro to be placed on test functions that'll generate two functions:
683///
684/// 1. One identical to the function it was placed on.
685/// 2. One where calls to `Router::nest` has been replaced with `Router::nest_service`
686///
687/// This makes it easy to that `nest` and `nest_service` behaves in the same way, without having to
688/// manually write identical tests for both methods.
689#[cfg(feature = "__private")]
690#[proc_macro_attribute]
691#[doc(hidden)]
692pub fn __private_axum_test(_attr: TokenStream, input: TokenStream) -> TokenStream {
693 expand_attr_with(_attr, input, axum_test::expand)
694}
695
696/// Derive an implementation of [`axum_extra::routing::TypedPath`].
697///
698/// See that trait for more details.
699///
700/// [`axum_extra::routing::TypedPath`]: https://docs.rs/axum-extra/latest/axum_extra/routing/trait.TypedPath.html
701#[proc_macro_derive(TypedPath, attributes(typed_path))]
702pub fn derive_typed_path(input: TokenStream) -> TokenStream {
703 expand_with(input, typed_path::expand)
704}
705
706/// Derive an implementation of [`FromRef`] for each field in a struct.
707///
708/// # Example
709///
710/// ```
711/// use axum::{
712/// Router,
713/// routing::get,
714/// extract::{State, FromRef},
715/// };
716///
717/// #
718/// # type AuthToken = String;
719/// # type DatabasePool = ();
720/// #
721/// // This will implement `FromRef` for each field in the struct.
722/// #[derive(FromRef, Clone)]
723/// struct AppState {
724/// auth_token: AuthToken,
725/// database_pool: DatabasePool,
726/// // fields can also be skipped
727/// #[from_ref(skip)]
728/// api_token: String,
729/// }
730///
731/// // So those types can be extracted via `State`
732/// async fn handler(State(auth_token): State<AuthToken>) {}
733///
734/// async fn other_handler(State(database_pool): State<DatabasePool>) {}
735///
736/// # let auth_token = Default::default();
737/// # let database_pool = Default::default();
738/// let state = AppState {
739/// auth_token,
740/// database_pool,
741/// api_token: "secret".to_owned(),
742/// };
743///
744/// let app = Router::new()
745/// .route("/", get(handler).post(other_handler))
746/// .with_state(state);
747/// # let _: axum::Router = app;
748/// ```
749///
750/// [`FromRef`]: https://docs.rs/axum/0.7/axum/extract/trait.FromRef.html
751#[proc_macro_derive(FromRef, attributes(from_ref))]
752pub fn derive_from_ref(item: TokenStream) -> TokenStream {
753 expand_with(item, from_ref::expand)
754}
755
756fn expand_with<F, I, K>(input: TokenStream, f: F) -> TokenStream
757where
758 F: FnOnce(I) -> syn::Result<K>,
759 I: Parse,
760 K: ToTokens,
761{
762 expand(syn::parse(input).and_then(f))
763}
764
765fn expand_attr_with<F, A, I, K>(attr: TokenStream, input: TokenStream, f: F) -> TokenStream
766where
767 F: FnOnce(A, I) -> K,
768 A: Parse,
769 I: Parse,
770 K: ToTokens,
771{
772 let expand_result = (|| {
773 let attr = syn::parse(attr)?;
774 let input = syn::parse(input)?;
775 Ok(f(attr, input))
776 })();
777 expand(expand_result)
778}
779
780fn expand<T>(result: syn::Result<T>) -> TokenStream
781where
782 T: ToTokens,
783{
784 match result {
785 Ok(tokens) => {
786 let tokens = (quote! { #tokens }).into();
787 if std::env::var_os("AXUM_MACROS_DEBUG").is_some() {
788 eprintln!("{tokens}");
789 }
790 tokens
791 }
792 Err(err) => err.into_compile_error().into(),
793 }
794}
795
796fn infer_state_types<'a, I>(types: I) -> impl Iterator<Item = Type> + 'a
797where
798 I: Iterator<Item = &'a Type> + 'a,
799{
800 types
801 .filter_map(|ty| {
802 if let Type::Path(path) = ty {
803 Some(&path.path)
804 } else {
805 None
806 }
807 })
808 .filter_map(|path| {
809 if let Some(last_segment) = path.segments.last() {
810 if last_segment.ident != "State" {
811 return None;
812 }
813
814 match &last_segment.arguments {
815 syn::PathArguments::AngleBracketed(args) if args.args.len() == 1 => {
816 Some(args.args.first().unwrap())
817 }
818 _ => None,
819 }
820 } else {
821 None
822 }
823 })
824 .filter_map(|generic_arg| {
825 if let syn::GenericArgument::Type(ty) = generic_arg {
826 Some(ty)
827 } else {
828 None
829 }
830 })
831 .cloned()
832}
833
834#[cfg(test)]
835fn run_ui_tests(directory: &str) {
836 #[rustversion::nightly]
837 fn go(directory: &str) {
838 let t = trybuild::TestCases::new();
839
840 if let Ok(mut path) = std::env::var("AXUM_TEST_ONLY") {
841 if let Some(path_without_prefix) = path.strip_prefix("axum-macros/") {
842 path = path_without_prefix.to_owned();
843 }
844
845 if !path.contains(&format!("/{directory}/")) {
846 return;
847 }
848
849 if path.contains("/fail/") {
850 t.compile_fail(path);
851 } else if path.contains("/pass/") {
852 t.pass(path);
853 } else {
854 panic!()
855 }
856 } else {
857 t.compile_fail(format!("tests/{directory}/fail/*.rs"));
858 t.pass(format!("tests/{directory}/pass/*.rs"));
859 }
860 }
861
862 #[rustversion::not(nightly)]
863 fn go(_directory: &str) {}
864
865 go(directory);
866}