snix_serde/
de.rs

1//! Deserialisation from Nix to Rust values.
2
3use bstr::ByteSlice;
4use serde::de::value::{MapDeserializer, SeqDeserializer};
5use serde::de::{self, EnumAccess, VariantAccess};
6use snix_eval::{EvalIO, EvalMode, EvaluationBuilder, Value};
7
8use crate::error::Error;
9
10struct NixDeserializer {
11    value: snix_eval::Value,
12}
13
14impl NixDeserializer {
15    fn new(value: Value) -> Self {
16        if let Value::Thunk(thunk) = value {
17            Self::new(thunk.value().clone())
18        } else {
19            Self { value }
20        }
21    }
22}
23
24impl de::IntoDeserializer<'_, Error> for NixDeserializer {
25    type Deserializer = Self;
26
27    fn into_deserializer(self) -> Self::Deserializer {
28        self
29    }
30}
31
32/// Evaluate the Nix code in `src` and attempt to deserialise the
33/// value it returns to `T`.
34pub fn from_str<'code, T>(src: &'code str) -> Result<T, Error>
35where
36    T: serde::Deserialize<'code>,
37{
38    from_str_with_config(src, |b| /* no extra config */ b)
39}
40
41/// Evaluate the Nix code in `src`, with extra configuration for the
42/// `snix_eval::Evaluation` provided by the given closure.
43pub fn from_str_with_config<'code, T, F>(src: &'code str, config: F) -> Result<T, Error>
44where
45    T: serde::Deserialize<'code>,
46    F: for<'co, 'ro, 'env> FnOnce(
47        EvaluationBuilder<'co, 'ro, 'env, Box<dyn EvalIO>>,
48    ) -> EvaluationBuilder<'co, 'ro, 'env, Box<dyn EvalIO>>,
49{
50    // First step is to evaluate the Nix code ...
51    let eval = config(EvaluationBuilder::new_pure().mode(EvalMode::Strict)).build();
52    let result = eval.evaluate(src, None);
53
54    if !result.errors.is_empty() {
55        return Err(Error::NixErrors {
56            errors: result.errors,
57        });
58    }
59
60    let de = NixDeserializer::new(result.value.expect("value should be present on success"));
61
62    T::deserialize(de)
63}
64
65fn unexpected(expected: &'static str, got: &Value) -> Error {
66    Error::UnexpectedType {
67        expected,
68        got: got.type_of(),
69    }
70}
71
72fn visit_integer<I: TryFrom<i64>>(v: &Value) -> Result<I, Error> {
73    match v {
74        Value::Integer(i) => I::try_from(*i).map_err(|_| Error::IntegerConversion {
75            got: *i,
76            need: std::any::type_name::<I>(),
77        }),
78
79        _ => Err(unexpected("integer", v)),
80    }
81}
82
83impl<'de> de::Deserializer<'de> for NixDeserializer {
84    type Error = Error;
85
86    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
87    where
88        V: de::Visitor<'de>,
89    {
90        match self.value {
91            Value::Null => visitor.visit_unit(),
92            Value::Bool(b) => visitor.visit_bool(b),
93            Value::Integer(i) => visitor.visit_i64(i),
94            Value::Float(f) => visitor.visit_f64(f),
95            Value::String(s) => visitor.visit_string(s.to_string()),
96            Value::Path(p) => visitor.visit_string(p.to_string_lossy().into()), // TODO: hmm
97            Value::Attrs(_) => self.deserialize_map(visitor),
98            Value::List(_) => self.deserialize_seq(visitor),
99
100            // snix-eval types that can not be deserialized through serde.
101            Value::Closure(_)
102            | Value::Builtin(_)
103            | Value::Thunk(_)
104            | Value::AttrNotFound
105            | Value::Blueprint(_)
106            | Value::DeferredUpvalue(_)
107            | Value::UnresolvedPath(_)
108            | Value::Catchable(_)
109            | Value::FinaliseRequest(_) => Err(Error::Unserializable {
110                value_type: self.value.type_of(),
111            }),
112        }
113    }
114
115    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
116    where
117        V: de::Visitor<'de>,
118    {
119        match self.value {
120            Value::Bool(b) => visitor.visit_bool(b),
121            _ => Err(unexpected("bool", &self.value)),
122        }
123    }
124
125    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
126    where
127        V: de::Visitor<'de>,
128    {
129        visitor.visit_i8(visit_integer(&self.value)?)
130    }
131
132    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
133    where
134        V: de::Visitor<'de>,
135    {
136        visitor.visit_i16(visit_integer(&self.value)?)
137    }
138
139    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
140    where
141        V: de::Visitor<'de>,
142    {
143        visitor.visit_i32(visit_integer(&self.value)?)
144    }
145
146    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
147    where
148        V: de::Visitor<'de>,
149    {
150        visitor.visit_i64(visit_integer(&self.value)?)
151    }
152
153    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
154    where
155        V: de::Visitor<'de>,
156    {
157        visitor.visit_u8(visit_integer(&self.value)?)
158    }
159
160    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
161    where
162        V: de::Visitor<'de>,
163    {
164        visitor.visit_u16(visit_integer(&self.value)?)
165    }
166
167    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
168    where
169        V: de::Visitor<'de>,
170    {
171        visitor.visit_u32(visit_integer(&self.value)?)
172    }
173
174    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
175    where
176        V: de::Visitor<'de>,
177    {
178        visitor.visit_u64(visit_integer(&self.value)?)
179    }
180
181    fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
182    where
183        V: de::Visitor<'de>,
184    {
185        if let Value::Float(f) = self.value {
186            return visitor.visit_f32(f as f32);
187        }
188
189        Err(unexpected("float", &self.value))
190    }
191
192    fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
193    where
194        V: de::Visitor<'de>,
195    {
196        if let Value::Float(f) = self.value {
197            return visitor.visit_f64(f);
198        }
199
200        Err(unexpected("float", &self.value))
201    }
202
203    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
204    where
205        V: de::Visitor<'de>,
206    {
207        if let Value::String(s) = &self.value {
208            let chars = s.chars().collect::<Vec<_>>();
209            if chars.len() == 1 {
210                return visitor.visit_char(chars[0]);
211            }
212        }
213
214        Err(unexpected("char", &self.value))
215    }
216
217    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
218    where
219        V: de::Visitor<'de>,
220    {
221        if let Value::String(s) = &self.value {
222            if let Ok(s) = s.to_str() {
223                return visitor.visit_str(s);
224            }
225        }
226
227        Err(unexpected("string", &self.value))
228    }
229
230    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
231    where
232        V: de::Visitor<'de>,
233    {
234        if let Value::String(s) = &self.value {
235            if let Ok(s) = s.to_str() {
236                return visitor.visit_str(s);
237            }
238        }
239
240        Err(unexpected("string", &self.value))
241    }
242
243    fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
244    where
245        V: de::Visitor<'de>,
246    {
247        unimplemented!()
248    }
249
250    fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
251    where
252        V: de::Visitor<'de>,
253    {
254        unimplemented!()
255    }
256
257    // Note that this can not distinguish between a serialisation of
258    // `Some(())` and `None`.
259    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
260    where
261        V: de::Visitor<'de>,
262    {
263        if let Value::Null = self.value {
264            visitor.visit_none()
265        } else {
266            visitor.visit_some(self)
267        }
268    }
269
270    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
271    where
272        V: de::Visitor<'de>,
273    {
274        if let Value::Null = self.value {
275            return visitor.visit_unit();
276        }
277
278        Err(unexpected("null", &self.value))
279    }
280
281    fn deserialize_unit_struct<V>(
282        self,
283        _name: &'static str,
284        visitor: V,
285    ) -> Result<V::Value, Self::Error>
286    where
287        V: de::Visitor<'de>,
288    {
289        self.deserialize_unit(visitor)
290    }
291
292    fn deserialize_newtype_struct<V>(
293        self,
294        _name: &'static str,
295        visitor: V,
296    ) -> Result<V::Value, Self::Error>
297    where
298        V: de::Visitor<'de>,
299    {
300        visitor.visit_newtype_struct(self)
301    }
302
303    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
304    where
305        V: de::Visitor<'de>,
306    {
307        if let Value::List(list) = self.value {
308            let mut seq = SeqDeserializer::new(list.into_iter().map(NixDeserializer::new));
309            let result = visitor.visit_seq(&mut seq)?;
310            seq.end()?;
311            return Ok(result);
312        }
313
314        Err(unexpected("list", &self.value))
315    }
316
317    fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
318    where
319        V: de::Visitor<'de>,
320    {
321        // just represent tuples as lists ...
322        self.deserialize_seq(visitor)
323    }
324
325    fn deserialize_tuple_struct<V>(
326        self,
327        _name: &'static str,
328        _len: usize,
329        visitor: V,
330    ) -> Result<V::Value, Self::Error>
331    where
332        V: de::Visitor<'de>,
333    {
334        // same as above
335        self.deserialize_seq(visitor)
336    }
337
338    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
339    where
340        V: de::Visitor<'de>,
341    {
342        if let Value::Attrs(attrs) = self.value {
343            let mut map = MapDeserializer::new(attrs.into_iter().map(|(k, v)| {
344                (
345                    NixDeserializer::new(Value::from(k)),
346                    NixDeserializer::new(v),
347                )
348            }));
349            let result = visitor.visit_map(&mut map)?;
350            map.end()?;
351            return Ok(result);
352        }
353
354        Err(unexpected("map", &self.value))
355    }
356
357    fn deserialize_struct<V>(
358        self,
359        _name: &'static str,
360        _fields: &'static [&'static str],
361        visitor: V,
362    ) -> Result<V::Value, Self::Error>
363    where
364        V: de::Visitor<'de>,
365    {
366        self.deserialize_map(visitor)
367    }
368
369    // This method is responsible for deserializing the externally
370    // tagged enum variant serialisation.
371    fn deserialize_enum<V>(
372        self,
373        name: &'static str,
374        _variants: &'static [&'static str],
375        visitor: V,
376    ) -> Result<V::Value, Self::Error>
377    where
378        V: de::Visitor<'de>,
379    {
380        match self.value {
381            // a string represents a unit variant
382            Value::String(ref s) => {
383                if let Ok(s) = s.to_str() {
384                    visitor.visit_enum(de::value::StrDeserializer::new(s))
385                } else {
386                    Err(unexpected(name, &self.value))
387                }
388            }
389
390            // an attribute set however represents an externally
391            // tagged enum with content
392            Value::Attrs(attrs) => visitor.visit_enum(Enum(*attrs)),
393
394            _ => Err(unexpected(name, &self.value)),
395        }
396    }
397
398    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
399    where
400        V: de::Visitor<'de>,
401    {
402        self.deserialize_str(visitor)
403    }
404
405    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
406    where
407        V: de::Visitor<'de>,
408    {
409        visitor.visit_unit()
410    }
411}
412
413struct Enum(snix_eval::NixAttrs);
414
415impl<'de> EnumAccess<'de> for Enum {
416    type Error = Error;
417    type Variant = NixDeserializer;
418
419    // TODO: pass the known variants down here and check against them
420    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
421    where
422        V: de::DeserializeSeed<'de>,
423    {
424        if self.0.len() != 1 {
425            return Err(Error::AmbiguousEnum);
426        }
427
428        let (key, value) = self.0.into_iter().next().expect("length asserted above");
429        if let Ok(k) = key.to_str() {
430            let val = seed.deserialize(de::value::StrDeserializer::<Error>::new(k))?;
431            Ok((val, NixDeserializer::new(value)))
432        } else {
433            Err(unexpected("string", &key.clone().into()))
434        }
435    }
436}
437
438impl<'de> VariantAccess<'de> for NixDeserializer {
439    type Error = Error;
440
441    fn unit_variant(self) -> Result<(), Self::Error> {
442        // If this case is hit, a user specified the name of a unit
443        // enum variant but gave it content. Unit enum deserialisation
444        // is handled in `deserialize_enum` above.
445        Err(Error::UnitEnumContent)
446    }
447
448    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
449    where
450        T: de::DeserializeSeed<'de>,
451    {
452        seed.deserialize(self)
453    }
454
455    fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
456    where
457        V: de::Visitor<'de>,
458    {
459        de::Deserializer::deserialize_seq(self, visitor)
460    }
461
462    fn struct_variant<V>(
463        self,
464        _fields: &'static [&'static str],
465        visitor: V,
466    ) -> Result<V::Value, Self::Error>
467    where
468        V: de::Visitor<'de>,
469    {
470        de::Deserializer::deserialize_map(self, visitor)
471    }
472}