test_strategy/
syn_utils.rs

1use proc_macro2::{Group, Spacing, Span, TokenStream, TokenTree};
2use quote::{quote, ToTokens};
3use std::{
4    collections::{HashMap, HashSet},
5    iter::once,
6    ops::Deref,
7};
8use structmeta::{Parse, ToTokens};
9use syn::{
10    ext::IdentExt,
11    parenthesized,
12    parse::{Parse, ParseStream},
13    parse2, parse_str,
14    punctuated::Punctuated,
15    spanned::Spanned,
16    token::{Comma, Paren},
17    visit::{visit_path, visit_type, Visit},
18    Attribute, DeriveInput, Expr, Field, GenericParam, Generics, Ident, Lit, Path, Result, Token,
19    Type, WherePredicate,
20};
21
22macro_rules! bail {
23    ($span:expr, $message:literal $(,)?) => {
24        return std::result::Result::Err(syn::Error::new($span, $message))
25    };
26    ($span:expr, $err:expr $(,)?) => {
27        return std::result::Result::Err(syn::Error::new($span, $err))
28    };
29    ($span:expr, $fmt:expr, $($arg:tt)*) => {
30        return std::result::Result::Err(syn::Error::new($span, std::format!($fmt, $($arg)*)))
31    };
32}
33
34pub fn into_macro_output(input: Result<TokenStream>) -> proc_macro::TokenStream {
35    match input {
36        Ok(s) => s,
37        Err(e) => e.to_compile_error(),
38    }
39    .into()
40}
41
42pub struct Parenthesized<T> {
43    pub paren_token: Option<Paren>,
44    pub content: T,
45}
46impl<T: Parse> Parse for Parenthesized<T> {
47    fn parse(input: ParseStream) -> Result<Self> {
48        let content;
49        let paren_token = Some(parenthesized!(content in input));
50        let content = content.parse()?;
51        Ok(Self {
52            paren_token,
53            content,
54        })
55    }
56}
57impl<T> Deref for Parenthesized<T> {
58    type Target = T;
59    fn deref(&self) -> &Self::Target {
60        &self.content
61    }
62}
63
64pub fn parse_parenthesized_args(input: TokenStream) -> Result<Args> {
65    if input.is_empty() {
66        Ok(Args::new())
67    } else {
68        Ok(parse2::<Parenthesized<Args>>(input)?.content)
69    }
70}
71
72#[derive(Parse)]
73pub struct Args(#[parse(terminated)] Punctuated<Arg, Comma>);
74
75impl Args {
76    fn new() -> Self {
77        Self(Punctuated::new())
78    }
79    pub fn expect_single_value(&self, span: Span) -> Result<&Expr> {
80        if self.len() != 1 {
81            bail!(
82                span,
83                "expect 1 arguments, but supplied {} arguments.",
84                self.len()
85            );
86        }
87        match &self[0] {
88            Arg::Value(expr) => Ok(expr),
89            Arg::NameValue { .. } => bail!(span, "expected unnamed argument."),
90        }
91    }
92}
93impl Deref for Args {
94    type Target = Punctuated<Arg, Comma>;
95
96    fn deref(&self) -> &Self::Target {
97        &self.0
98    }
99}
100impl IntoIterator for Args {
101    type Item = Arg;
102    type IntoIter = <Punctuated<Arg, Comma> as IntoIterator>::IntoIter;
103
104    fn into_iter(self) -> Self::IntoIter {
105        self.0.into_iter()
106    }
107}
108
109#[derive(ToTokens, Parse)]
110pub enum Arg {
111    NameValue {
112        #[parse(peek, any)]
113        name: Ident,
114        #[parse(peek)]
115        eq_token: Token![=],
116        value: Expr,
117    },
118    Value(Expr),
119}
120
121pub struct SharpVals {
122    allow_vals: bool,
123    allow_self: bool,
124    pub vals: HashMap<FieldKey, Span>,
125    pub self_span: Option<Span>,
126}
127impl SharpVals {
128    pub fn new(allow_vals: bool, allow_self: bool) -> Self {
129        Self {
130            allow_vals,
131            allow_self,
132            vals: HashMap::new(),
133            self_span: None,
134        }
135    }
136    pub fn expand(&mut self, input: TokenStream) -> Result<TokenStream> {
137        let mut tokens = Vec::new();
138        let mut iter = input.into_iter().peekable();
139        while let Some(t) = iter.next() {
140            match &t {
141                TokenTree::Group(g) => {
142                    tokens.push(TokenTree::Group(Group::new(
143                        g.delimiter(),
144                        self.expand(g.stream())?,
145                    )));
146                    continue;
147                }
148                TokenTree::Punct(p) => {
149                    if p.as_char() == '#' && p.spacing() == Spacing::Alone {
150                        if let Some(token) = iter.peek() {
151                            if let Some(key) = FieldKey::try_from_token(token) {
152                                let span = token.span();
153                                let allow = if &key == "self" {
154                                    self.self_span.get_or_insert(span);
155                                    self.allow_self
156                                } else {
157                                    self.vals.entry(key.clone()).or_insert(span);
158                                    self.allow_vals
159                                };
160                                if !allow {
161                                    bail!(span, "cannot use `#{}` in this position.", key);
162                                }
163                                if self.self_span.is_some() {
164                                    if let Some(key) = self.vals.keys().next() {
165                                        bail!(span, "cannot use both `#self` and `#{}`", key);
166                                    }
167                                }
168                                let mut ident = key.to_dummy_ident();
169                                ident.set_span(span);
170                                tokens.extend(ident.to_token_stream());
171                                iter.next();
172                                continue;
173                            }
174                        }
175                    }
176                }
177                _ => {}
178            }
179            tokens.extend(once(t));
180        }
181        Ok(tokens.into_iter().collect())
182    }
183}
184#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
185pub enum FieldKey {
186    Named(String),
187    Unnamed(usize),
188}
189
190impl FieldKey {
191    pub fn from_ident(ident: &Ident) -> Self {
192        Self::Named(ident.unraw().to_string())
193    }
194    pub fn from_field(idx: usize, field: &Field) -> Self {
195        if let Some(ident) = &field.ident {
196            Self::from_ident(ident)
197        } else {
198            Self::Unnamed(idx)
199        }
200    }
201    pub fn try_from_token(token: &TokenTree) -> Option<Self> {
202        match token {
203            TokenTree::Ident(ident) => Some(Self::from_ident(ident)),
204            TokenTree::Literal(token) => {
205                if let Lit::Int(lit) = Lit::new(token.clone()) {
206                    if lit.suffix().is_empty() {
207                        if let Ok(idx) = lit.base10_parse() {
208                            return Some(Self::Unnamed(idx));
209                        }
210                    }
211                }
212                None
213            }
214            _ => None,
215        }
216    }
217
218    pub fn to_dummy_ident(&self) -> Ident {
219        Ident::new(&format!("_{}", self), Span::call_site())
220    }
221    pub fn to_valid_ident(&self) -> Option<Ident> {
222        match self {
223            Self::Named(name) => to_valid_ident(name).ok(),
224            Self::Unnamed(..) => None,
225        }
226    }
227}
228impl PartialEq<str> for FieldKey {
229    fn eq(&self, other: &str) -> bool {
230        match self {
231            FieldKey::Named(name) => name == other,
232            _ => false,
233        }
234    }
235}
236impl std::fmt::Display for FieldKey {
237    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238        match self {
239            Self::Named(name) => name.fmt(f),
240            Self::Unnamed(idx) => idx.fmt(f),
241        }
242    }
243}
244
245pub struct GenericParamSet {
246    idents: HashSet<Ident>,
247}
248
249impl GenericParamSet {
250    pub fn new(generics: &Generics) -> Self {
251        let mut idents = HashSet::new();
252        for p in &generics.params {
253            match p {
254                GenericParam::Type(t) => {
255                    idents.insert(t.ident.unraw());
256                }
257                GenericParam::Const(t) => {
258                    idents.insert(t.ident.unraw());
259                }
260                _ => {}
261            }
262        }
263        Self { idents }
264    }
265    fn contains(&self, ident: &Ident) -> bool {
266        self.idents.contains(&ident.unraw())
267    }
268
269    pub fn contains_in_type(&self, ty: &Type) -> bool {
270        struct Visitor<'a> {
271            generics: &'a GenericParamSet,
272            result: bool,
273        }
274        impl<'a, 'ast> Visit<'ast> for Visitor<'a> {
275            fn visit_path(&mut self, i: &'ast syn::Path) {
276                if i.leading_colon.is_none() {
277                    if let Some(s) = i.segments.iter().next() {
278                        if self.generics.contains(&s.ident) {
279                            self.result = true;
280                        }
281                    }
282                }
283                visit_path(self, i);
284            }
285        }
286        let mut visitor = Visitor {
287            generics: self,
288            result: false,
289        };
290        visit_type(&mut visitor, ty);
291        visitor.result
292    }
293}
294
295pub fn impl_trait(
296    input: &DeriveInput,
297    trait_path: &Path,
298    wheres: &[WherePredicate],
299    contents: TokenStream,
300) -> TokenStream {
301    let ty = &input.ident;
302    let (impl_g, ty_g, where_clause) = input.generics.split_for_impl();
303    let mut wheres = wheres.to_vec();
304    if let Some(where_clause) = where_clause {
305        wheres.extend(where_clause.predicates.iter().cloned());
306    }
307    let where_clause = if wheres.is_empty() {
308        quote! {}
309    } else {
310        quote! { where #(#wheres,)*}
311    };
312    quote! {
313        #[automatically_derived]
314        impl #impl_g #trait_path for #ty #ty_g #where_clause {
315            #contents
316        }
317    }
318}
319pub fn impl_trait_result(
320    input: &DeriveInput,
321    trait_path: &Path,
322    wheres: &[WherePredicate],
323    contents: TokenStream,
324    dump: bool,
325) -> Result<TokenStream> {
326    let ts = impl_trait(input, trait_path, wheres, contents);
327    if dump {
328        panic!("macro result: \n{}", ts);
329    }
330    Ok(ts)
331}
332
333pub fn to_valid_ident(s: &str) -> Result<Ident> {
334    if let Ok(ident) = parse_str(s) {
335        Ok(ident)
336    } else {
337        parse_str(&format!("r#{}", s))
338    }
339}
340
341pub fn parse_from_attrs<T: Parse + Default>(attrs: &[Attribute], name: &str) -> Result<T> {
342    let mut a = None;
343    for attr in attrs {
344        if attr.path.is_ident(name) {
345            if a.is_some() {
346                bail!(attr.span(), "attribute `{}` can specified only once", name);
347            }
348            a = Some(attr);
349        }
350    }
351    if let Some(a) = a {
352        a.parse_args()
353    } else {
354        Ok(T::default())
355    }
356}