nix_compat_derive/
de.rs

1use proc_macro2::{Span, TokenStream};
2use quote::{ToTokens, quote, quote_spanned};
3use syn::spanned::Spanned;
4use syn::{DeriveInput, Generics, Path, Type};
5
6use crate::internal::attrs::Default;
7use crate::internal::inputs::RemoteInput;
8use crate::internal::{Container, Context, Data, Field, Remote, Style, Variant, attrs};
9
10pub fn expand_nix_deserialize(
11    crate_path: Path,
12    input: &mut DeriveInput,
13) -> syn::Result<TokenStream> {
14    let cx = Context::new();
15    let cont = Container::from_ast(&cx, crate_path, input);
16    cx.check()?;
17    let cont = cont.unwrap();
18
19    let ty = cont.ident_type();
20    let body = nix_deserialize_body(&cont);
21    let crate_path = cont.crate_path();
22
23    Ok(nix_deserialize_impl(
24        crate_path,
25        &ty,
26        &cont.original.generics,
27        body,
28    ))
29}
30
31pub fn expand_nix_deserialize_remote(
32    crate_path: Path,
33    input: &RemoteInput,
34) -> syn::Result<TokenStream> {
35    let cx = Context::new();
36    let remote = Remote::from_ast(&cx, crate_path, input);
37    if let Some(attrs) = remote.as_ref().map(|r| &r.attrs)
38        && attrs.from_str.is_none()
39        && attrs.type_from.is_none()
40        && attrs.type_try_from.is_none()
41    {
42        cx.error_spanned(input, "Missing from_str, from or try_from attribute");
43    }
44    cx.check()?;
45    let remote = remote.unwrap();
46
47    let crate_path = remote.crate_path();
48    let body = nix_deserialize_body_from(crate_path, &remote.attrs).expect("From tokenstream");
49    let generics = Generics::default();
50    Ok(nix_deserialize_impl(crate_path, remote.ty, &generics, body))
51}
52
53fn nix_deserialize_impl(
54    crate_path: &Path,
55    ty: &Type,
56    generics: &Generics,
57    body: TokenStream,
58) -> TokenStream {
59    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
60
61    quote! {
62        #[automatically_derived]
63        impl #impl_generics #crate_path::wire::de::NixDeserialize for #ty #ty_generics
64            #where_clause
65        {
66            #[allow(clippy::manual_async_fn)]
67            fn try_deserialize<R>(reader: &mut R) -> impl ::std::future::Future<Output=Result<Option<Self>, R::Error>> + Send + '_
68                where R: ?Sized + #crate_path::wire::de::NixRead + Send,
69            {
70                #body
71            }
72        }
73    }
74}
75
76fn nix_deserialize_body_from(
77    crate_path: &syn::Path,
78    attrs: &attrs::Container,
79) -> Option<TokenStream> {
80    if let Some(span) = attrs.from_str.as_ref() {
81        Some(nix_deserialize_from_str(crate_path, span.span()))
82    } else if let Some(type_from) = attrs.type_from.as_ref() {
83        Some(nix_deserialize_from(type_from))
84    } else {
85        attrs
86            .type_try_from
87            .as_ref()
88            .map(|type_try_from| nix_deserialize_try_from(crate_path, type_try_from))
89    }
90}
91
92fn nix_deserialize_body(cont: &Container) -> TokenStream {
93    if let Some(tokens) = nix_deserialize_body_from(cont.crate_path(), &cont.attrs) {
94        tokens
95    } else {
96        match &cont.data {
97            Data::Struct(style, fields) => nix_deserialize_struct(*style, fields),
98            Data::Enum(variants) => nix_deserialize_enum(variants),
99        }
100    }
101}
102
103fn nix_deserialize_struct(style: Style, fields: &[Field<'_>]) -> TokenStream {
104    let read_fields = fields.iter().map(|f| {
105        let field = f.var_ident();
106        let ty = f.ty;
107        let read_value = quote_spanned! {
108            ty.span()=> if first__ {
109                first__ = false;
110                if let Some(v) = reader.try_read_value::<#ty>().await? {
111                    v
112                } else {
113                    return Ok(None);
114                }
115            } else {
116                reader.read_value::<#ty>().await?
117            }
118        };
119        if let Some(version) = f.attrs.version.as_ref() {
120            let default = match &f.attrs.default {
121                Default::Default(span) => {
122                    quote_spanned!(span.span()=>::std::default::Default::default)
123                }
124                Default::Path(path) => path.to_token_stream(),
125                _ => panic!("No default for versioned field"),
126            };
127            quote! {
128                let #field : #ty = if (#version).contains(&reader.version().minor()) {
129                    #read_value
130                } else {
131                    #default()
132                };
133            }
134        } else {
135            quote! {
136                let #field : #ty = #read_value;
137            }
138        }
139    });
140
141    let field_names = fields.iter().map(|f| f.var_ident());
142    let construct = match style {
143        Style::Struct => {
144            quote! {
145                Self { #(#field_names),* }
146            }
147        }
148        Style::Tuple => {
149            quote! {
150                Self(#(#field_names),*)
151            }
152        }
153        Style::Unit => quote!(Self),
154    };
155    quote! {
156        #[allow(unused_assignments)]
157        async move {
158            let mut first__ = true;
159            #(#read_fields)*
160            Ok(Some(#construct))
161        }
162    }
163}
164
165fn nix_deserialize_variant(variant: &Variant<'_>) -> TokenStream {
166    let ident = variant.ident;
167    let read_fields = variant.fields.iter().map(|f| {
168        let field = f.var_ident();
169        let ty = f.ty;
170        let read_value = quote_spanned! {
171            ty.span()=> if first__ {
172                first__ = false;
173                if let Some(v) = reader.try_read_value::<#ty>().await? {
174                    v
175                } else {
176                    return Ok(None);
177                }
178            } else {
179                reader.read_value::<#ty>().await?
180            }
181        };
182        if let Some(version) = f.attrs.version.as_ref() {
183            let default = match &f.attrs.default {
184                Default::Default(span) => {
185                    quote_spanned!(span.span()=>::std::default::Default::default)
186                }
187                Default::Path(path) => path.to_token_stream(),
188                _ => panic!("No default for versioned field"),
189            };
190            quote! {
191                let #field : #ty = if (#version).contains(&reader.version().minor()) {
192                    #read_value
193                } else {
194                    #default()
195                };
196            }
197        } else {
198            quote! {
199                let #field : #ty = #read_value;
200            }
201        }
202    });
203    let field_names = variant.fields.iter().map(|f| f.var_ident());
204    let construct = match variant.style {
205        Style::Struct => {
206            quote! {
207                Self::#ident { #(#field_names),* }
208            }
209        }
210        Style::Tuple => {
211            quote! {
212                Self::#ident(#(#field_names),*)
213            }
214        }
215        Style::Unit => quote!(Self::#ident),
216    };
217    let version = &variant.attrs.version;
218    quote! {
219        #version => {
220            #(#read_fields)*
221            Ok(Some(#construct))
222        }
223    }
224}
225
226fn nix_deserialize_enum(variants: &[Variant<'_>]) -> TokenStream {
227    let match_variant = variants
228        .iter()
229        .map(|variant| nix_deserialize_variant(variant));
230    quote! {
231        #[allow(unused_assignments)]
232        async move {
233            let mut first__ = true;
234            match reader.version().minor() {
235                #(#match_variant)*
236            }
237        }
238    }
239}
240
241fn nix_deserialize_from(ty: &Type) -> TokenStream {
242    quote_spanned! {
243        ty.span() =>
244        async move {
245            if let Some(value) = reader.try_read_value::<#ty>().await? {
246                Ok(Some(<Self as ::std::convert::From<#ty>>::from(value)))
247            } else {
248                Ok(None)
249            }
250        }
251    }
252}
253
254fn nix_deserialize_try_from(crate_path: &Path, ty: &Type) -> TokenStream {
255    quote_spanned! {
256        ty.span() =>
257        async move {
258            use #crate_path::wire::de::Error;
259            if let Some(item) = reader.try_read_value::<#ty>().await? {
260                <Self as ::std::convert::TryFrom<#ty>>::try_from(item)
261                    .map_err(Error::invalid_data)
262                    .map(Some)
263            } else {
264                Ok(None)
265            }
266        }
267    }
268}
269
270fn nix_deserialize_from_str(crate_path: &Path, span: Span) -> TokenStream {
271    quote_spanned! {
272        span =>
273        async move {
274            use #crate_path::wire::de::Error;
275            if let Some(buf) = reader.try_read_bytes().await? {
276                let s = ::std::str::from_utf8(&buf)
277                    .map_err(Error::invalid_data)?;
278                <Self as ::std::str::FromStr>::from_str(s)
279                    .map_err(Error::invalid_data)
280                    .map(Some)
281            } else {
282                Ok(None)
283            }
284        }
285    }
286}