1use proc_macro2::{Group, Spacing, Span, TokenStream, TokenTree};
2use quote::{quote, ToTokens};
3use std::{
4 collections::{HashMap, HashSet},
5 iter::once,
6 ops::Deref,
7};
8use structmeta::{Parse, ToTokens};
9use syn::{
10 ext::IdentExt,
11 parenthesized,
12 parse::{Parse, ParseStream},
13 parse2, parse_str,
14 punctuated::Punctuated,
15 spanned::Spanned,
16 token::{Comma, Paren},
17 visit::{visit_path, visit_type, Visit},
18 Attribute, DeriveInput, Expr, Field, GenericParam, Generics, Ident, Lit, Path, Result, Token,
19 Type, WherePredicate,
20};
21
22macro_rules! bail {
23 ($span:expr, $message:literal $(,)?) => {
24 return std::result::Result::Err(syn::Error::new($span, $message))
25 };
26 ($span:expr, $err:expr $(,)?) => {
27 return std::result::Result::Err(syn::Error::new($span, $err))
28 };
29 ($span:expr, $fmt:expr, $($arg:tt)*) => {
30 return std::result::Result::Err(syn::Error::new($span, std::format!($fmt, $($arg)*)))
31 };
32}
33
34pub fn into_macro_output(input: Result<TokenStream>) -> proc_macro::TokenStream {
35 match input {
36 Ok(s) => s,
37 Err(e) => e.to_compile_error(),
38 }
39 .into()
40}
41
42pub struct Parenthesized<T> {
43 pub paren_token: Option<Paren>,
44 pub content: T,
45}
46impl<T: Parse> Parse for Parenthesized<T> {
47 fn parse(input: ParseStream) -> Result<Self> {
48 let content;
49 let paren_token = Some(parenthesized!(content in input));
50 let content = content.parse()?;
51 Ok(Self {
52 paren_token,
53 content,
54 })
55 }
56}
57impl<T> Deref for Parenthesized<T> {
58 type Target = T;
59 fn deref(&self) -> &Self::Target {
60 &self.content
61 }
62}
63
64pub fn parse_parenthesized_args(input: TokenStream) -> Result<Args> {
65 if input.is_empty() {
66 Ok(Args::new())
67 } else {
68 Ok(parse2::<Parenthesized<Args>>(input)?.content)
69 }
70}
71
72#[derive(Parse)]
73pub struct Args(#[parse(terminated)] Punctuated<Arg, Comma>);
74
75impl Args {
76 fn new() -> Self {
77 Self(Punctuated::new())
78 }
79 pub fn expect_single_value(&self, span: Span) -> Result<&Expr> {
80 if self.len() != 1 {
81 bail!(
82 span,
83 "expect 1 arguments, but supplied {} arguments.",
84 self.len()
85 );
86 }
87 match &self[0] {
88 Arg::Value(expr) => Ok(expr),
89 Arg::NameValue { .. } => bail!(span, "expected unnamed argument."),
90 }
91 }
92}
93impl Deref for Args {
94 type Target = Punctuated<Arg, Comma>;
95
96 fn deref(&self) -> &Self::Target {
97 &self.0
98 }
99}
100impl IntoIterator for Args {
101 type Item = Arg;
102 type IntoIter = <Punctuated<Arg, Comma> as IntoIterator>::IntoIter;
103
104 fn into_iter(self) -> Self::IntoIter {
105 self.0.into_iter()
106 }
107}
108
109#[derive(ToTokens, Parse)]
110pub enum Arg {
111 NameValue {
112 #[parse(peek, any)]
113 name: Ident,
114 #[parse(peek)]
115 eq_token: Token![=],
116 value: Expr,
117 },
118 Value(Expr),
119}
120
121pub struct SharpVals {
122 allow_vals: bool,
123 allow_self: bool,
124 pub vals: HashMap<FieldKey, Span>,
125 pub self_span: Option<Span>,
126}
127impl SharpVals {
128 pub fn new(allow_vals: bool, allow_self: bool) -> Self {
129 Self {
130 allow_vals,
131 allow_self,
132 vals: HashMap::new(),
133 self_span: None,
134 }
135 }
136 pub fn expand(&mut self, input: TokenStream) -> Result<TokenStream> {
137 let mut tokens = Vec::new();
138 let mut iter = input.into_iter().peekable();
139 while let Some(t) = iter.next() {
140 match &t {
141 TokenTree::Group(g) => {
142 tokens.push(TokenTree::Group(Group::new(
143 g.delimiter(),
144 self.expand(g.stream())?,
145 )));
146 continue;
147 }
148 TokenTree::Punct(p) => {
149 if p.as_char() == '#' && p.spacing() == Spacing::Alone {
150 if let Some(token) = iter.peek() {
151 if let Some(key) = FieldKey::try_from_token(token) {
152 let span = token.span();
153 let allow = if &key == "self" {
154 self.self_span.get_or_insert(span);
155 self.allow_self
156 } else {
157 self.vals.entry(key.clone()).or_insert(span);
158 self.allow_vals
159 };
160 if !allow {
161 bail!(span, "cannot use `#{}` in this position.", key);
162 }
163 if self.self_span.is_some() {
164 if let Some(key) = self.vals.keys().next() {
165 bail!(span, "cannot use both `#self` and `#{}`", key);
166 }
167 }
168 let mut ident = key.to_dummy_ident();
169 ident.set_span(span);
170 tokens.extend(ident.to_token_stream());
171 iter.next();
172 continue;
173 }
174 }
175 }
176 }
177 _ => {}
178 }
179 tokens.extend(once(t));
180 }
181 Ok(tokens.into_iter().collect())
182 }
183}
184#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
185pub enum FieldKey {
186 Named(String),
187 Unnamed(usize),
188}
189
190impl FieldKey {
191 pub fn from_ident(ident: &Ident) -> Self {
192 Self::Named(ident.unraw().to_string())
193 }
194 pub fn from_field(idx: usize, field: &Field) -> Self {
195 if let Some(ident) = &field.ident {
196 Self::from_ident(ident)
197 } else {
198 Self::Unnamed(idx)
199 }
200 }
201 pub fn try_from_token(token: &TokenTree) -> Option<Self> {
202 match token {
203 TokenTree::Ident(ident) => Some(Self::from_ident(ident)),
204 TokenTree::Literal(token) => {
205 if let Lit::Int(lit) = Lit::new(token.clone()) {
206 if lit.suffix().is_empty() {
207 if let Ok(idx) = lit.base10_parse() {
208 return Some(Self::Unnamed(idx));
209 }
210 }
211 }
212 None
213 }
214 _ => None,
215 }
216 }
217
218 pub fn to_dummy_ident(&self) -> Ident {
219 Ident::new(&format!("_{}", self), Span::call_site())
220 }
221 pub fn to_valid_ident(&self) -> Option<Ident> {
222 match self {
223 Self::Named(name) => to_valid_ident(name).ok(),
224 Self::Unnamed(..) => None,
225 }
226 }
227}
228impl PartialEq<str> for FieldKey {
229 fn eq(&self, other: &str) -> bool {
230 match self {
231 FieldKey::Named(name) => name == other,
232 _ => false,
233 }
234 }
235}
236impl std::fmt::Display for FieldKey {
237 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238 match self {
239 Self::Named(name) => name.fmt(f),
240 Self::Unnamed(idx) => idx.fmt(f),
241 }
242 }
243}
244
245pub struct GenericParamSet {
246 idents: HashSet<Ident>,
247}
248
249impl GenericParamSet {
250 pub fn new(generics: &Generics) -> Self {
251 let mut idents = HashSet::new();
252 for p in &generics.params {
253 match p {
254 GenericParam::Type(t) => {
255 idents.insert(t.ident.unraw());
256 }
257 GenericParam::Const(t) => {
258 idents.insert(t.ident.unraw());
259 }
260 _ => {}
261 }
262 }
263 Self { idents }
264 }
265 fn contains(&self, ident: &Ident) -> bool {
266 self.idents.contains(&ident.unraw())
267 }
268
269 pub fn contains_in_type(&self, ty: &Type) -> bool {
270 struct Visitor<'a> {
271 generics: &'a GenericParamSet,
272 result: bool,
273 }
274 impl<'a, 'ast> Visit<'ast> for Visitor<'a> {
275 fn visit_path(&mut self, i: &'ast syn::Path) {
276 if i.leading_colon.is_none() {
277 if let Some(s) = i.segments.iter().next() {
278 if self.generics.contains(&s.ident) {
279 self.result = true;
280 }
281 }
282 }
283 visit_path(self, i);
284 }
285 }
286 let mut visitor = Visitor {
287 generics: self,
288 result: false,
289 };
290 visit_type(&mut visitor, ty);
291 visitor.result
292 }
293}
294
295pub fn impl_trait(
296 input: &DeriveInput,
297 trait_path: &Path,
298 wheres: &[WherePredicate],
299 contents: TokenStream,
300) -> TokenStream {
301 let ty = &input.ident;
302 let (impl_g, ty_g, where_clause) = input.generics.split_for_impl();
303 let mut wheres = wheres.to_vec();
304 if let Some(where_clause) = where_clause {
305 wheres.extend(where_clause.predicates.iter().cloned());
306 }
307 let where_clause = if wheres.is_empty() {
308 quote! {}
309 } else {
310 quote! { where #(#wheres,)*}
311 };
312 quote! {
313 #[automatically_derived]
314 impl #impl_g #trait_path for #ty #ty_g #where_clause {
315 #contents
316 }
317 }
318}
319pub fn impl_trait_result(
320 input: &DeriveInput,
321 trait_path: &Path,
322 wheres: &[WherePredicate],
323 contents: TokenStream,
324 dump: bool,
325) -> Result<TokenStream> {
326 let ts = impl_trait(input, trait_path, wheres, contents);
327 if dump {
328 panic!("macro result: \n{}", ts);
329 }
330 Ok(ts)
331}
332
333pub fn to_valid_ident(s: &str) -> Result<Ident> {
334 if let Ok(ident) = parse_str(s) {
335 Ok(ident)
336 } else {
337 parse_str(&format!("r#{}", s))
338 }
339}
340
341pub fn parse_from_attrs<T: Parse + Default>(attrs: &[Attribute], name: &str) -> Result<T> {
342 let mut a = None;
343 for attr in attrs {
344 if attr.path.is_ident(name) {
345 if a.is_some() {
346 bail!(attr.span(), "attribute `{}` can specified only once", name);
347 }
348 a = Some(attr);
349 }
350 }
351 if let Some(a) = a {
352 a.parse_args()
353 } else {
354 Ok(T::default())
355 }
356}