axum_macros/
typed_path.rs

1use proc_macro2::{Span, TokenStream};
2use quote::{format_ident, quote, quote_spanned};
3use syn::{parse::Parse, ItemStruct, LitStr, Token};
4
5use crate::attr_parsing::{combine_attribute, parse_parenthesized_attribute, second, Combine};
6
7pub(crate) fn expand(item_struct: ItemStruct) -> syn::Result<TokenStream> {
8    let ItemStruct {
9        attrs,
10        ident,
11        generics,
12        fields,
13        ..
14    } = &item_struct;
15
16    if !generics.params.is_empty() || generics.where_clause.is_some() {
17        return Err(syn::Error::new_spanned(
18            generics,
19            "`#[derive(TypedPath)]` doesn't support generics",
20        ));
21    }
22
23    let Attrs { path, rejection } = crate::attr_parsing::parse_attrs("typed_path", attrs)?;
24
25    let path = path.ok_or_else(|| {
26        syn::Error::new(
27            Span::call_site(),
28            "Missing path: `#[typed_path(\"/foo/bar\")]`",
29        )
30    })?;
31
32    let rejection = rejection.map(second);
33
34    match fields {
35        syn::Fields::Named(_) => {
36            let segments = parse_path(&path)?;
37            Ok(expand_named_fields(ident, path, &segments, rejection))
38        }
39        syn::Fields::Unnamed(fields) => {
40            let segments = parse_path(&path)?;
41            expand_unnamed_fields(fields, ident, path, &segments, rejection)
42        }
43        syn::Fields::Unit => expand_unit_fields(ident, path, rejection),
44    }
45}
46
47mod kw {
48    syn::custom_keyword!(rejection);
49}
50
51#[derive(Default)]
52struct Attrs {
53    path: Option<LitStr>,
54    rejection: Option<(kw::rejection, syn::Path)>,
55}
56
57impl Parse for Attrs {
58    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
59        let mut path = None;
60        let mut rejection = None;
61
62        while !input.is_empty() {
63            let lh = input.lookahead1();
64            if lh.peek(LitStr) {
65                path = Some(input.parse()?);
66            } else if lh.peek(kw::rejection) {
67                parse_parenthesized_attribute(input, &mut rejection)?;
68            } else {
69                return Err(lh.error());
70            }
71
72            let _ = input.parse::<Token![,]>();
73        }
74
75        Ok(Self { path, rejection })
76    }
77}
78
79impl Combine for Attrs {
80    fn combine(mut self, other: Self) -> syn::Result<Self> {
81        let Self { path, rejection } = other;
82        if let Some(path) = path {
83            if self.path.is_some() {
84                return Err(syn::Error::new_spanned(
85                    path,
86                    "path specified more than once",
87                ));
88            }
89            self.path = Some(path);
90        }
91        combine_attribute(&mut self.rejection, rejection)?;
92        Ok(self)
93    }
94}
95
96fn expand_named_fields(
97    ident: &syn::Ident,
98    path: LitStr,
99    segments: &[Segment],
100    rejection: Option<syn::Path>,
101) -> TokenStream {
102    let format_str = format_str_from_path(segments);
103    let captures = captures_from_path(segments);
104
105    let typed_path_impl = quote_spanned! {path.span()=>
106        #[automatically_derived]
107        impl ::axum_extra::routing::TypedPath for #ident {
108            const PATH: &'static str = #path;
109        }
110    };
111
112    let display_impl = quote_spanned! {path.span()=>
113        #[automatically_derived]
114        impl ::std::fmt::Display for #ident {
115            #[allow(clippy::unnecessary_to_owned)]
116            fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
117                let Self { #(#captures,)* } = self;
118                write!(
119                    f,
120                    #format_str,
121                    #(
122                        #captures = ::axum_extra::__private::utf8_percent_encode(
123                            &#captures.to_string(),
124                            ::axum_extra::__private::PATH_SEGMENT,
125                        )
126                    ),*
127                )
128            }
129        }
130    };
131
132    let rejection_assoc_type = rejection_assoc_type(&rejection);
133    let map_err_rejection = map_err_rejection(&rejection);
134
135    let from_request_impl = quote! {
136        #[::axum::async_trait]
137        #[automatically_derived]
138        impl<S> ::axum::extract::FromRequestParts<S> for #ident
139        where
140            S: Send + Sync,
141        {
142            type Rejection = #rejection_assoc_type;
143
144            async fn from_request_parts(
145                parts: &mut ::axum::http::request::Parts,
146                state: &S,
147            ) -> ::std::result::Result<Self, Self::Rejection> {
148                ::axum::extract::Path::from_request_parts(parts, state)
149                    .await
150                    .map(|path| path.0)
151                    #map_err_rejection
152            }
153        }
154    };
155
156    quote! {
157        #typed_path_impl
158        #display_impl
159        #from_request_impl
160    }
161}
162
163fn expand_unnamed_fields(
164    fields: &syn::FieldsUnnamed,
165    ident: &syn::Ident,
166    path: LitStr,
167    segments: &[Segment],
168    rejection: Option<syn::Path>,
169) -> syn::Result<TokenStream> {
170    let num_captures = segments
171        .iter()
172        .filter(|segment| match segment {
173            Segment::Capture(_, _) => true,
174            Segment::Static(_) => false,
175        })
176        .count();
177    let num_fields = fields.unnamed.len();
178    if num_fields != num_captures {
179        return Err(syn::Error::new_spanned(
180            fields,
181            format!(
182                "Mismatch in number of captures and fields. Path has {} but struct has {}",
183                simple_pluralize(num_captures, "capture"),
184                simple_pluralize(num_fields, "field"),
185            ),
186        ));
187    }
188
189    let destructure_self = segments
190        .iter()
191        .filter_map(|segment| match segment {
192            Segment::Capture(capture, _) => Some(capture),
193            Segment::Static(_) => None,
194        })
195        .enumerate()
196        .map(|(idx, capture)| {
197            let idx = syn::Index {
198                index: idx as _,
199                span: Span::call_site(),
200            };
201            let capture = format_ident!("{}", capture, span = path.span());
202            quote_spanned! {path.span()=>
203                #idx: #capture,
204            }
205        });
206
207    let format_str = format_str_from_path(segments);
208    let captures = captures_from_path(segments);
209
210    let typed_path_impl = quote_spanned! {path.span()=>
211        #[automatically_derived]
212        impl ::axum_extra::routing::TypedPath for #ident {
213            const PATH: &'static str = #path;
214        }
215    };
216
217    let display_impl = quote_spanned! {path.span()=>
218        #[automatically_derived]
219        impl ::std::fmt::Display for #ident {
220            #[allow(clippy::unnecessary_to_owned)]
221            fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
222                let Self { #(#destructure_self)* } = self;
223                write!(
224                    f,
225                    #format_str,
226                    #(
227                        #captures = ::axum_extra::__private::utf8_percent_encode(
228                            &#captures.to_string(),
229                            ::axum_extra::__private::PATH_SEGMENT,
230                        )
231                    ),*
232                )
233            }
234        }
235    };
236
237    let rejection_assoc_type = rejection_assoc_type(&rejection);
238    let map_err_rejection = map_err_rejection(&rejection);
239
240    let from_request_impl = quote! {
241        #[::axum::async_trait]
242        #[automatically_derived]
243        impl<S> ::axum::extract::FromRequestParts<S> for #ident
244        where
245            S: Send + Sync,
246        {
247            type Rejection = #rejection_assoc_type;
248
249            async fn from_request_parts(
250                parts: &mut ::axum::http::request::Parts,
251                state: &S,
252            ) -> ::std::result::Result<Self, Self::Rejection> {
253                ::axum::extract::Path::from_request_parts(parts, state)
254                    .await
255                    .map(|path| path.0)
256                    #map_err_rejection
257            }
258        }
259    };
260
261    Ok(quote! {
262        #typed_path_impl
263        #display_impl
264        #from_request_impl
265    })
266}
267
268fn simple_pluralize(count: usize, word: &str) -> String {
269    if count == 1 {
270        format!("{count} {word}")
271    } else {
272        format!("{count} {word}s")
273    }
274}
275
276fn expand_unit_fields(
277    ident: &syn::Ident,
278    path: LitStr,
279    rejection: Option<syn::Path>,
280) -> syn::Result<TokenStream> {
281    for segment in parse_path(&path)? {
282        match segment {
283            Segment::Capture(_, span) => {
284                return Err(syn::Error::new(
285                    span,
286                    "Typed paths for unit structs cannot contain captures",
287                ));
288            }
289            Segment::Static(_) => {}
290        }
291    }
292
293    let typed_path_impl = quote_spanned! {path.span()=>
294        #[automatically_derived]
295        impl ::axum_extra::routing::TypedPath for #ident {
296            const PATH: &'static str = #path;
297        }
298    };
299
300    let display_impl = quote_spanned! {path.span()=>
301        #[automatically_derived]
302        impl ::std::fmt::Display for #ident {
303            fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
304                write!(f, #path)
305            }
306        }
307    };
308
309    let rejection_assoc_type = if let Some(rejection) = &rejection {
310        quote! { #rejection }
311    } else {
312        quote! { ::axum::http::StatusCode }
313    };
314    let create_rejection = if let Some(rejection) = &rejection {
315        quote! {
316            Err(<#rejection as ::std::default::Default>::default())
317        }
318    } else {
319        quote! {
320            Err(::axum::http::StatusCode::NOT_FOUND)
321        }
322    };
323
324    let from_request_impl = quote! {
325        #[::axum::async_trait]
326        #[automatically_derived]
327        impl<S> ::axum::extract::FromRequestParts<S> for #ident
328        where
329            S: Send + Sync,
330        {
331            type Rejection = #rejection_assoc_type;
332
333            async fn from_request_parts(
334                parts: &mut ::axum::http::request::Parts,
335                _state: &S,
336            ) -> ::std::result::Result<Self, Self::Rejection> {
337                if parts.uri.path() == <Self as ::axum_extra::routing::TypedPath>::PATH {
338                    Ok(Self)
339                } else {
340                    #create_rejection
341                }
342            }
343        }
344    };
345
346    Ok(quote! {
347        #typed_path_impl
348        #display_impl
349        #from_request_impl
350    })
351}
352
353fn format_str_from_path(segments: &[Segment]) -> String {
354    segments
355        .iter()
356        .map(|segment| match segment {
357            Segment::Capture(capture, _) => format!("{{{capture}}}"),
358            Segment::Static(segment) => segment.to_owned(),
359        })
360        .collect::<Vec<_>>()
361        .join("/")
362}
363
364fn captures_from_path(segments: &[Segment]) -> Vec<syn::Ident> {
365    segments
366        .iter()
367        .filter_map(|segment| match segment {
368            Segment::Capture(capture, span) => Some(format_ident!("{}", capture, span = *span)),
369            Segment::Static(_) => None,
370        })
371        .collect::<Vec<_>>()
372}
373
374fn parse_path(path: &LitStr) -> syn::Result<Vec<Segment>> {
375    let value = path.value();
376    if value.is_empty() {
377        return Err(syn::Error::new_spanned(
378            path,
379            "paths must start with a `/`. Use \"/\" for root routes",
380        ));
381    } else if !path.value().starts_with('/') {
382        return Err(syn::Error::new_spanned(path, "paths must start with a `/`"));
383    }
384
385    path.value()
386        .split('/')
387        .map(|segment| {
388            if let Some(capture) = segment
389                .strip_prefix(':')
390                .or_else(|| segment.strip_prefix('*'))
391            {
392                Ok(Segment::Capture(capture.to_owned(), path.span()))
393            } else {
394                Ok(Segment::Static(segment.to_owned()))
395            }
396        })
397        .collect()
398}
399
400enum Segment {
401    Capture(String, Span),
402    Static(String),
403}
404
405fn path_rejection() -> TokenStream {
406    quote! {
407        <::axum::extract::Path<Self> as ::axum::extract::FromRequestParts<S>>::Rejection
408    }
409}
410
411fn rejection_assoc_type(rejection: &Option<syn::Path>) -> TokenStream {
412    match rejection {
413        Some(rejection) => quote! { #rejection },
414        None => path_rejection(),
415    }
416}
417
418fn map_err_rejection(rejection: &Option<syn::Path>) -> TokenStream {
419    rejection
420        .as_ref()
421        .map(|rejection| {
422            let path_rejection = path_rejection();
423            quote! {
424                .map_err(|rejection| {
425                    <#rejection as ::std::convert::From<#path_rejection>>::from(rejection)
426                })
427            }
428        })
429        .unwrap_or_default()
430}
431
432#[test]
433fn ui() {
434    crate::run_ui_tests("typed_path");
435}