1use proc_macro2::{Span, TokenStream};
2use quote::{format_ident, quote, quote_spanned};
3use syn::{parse::Parse, ItemStruct, LitStr, Token};
4
5use crate::attr_parsing::{combine_attribute, parse_parenthesized_attribute, second, Combine};
6
7pub(crate) fn expand(item_struct: ItemStruct) -> syn::Result<TokenStream> {
8 let ItemStruct {
9 attrs,
10 ident,
11 generics,
12 fields,
13 ..
14 } = &item_struct;
15
16 if !generics.params.is_empty() || generics.where_clause.is_some() {
17 return Err(syn::Error::new_spanned(
18 generics,
19 "`#[derive(TypedPath)]` doesn't support generics",
20 ));
21 }
22
23 let Attrs { path, rejection } = crate::attr_parsing::parse_attrs("typed_path", attrs)?;
24
25 let path = path.ok_or_else(|| {
26 syn::Error::new(
27 Span::call_site(),
28 "Missing path: `#[typed_path(\"/foo/bar\")]`",
29 )
30 })?;
31
32 let rejection = rejection.map(second);
33
34 match fields {
35 syn::Fields::Named(_) => {
36 let segments = parse_path(&path)?;
37 Ok(expand_named_fields(ident, path, &segments, rejection))
38 }
39 syn::Fields::Unnamed(fields) => {
40 let segments = parse_path(&path)?;
41 expand_unnamed_fields(fields, ident, path, &segments, rejection)
42 }
43 syn::Fields::Unit => expand_unit_fields(ident, path, rejection),
44 }
45}
46
47mod kw {
48 syn::custom_keyword!(rejection);
49}
50
51#[derive(Default)]
52struct Attrs {
53 path: Option<LitStr>,
54 rejection: Option<(kw::rejection, syn::Path)>,
55}
56
57impl Parse for Attrs {
58 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
59 let mut path = None;
60 let mut rejection = None;
61
62 while !input.is_empty() {
63 let lh = input.lookahead1();
64 if lh.peek(LitStr) {
65 path = Some(input.parse()?);
66 } else if lh.peek(kw::rejection) {
67 parse_parenthesized_attribute(input, &mut rejection)?;
68 } else {
69 return Err(lh.error());
70 }
71
72 let _ = input.parse::<Token![,]>();
73 }
74
75 Ok(Self { path, rejection })
76 }
77}
78
79impl Combine for Attrs {
80 fn combine(mut self, other: Self) -> syn::Result<Self> {
81 let Self { path, rejection } = other;
82 if let Some(path) = path {
83 if self.path.is_some() {
84 return Err(syn::Error::new_spanned(
85 path,
86 "path specified more than once",
87 ));
88 }
89 self.path = Some(path);
90 }
91 combine_attribute(&mut self.rejection, rejection)?;
92 Ok(self)
93 }
94}
95
96fn expand_named_fields(
97 ident: &syn::Ident,
98 path: LitStr,
99 segments: &[Segment],
100 rejection: Option<syn::Path>,
101) -> TokenStream {
102 let format_str = format_str_from_path(segments);
103 let captures = captures_from_path(segments);
104
105 let typed_path_impl = quote_spanned! {path.span()=>
106 #[automatically_derived]
107 impl ::axum_extra::routing::TypedPath for #ident {
108 const PATH: &'static str = #path;
109 }
110 };
111
112 let display_impl = quote_spanned! {path.span()=>
113 #[automatically_derived]
114 impl ::std::fmt::Display for #ident {
115 #[allow(clippy::unnecessary_to_owned)]
116 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
117 let Self { #(#captures,)* } = self;
118 write!(
119 f,
120 #format_str,
121 #(
122 #captures = ::axum_extra::__private::utf8_percent_encode(
123 &#captures.to_string(),
124 ::axum_extra::__private::PATH_SEGMENT,
125 )
126 ),*
127 )
128 }
129 }
130 };
131
132 let rejection_assoc_type = rejection_assoc_type(&rejection);
133 let map_err_rejection = map_err_rejection(&rejection);
134
135 let from_request_impl = quote! {
136 #[::axum::async_trait]
137 #[automatically_derived]
138 impl<S> ::axum::extract::FromRequestParts<S> for #ident
139 where
140 S: Send + Sync,
141 {
142 type Rejection = #rejection_assoc_type;
143
144 async fn from_request_parts(
145 parts: &mut ::axum::http::request::Parts,
146 state: &S,
147 ) -> ::std::result::Result<Self, Self::Rejection> {
148 ::axum::extract::Path::from_request_parts(parts, state)
149 .await
150 .map(|path| path.0)
151 #map_err_rejection
152 }
153 }
154 };
155
156 quote! {
157 #typed_path_impl
158 #display_impl
159 #from_request_impl
160 }
161}
162
163fn expand_unnamed_fields(
164 fields: &syn::FieldsUnnamed,
165 ident: &syn::Ident,
166 path: LitStr,
167 segments: &[Segment],
168 rejection: Option<syn::Path>,
169) -> syn::Result<TokenStream> {
170 let num_captures = segments
171 .iter()
172 .filter(|segment| match segment {
173 Segment::Capture(_, _) => true,
174 Segment::Static(_) => false,
175 })
176 .count();
177 let num_fields = fields.unnamed.len();
178 if num_fields != num_captures {
179 return Err(syn::Error::new_spanned(
180 fields,
181 format!(
182 "Mismatch in number of captures and fields. Path has {} but struct has {}",
183 simple_pluralize(num_captures, "capture"),
184 simple_pluralize(num_fields, "field"),
185 ),
186 ));
187 }
188
189 let destructure_self = segments
190 .iter()
191 .filter_map(|segment| match segment {
192 Segment::Capture(capture, _) => Some(capture),
193 Segment::Static(_) => None,
194 })
195 .enumerate()
196 .map(|(idx, capture)| {
197 let idx = syn::Index {
198 index: idx as _,
199 span: Span::call_site(),
200 };
201 let capture = format_ident!("{}", capture, span = path.span());
202 quote_spanned! {path.span()=>
203 #idx: #capture,
204 }
205 });
206
207 let format_str = format_str_from_path(segments);
208 let captures = captures_from_path(segments);
209
210 let typed_path_impl = quote_spanned! {path.span()=>
211 #[automatically_derived]
212 impl ::axum_extra::routing::TypedPath for #ident {
213 const PATH: &'static str = #path;
214 }
215 };
216
217 let display_impl = quote_spanned! {path.span()=>
218 #[automatically_derived]
219 impl ::std::fmt::Display for #ident {
220 #[allow(clippy::unnecessary_to_owned)]
221 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
222 let Self { #(#destructure_self)* } = self;
223 write!(
224 f,
225 #format_str,
226 #(
227 #captures = ::axum_extra::__private::utf8_percent_encode(
228 &#captures.to_string(),
229 ::axum_extra::__private::PATH_SEGMENT,
230 )
231 ),*
232 )
233 }
234 }
235 };
236
237 let rejection_assoc_type = rejection_assoc_type(&rejection);
238 let map_err_rejection = map_err_rejection(&rejection);
239
240 let from_request_impl = quote! {
241 #[::axum::async_trait]
242 #[automatically_derived]
243 impl<S> ::axum::extract::FromRequestParts<S> for #ident
244 where
245 S: Send + Sync,
246 {
247 type Rejection = #rejection_assoc_type;
248
249 async fn from_request_parts(
250 parts: &mut ::axum::http::request::Parts,
251 state: &S,
252 ) -> ::std::result::Result<Self, Self::Rejection> {
253 ::axum::extract::Path::from_request_parts(parts, state)
254 .await
255 .map(|path| path.0)
256 #map_err_rejection
257 }
258 }
259 };
260
261 Ok(quote! {
262 #typed_path_impl
263 #display_impl
264 #from_request_impl
265 })
266}
267
268fn simple_pluralize(count: usize, word: &str) -> String {
269 if count == 1 {
270 format!("{count} {word}")
271 } else {
272 format!("{count} {word}s")
273 }
274}
275
276fn expand_unit_fields(
277 ident: &syn::Ident,
278 path: LitStr,
279 rejection: Option<syn::Path>,
280) -> syn::Result<TokenStream> {
281 for segment in parse_path(&path)? {
282 match segment {
283 Segment::Capture(_, span) => {
284 return Err(syn::Error::new(
285 span,
286 "Typed paths for unit structs cannot contain captures",
287 ));
288 }
289 Segment::Static(_) => {}
290 }
291 }
292
293 let typed_path_impl = quote_spanned! {path.span()=>
294 #[automatically_derived]
295 impl ::axum_extra::routing::TypedPath for #ident {
296 const PATH: &'static str = #path;
297 }
298 };
299
300 let display_impl = quote_spanned! {path.span()=>
301 #[automatically_derived]
302 impl ::std::fmt::Display for #ident {
303 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
304 write!(f, #path)
305 }
306 }
307 };
308
309 let rejection_assoc_type = if let Some(rejection) = &rejection {
310 quote! { #rejection }
311 } else {
312 quote! { ::axum::http::StatusCode }
313 };
314 let create_rejection = if let Some(rejection) = &rejection {
315 quote! {
316 Err(<#rejection as ::std::default::Default>::default())
317 }
318 } else {
319 quote! {
320 Err(::axum::http::StatusCode::NOT_FOUND)
321 }
322 };
323
324 let from_request_impl = quote! {
325 #[::axum::async_trait]
326 #[automatically_derived]
327 impl<S> ::axum::extract::FromRequestParts<S> for #ident
328 where
329 S: Send + Sync,
330 {
331 type Rejection = #rejection_assoc_type;
332
333 async fn from_request_parts(
334 parts: &mut ::axum::http::request::Parts,
335 _state: &S,
336 ) -> ::std::result::Result<Self, Self::Rejection> {
337 if parts.uri.path() == <Self as ::axum_extra::routing::TypedPath>::PATH {
338 Ok(Self)
339 } else {
340 #create_rejection
341 }
342 }
343 }
344 };
345
346 Ok(quote! {
347 #typed_path_impl
348 #display_impl
349 #from_request_impl
350 })
351}
352
353fn format_str_from_path(segments: &[Segment]) -> String {
354 segments
355 .iter()
356 .map(|segment| match segment {
357 Segment::Capture(capture, _) => format!("{{{capture}}}"),
358 Segment::Static(segment) => segment.to_owned(),
359 })
360 .collect::<Vec<_>>()
361 .join("/")
362}
363
364fn captures_from_path(segments: &[Segment]) -> Vec<syn::Ident> {
365 segments
366 .iter()
367 .filter_map(|segment| match segment {
368 Segment::Capture(capture, span) => Some(format_ident!("{}", capture, span = *span)),
369 Segment::Static(_) => None,
370 })
371 .collect::<Vec<_>>()
372}
373
374fn parse_path(path: &LitStr) -> syn::Result<Vec<Segment>> {
375 let value = path.value();
376 if value.is_empty() {
377 return Err(syn::Error::new_spanned(
378 path,
379 "paths must start with a `/`. Use \"/\" for root routes",
380 ));
381 } else if !path.value().starts_with('/') {
382 return Err(syn::Error::new_spanned(path, "paths must start with a `/`"));
383 }
384
385 path.value()
386 .split('/')
387 .map(|segment| {
388 if let Some(capture) = segment
389 .strip_prefix(':')
390 .or_else(|| segment.strip_prefix('*'))
391 {
392 Ok(Segment::Capture(capture.to_owned(), path.span()))
393 } else {
394 Ok(Segment::Static(segment.to_owned()))
395 }
396 })
397 .collect()
398}
399
400enum Segment {
401 Capture(String, Span),
402 Static(String),
403}
404
405fn path_rejection() -> TokenStream {
406 quote! {
407 <::axum::extract::Path<Self> as ::axum::extract::FromRequestParts<S>>::Rejection
408 }
409}
410
411fn rejection_assoc_type(rejection: &Option<syn::Path>) -> TokenStream {
412 match rejection {
413 Some(rejection) => quote! { #rejection },
414 None => path_rejection(),
415 }
416}
417
418fn map_err_rejection(rejection: &Option<syn::Path>) -> TokenStream {
419 rejection
420 .as_ref()
421 .map(|rejection| {
422 let path_rejection = path_rejection();
423 quote! {
424 .map_err(|rejection| {
425 <#rejection as ::std::convert::From<#path_rejection>>::from(rejection)
426 })
427 }
428 })
429 .unwrap_or_default()
430}
431
432#[test]
433fn ui() {
434 crate::run_ui_tests("typed_path");
435}