test_strategy/
proptest_fn.rs

1use crate::syn_utils::{Arg, Args};
2use proc_macro2::TokenStream;
3use quote::{quote, ToTokens};
4use syn::{
5    parse2, parse_quote, parse_str, spanned::Spanned, token, Field, FnArg, Ident, ItemFn, Pat,
6    Result, Visibility,
7};
8
9pub fn build_proptest(attr: TokenStream, mut item_fn: ItemFn) -> Result<TokenStream> {
10    let mut attr_args = None;
11    if !attr.is_empty() {
12        attr_args = Some(parse2::<Args>(attr)?);
13    }
14    let mut dump = false;
15    item_fn.attrs.retain(|attr| {
16        if attr.path.is_ident("proptest_dump") {
17            dump = true;
18            false
19        } else {
20            true
21        }
22    });
23    let args_type_str = format!("_{}Args", to_camel_case(&item_fn.sig.ident.to_string()));
24    let args_type_ident: Ident = parse_str(&args_type_str).unwrap();
25    let args = item_fn
26        .sig
27        .inputs
28        .iter()
29        .map(TestFnArg::from)
30        .collect::<Result<Vec<_>>>()?;
31    let args_pats = args.iter().map(|arg| arg.pat());
32    let block = &item_fn.block;
33    let block = quote! {
34        {
35            let #args_type_ident { #(#args_pats,)* } = input;
36            #block
37        }
38    };
39    item_fn.sig.inputs = parse_quote! { input: #args_type_ident };
40    item_fn.block = Box::new(parse2(block)?);
41    let args_fields = args.iter().map(|arg| &arg.field);
42    let config = to_proptest_config(attr_args);
43    let ts = quote! {
44        #[derive(test_strategy::Arbitrary, Debug)]
45        struct #args_type_ident {
46            #(#args_fields,)*
47        }
48        proptest::proptest! {
49            #config
50            #[test]
51            #item_fn
52        }
53    };
54    if dump {
55        panic!("{}", ts);
56    }
57    Ok(ts)
58}
59
60fn to_proptest_config(args: Option<Args>) -> TokenStream {
61    if let Some(args) = args {
62        let mut base_expr = None;
63        let mut inits = Vec::new();
64        for arg in args {
65            match arg {
66                Arg::Value(value) => base_expr = Some(value),
67                Arg::NameValue { name, value, .. } => inits.push(quote!(#name : #value)),
68            }
69        }
70        let base_expr = base_expr.unwrap_or_else(|| {
71            parse_quote!(<proptest::test_runner::Config as std::default::Default>::default())
72        });
73        quote! {
74            #![proptest_config(proptest::test_runner::Config {
75                #(#inits,)*
76                .. #base_expr
77              })]
78        }
79    } else {
80        quote! {}
81    }
82}
83struct TestFnArg {
84    field: Field,
85    mutability: Option<token::Mut>,
86}
87impl TestFnArg {
88    fn from(arg: &FnArg) -> Result<Self> {
89        if let FnArg::Typed(arg) = arg {
90            if let Pat::Ident(ident) = arg.pat.as_ref() {
91                if ident.attrs.is_empty() && ident.by_ref.is_none() && ident.subpat.is_none() {
92                    return Ok(Self {
93                        field: Field {
94                            attrs: arg.attrs.clone(),
95                            vis: Visibility::Inherited,
96                            ident: Some(ident.ident.clone()),
97                            colon_token: Some(arg.colon_token),
98                            ty: arg.ty.as_ref().clone(),
99                        },
100                        mutability: ident.mutability,
101                    });
102                }
103            } else {
104                bail!(arg.pat.span(), "argument pattern not supported.");
105            }
106        }
107        bail!(
108            arg.span(),
109            "argument {} is not supported.",
110            arg.to_token_stream()
111        );
112    }
113    fn pat(&self) -> TokenStream {
114        let mutability = &self.mutability;
115        let ident = &self.field.ident;
116        quote!(#mutability #ident)
117    }
118}
119
120fn to_camel_case(s: &str) -> String {
121    let mut upper = true;
122    let mut r = String::new();
123    for c in s.chars() {
124        if c == '_' {
125            upper = true;
126        } else if upper {
127            r.push_str(&c.to_uppercase().to_string());
128            upper = false;
129        } else {
130            r.push(c);
131        }
132    }
133    r
134}