nix_compat_derive/
de.rs

1use proc_macro2::{Span, TokenStream};
2use quote::{quote, quote_spanned, ToTokens};
3use syn::spanned::Spanned;
4use syn::{DeriveInput, Generics, Path, Type};
5
6use crate::internal::attrs::Default;
7use crate::internal::inputs::RemoteInput;
8use crate::internal::{attrs, Container, Context, Data, Field, Remote, Style, Variant};
9
10pub fn expand_nix_deserialize(
11    crate_path: Path,
12    input: &mut DeriveInput,
13) -> syn::Result<TokenStream> {
14    let cx = Context::new();
15    let cont = Container::from_ast(&cx, crate_path, input);
16    cx.check()?;
17    let cont = cont.unwrap();
18
19    let ty = cont.ident_type();
20    let body = nix_deserialize_body(&cont);
21    let crate_path = cont.crate_path();
22
23    Ok(nix_deserialize_impl(
24        crate_path,
25        &ty,
26        &cont.original.generics,
27        body,
28    ))
29}
30
31pub fn expand_nix_deserialize_remote(
32    crate_path: Path,
33    input: &RemoteInput,
34) -> syn::Result<TokenStream> {
35    let cx = Context::new();
36    let remote = Remote::from_ast(&cx, crate_path, input);
37    if let Some(attrs) = remote.as_ref().map(|r| &r.attrs) {
38        if attrs.from_str.is_none() && attrs.type_from.is_none() && attrs.type_try_from.is_none() {
39            cx.error_spanned(input, "Missing from_str, from or try_from attribute");
40        }
41    }
42    cx.check()?;
43    let remote = remote.unwrap();
44
45    let crate_path = remote.crate_path();
46    let body = nix_deserialize_body_from(crate_path, &remote.attrs).expect("From tokenstream");
47    let generics = Generics::default();
48    Ok(nix_deserialize_impl(crate_path, remote.ty, &generics, body))
49}
50
51fn nix_deserialize_impl(
52    crate_path: &Path,
53    ty: &Type,
54    generics: &Generics,
55    body: TokenStream,
56) -> TokenStream {
57    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
58
59    quote! {
60        #[automatically_derived]
61        impl #impl_generics #crate_path::wire::de::NixDeserialize for #ty #ty_generics
62            #where_clause
63        {
64            #[allow(clippy::manual_async_fn)]
65            fn try_deserialize<R>(reader: &mut R) -> impl ::std::future::Future<Output=Result<Option<Self>, R::Error>> + Send + '_
66                where R: ?Sized + #crate_path::wire::de::NixRead + Send,
67            {
68                #body
69            }
70        }
71    }
72}
73
74fn nix_deserialize_body_from(
75    crate_path: &syn::Path,
76    attrs: &attrs::Container,
77) -> Option<TokenStream> {
78    if let Some(span) = attrs.from_str.as_ref() {
79        Some(nix_deserialize_from_str(crate_path, span.span()))
80    } else if let Some(type_from) = attrs.type_from.as_ref() {
81        Some(nix_deserialize_from(type_from))
82    } else {
83        attrs
84            .type_try_from
85            .as_ref()
86            .map(|type_try_from| nix_deserialize_try_from(crate_path, type_try_from))
87    }
88}
89
90fn nix_deserialize_body(cont: &Container) -> TokenStream {
91    if let Some(tokens) = nix_deserialize_body_from(cont.crate_path(), &cont.attrs) {
92        tokens
93    } else {
94        match &cont.data {
95            Data::Struct(style, fields) => nix_deserialize_struct(*style, fields),
96            Data::Enum(variants) => nix_deserialize_enum(variants),
97        }
98    }
99}
100
101fn nix_deserialize_struct(style: Style, fields: &[Field<'_>]) -> TokenStream {
102    let read_fields = fields.iter().map(|f| {
103        let field = f.var_ident();
104        let ty = f.ty;
105        let read_value = quote_spanned! {
106            ty.span()=> if first__ {
107                first__ = false;
108                if let Some(v) = reader.try_read_value::<#ty>().await? {
109                    v
110                } else {
111                    return Ok(None);
112                }
113            } else {
114                reader.read_value::<#ty>().await?
115            }
116        };
117        if let Some(version) = f.attrs.version.as_ref() {
118            let default = match &f.attrs.default {
119                Default::Default(span) => {
120                    quote_spanned!(span.span()=>::std::default::Default::default)
121                }
122                Default::Path(path) => path.to_token_stream(),
123                _ => panic!("No default for versioned field"),
124            };
125            quote! {
126                let #field : #ty = if (#version).contains(&reader.version().minor()) {
127                    #read_value
128                } else {
129                    #default()
130                };
131            }
132        } else {
133            quote! {
134                let #field : #ty = #read_value;
135            }
136        }
137    });
138
139    let field_names = fields.iter().map(|f| f.var_ident());
140    let construct = match style {
141        Style::Struct => {
142            quote! {
143                Self { #(#field_names),* }
144            }
145        }
146        Style::Tuple => {
147            quote! {
148                Self(#(#field_names),*)
149            }
150        }
151        Style::Unit => quote!(Self),
152    };
153    quote! {
154        #[allow(unused_assignments)]
155        async move {
156            let mut first__ = true;
157            #(#read_fields)*
158            Ok(Some(#construct))
159        }
160    }
161}
162
163fn nix_deserialize_variant(variant: &Variant<'_>) -> TokenStream {
164    let ident = variant.ident;
165    let read_fields = variant.fields.iter().map(|f| {
166        let field = f.var_ident();
167        let ty = f.ty;
168        let read_value = quote_spanned! {
169            ty.span()=> if first__ {
170                first__ = false;
171                if let Some(v) = reader.try_read_value::<#ty>().await? {
172                    v
173                } else {
174                    return Ok(None);
175                }
176            } else {
177                reader.read_value::<#ty>().await?
178            }
179        };
180        if let Some(version) = f.attrs.version.as_ref() {
181            let default = match &f.attrs.default {
182                Default::Default(span) => {
183                    quote_spanned!(span.span()=>::std::default::Default::default)
184                }
185                Default::Path(path) => path.to_token_stream(),
186                _ => panic!("No default for versioned field"),
187            };
188            quote! {
189                let #field : #ty = if (#version).contains(&reader.version().minor()) {
190                    #read_value
191                } else {
192                    #default()
193                };
194            }
195        } else {
196            quote! {
197                let #field : #ty = #read_value;
198            }
199        }
200    });
201    let field_names = variant.fields.iter().map(|f| f.var_ident());
202    let construct = match variant.style {
203        Style::Struct => {
204            quote! {
205                Self::#ident { #(#field_names),* }
206            }
207        }
208        Style::Tuple => {
209            quote! {
210                Self::#ident(#(#field_names),*)
211            }
212        }
213        Style::Unit => quote!(Self::#ident),
214    };
215    let version = &variant.attrs.version;
216    quote! {
217        #version => {
218            #(#read_fields)*
219            Ok(Some(#construct))
220        }
221    }
222}
223
224fn nix_deserialize_enum(variants: &[Variant<'_>]) -> TokenStream {
225    let match_variant = variants
226        .iter()
227        .map(|variant| nix_deserialize_variant(variant));
228    quote! {
229        #[allow(unused_assignments)]
230        async move {
231            let mut first__ = true;
232            match reader.version().minor() {
233                #(#match_variant)*
234            }
235        }
236    }
237}
238
239fn nix_deserialize_from(ty: &Type) -> TokenStream {
240    quote_spanned! {
241        ty.span() =>
242        async move {
243            if let Some(value) = reader.try_read_value::<#ty>().await? {
244                Ok(Some(<Self as ::std::convert::From<#ty>>::from(value)))
245            } else {
246                Ok(None)
247            }
248        }
249    }
250}
251
252fn nix_deserialize_try_from(crate_path: &Path, ty: &Type) -> TokenStream {
253    quote_spanned! {
254        ty.span() =>
255        async move {
256            use #crate_path::wire::de::Error;
257            if let Some(item) = reader.try_read_value::<#ty>().await? {
258                <Self as ::std::convert::TryFrom<#ty>>::try_from(item)
259                    .map_err(Error::invalid_data)
260                    .map(Some)
261            } else {
262                Ok(None)
263            }
264        }
265    }
266}
267
268fn nix_deserialize_from_str(crate_path: &Path, span: Span) -> TokenStream {
269    quote_spanned! {
270        span =>
271        async move {
272            use #crate_path::wire::de::Error;
273            if let Some(buf) = reader.try_read_bytes().await? {
274                let s = ::std::str::from_utf8(&buf)
275                    .map_err(Error::invalid_data)?;
276                <Self as ::std::str::FromStr>::from_str(s)
277                    .map_err(Error::invalid_data)
278                    .map(Some)
279            } else {
280                Ok(None)
281            }
282        }
283    }
284}