axum_macros/
from_ref.rs

1use proc_macro2::{Ident, TokenStream};
2use quote::quote_spanned;
3use syn::{
4    parse::{Parse, ParseStream},
5    spanned::Spanned,
6    Field, ItemStruct, Token, Type,
7};
8
9use crate::attr_parsing::{combine_unary_attribute, parse_attrs, Combine};
10
11pub(crate) fn expand(item: ItemStruct) -> syn::Result<TokenStream> {
12    if !item.generics.params.is_empty() {
13        return Err(syn::Error::new_spanned(
14            item.generics,
15            "`#[derive(FromRef)]` doesn't support generics",
16        ));
17    }
18
19    let tokens = item
20        .fields
21        .iter()
22        .enumerate()
23        .map(|(idx, field)| expand_field(&item.ident, idx, field))
24        .collect();
25
26    Ok(tokens)
27}
28
29fn expand_field(state: &Ident, idx: usize, field: &Field) -> TokenStream {
30    let FieldAttrs { skip } = match parse_attrs("from_ref", &field.attrs) {
31        Ok(attrs) => attrs,
32        Err(err) => return err.into_compile_error(),
33    };
34
35    if skip.is_some() {
36        return TokenStream::default();
37    }
38
39    let field_ty = &field.ty;
40    let span = field.ty.span();
41
42    let body = if let Some(field_ident) = &field.ident {
43        if matches!(field_ty, Type::Reference(_)) {
44            quote_spanned! {span=> state.#field_ident }
45        } else {
46            quote_spanned! {span=> state.#field_ident.clone() }
47        }
48    } else {
49        let idx = syn::Index {
50            index: idx as _,
51            span: field.span(),
52        };
53        quote_spanned! {span=> state.#idx.clone() }
54    };
55
56    quote_spanned! {span=>
57        #[allow(clippy::clone_on_copy, clippy::clone_on_ref_ptr)]
58        impl ::axum::extract::FromRef<#state> for #field_ty {
59            fn from_ref(state: &#state) -> Self {
60                #body
61            }
62        }
63    }
64}
65
66mod kw {
67    syn::custom_keyword!(skip);
68}
69
70#[derive(Default)]
71pub(super) struct FieldAttrs {
72    pub(super) skip: Option<kw::skip>,
73}
74
75impl Parse for FieldAttrs {
76    fn parse(input: ParseStream) -> syn::Result<Self> {
77        let mut skip = None;
78
79        while !input.is_empty() {
80            let lh = input.lookahead1();
81            if lh.peek(kw::skip) {
82                skip = Some(input.parse()?);
83            } else {
84                return Err(lh.error());
85            }
86
87            let _ = input.parse::<Token![,]>();
88        }
89
90        Ok(Self { skip })
91    }
92}
93
94impl Combine for FieldAttrs {
95    fn combine(mut self, other: Self) -> syn::Result<Self> {
96        let Self { skip } = other;
97        combine_unary_attribute(&mut self.skip, skip)?;
98        Ok(self)
99    }
100}
101
102#[test]
103fn ui() {
104    crate::run_ui_tests("from_ref");
105}