nix_compat/nix_daemon/framing/
framed_read.rs

1use std::{
2    num::NonZeroU64,
3    pin::Pin,
4    task::{self, Poll, ready},
5};
6
7use pin_project_lite::pin_project;
8use tokio::io::{self, AsyncRead, AsyncReadExt, ReadBuf};
9
10/// State machine for [`NixFramedReader`].
11///
12/// We read length-prefixed chunks until we receive a zero-sized payload indicating EOF.
13/// Other than the zero-sized terminating chunk, chunk boundaries are not considered meaningful.
14/// Lengths are 64-bit little endian values on the wire.
15#[derive(Debug, Eq, PartialEq)]
16enum State {
17    Length { buf: [u8; 8], filled: u8 },
18    Chunk { remaining: NonZeroU64 },
19    Eof,
20}
21
22pin_project! {
23    /// Implements Nix's [Framed] reader protocol for protocol versions >= 1.23.
24    ///
25    /// Unexpected EOF on the underlying reader is returned as [UnexpectedEof][`std::io::ErrorKind::UnexpectedEof`].
26    /// True EOF (end-of-stream) is fused.
27    ///
28    /// [Framed]: https://snix.dev/docs/reference/nix-daemon-protocol/types/#framed
29    pub struct NixFramedReader<R> {
30        #[pin]
31        reader: R,
32        state: State,
33    }
34}
35
36impl<R> NixFramedReader<R> {
37    pub fn new(reader: R) -> Self {
38        Self {
39            reader,
40            state: State::Length {
41                buf: [0; 8],
42                filled: 0,
43            },
44        }
45    }
46}
47
48impl<R: AsyncRead + Unpin> NixFramedReader<R> {
49    /// Returns `true` if the Nix Framed reader has reached EOF.
50    pub async fn is_eof_unpin(&mut self) -> io::Result<bool> {
51        Pin::new(self).is_eof().await
52    }
53}
54
55impl<R: AsyncRead> NixFramedReader<R> {
56    /// Returns `true` if the Nix Framed reader has reached EOF.
57    pub async fn is_eof(self: Pin<&mut Self>) -> io::Result<bool> {
58        let mut this = self.project();
59        // we have have to ensure that we aren't just in [`State::Length`]
60        // with a pending terminating frame, since the NAR reader will not
61        // ever observe the EOF itself
62        loop {
63            match this.state {
64                State::Length { buf, filled: 8 } => {
65                    *this.state = match NonZeroU64::new(u64::from_le_bytes(*buf)) {
66                        None => State::Eof,
67                        Some(remaining) => State::Chunk { remaining },
68                    };
69                }
70                State::Length { buf, filled } => {
71                    let bytes_read = this.reader.read(&mut buf[*filled as usize..]).await? as u8;
72
73                    if bytes_read == 0 {
74                        return Err(io::ErrorKind::UnexpectedEof.into());
75                    }
76
77                    *filled += bytes_read;
78                }
79                State::Chunk { .. } => {
80                    return Ok(false);
81                }
82                State::Eof => {
83                    return Ok(true);
84                }
85            }
86        }
87    }
88}
89
90impl<R: AsyncRead> AsyncRead for NixFramedReader<R> {
91    fn poll_read(
92        mut self: Pin<&mut Self>,
93        cx: &mut task::Context<'_>,
94        buf: &mut ReadBuf<'_>,
95    ) -> Poll<io::Result<()>> {
96        let mut this = self.as_mut().project();
97
98        // reading nothing always succeeds
99        if buf.remaining() == 0 {
100            return Ok(()).into();
101        }
102
103        loop {
104            let reader = this.reader.as_mut();
105            match this.state {
106                State::Eof => {
107                    return Ok(()).into();
108                }
109                State::Length { buf, filled: 8 } => {
110                    *this.state = match NonZeroU64::new(u64::from_le_bytes(*buf)) {
111                        None => State::Eof,
112                        Some(remaining) => State::Chunk { remaining },
113                    };
114                }
115                State::Length { buf, filled } => {
116                    let bytes_read = {
117                        let mut b = ReadBuf::new(&mut buf[*filled as usize..]);
118                        ready!(reader.poll_read(cx, &mut b))?;
119                        b.filled().len() as u8
120                    };
121
122                    if bytes_read == 0 {
123                        return Err(io::ErrorKind::UnexpectedEof.into()).into();
124                    }
125
126                    *filled += bytes_read;
127                }
128                State::Chunk { remaining } => {
129                    let bytes_read = ready!(with_limited(buf, remaining.get(), |buf| {
130                        reader.poll_read(cx, buf).map_ok(|()| buf.filled().len())
131                    }))?;
132
133                    *this.state = match NonZeroU64::new(remaining.get() - bytes_read as u64) {
134                        None => State::Length {
135                            buf: [0; 8],
136                            filled: 0,
137                        },
138                        Some(remaining) => State::Chunk { remaining },
139                    };
140
141                    return if bytes_read == 0 {
142                        Err(io::ErrorKind::UnexpectedEof.into())
143                    } else {
144                        Ok(())
145                    }
146                    .into();
147                }
148            }
149        }
150    }
151}
152
153/// Make a limited version of `buf`, consisting only of up to `n` bytes of the unfilled section, and call `f` with it.
154/// After `f` returns, we propagate the filled cursor advancement back to `buf`.
155// TODO(edef): duplicate of src/wire/bytes/reader/mod.rs:with_limited
156fn with_limited<R>(buf: &mut ReadBuf, n: u64, f: impl FnOnce(&mut ReadBuf) -> R) -> R {
157    let mut nbuf = buf.take(n.try_into().unwrap_or(usize::MAX));
158    let ptr = nbuf.initialized().as_ptr();
159    let ret = f(&mut nbuf);
160
161    // SAFETY: `ReadBuf::take` only returns the *unfilled* section of `buf`,
162    // so anything filled is new, initialized data.
163    //
164    // We verify that `nbuf` still points to the same buffer,
165    // so we're sure it hasn't been swapped out.
166    unsafe {
167        // ensure our buffer hasn't been swapped out
168        assert_eq!(nbuf.initialized().as_ptr(), ptr);
169
170        let n = nbuf.filled().len();
171        buf.assume_init(n);
172        buf.advance(n);
173    }
174
175    ret
176}
177
178#[cfg(test)]
179mod nix_framed_tests {
180    use std::{
181        cmp::min,
182        pin::Pin,
183        task::{self, Poll},
184        time::Duration,
185    };
186
187    use tokio::io::{self, AsyncRead, AsyncReadExt, ReadBuf};
188    use tokio_test::io::Builder;
189
190    use crate::nix_daemon::framing::NixFramedReader;
191
192    #[tokio::test]
193    async fn read_unexpected_eof_after_frame() {
194        let mut mock = Builder::new()
195            // The client sends len
196            .read(&5u64.to_le_bytes())
197            // Immediately followed by the bytes
198            .read("hello".as_bytes())
199            .wait(Duration::ZERO)
200            // Send more data separately
201            .read(&6u64.to_le_bytes())
202            .read(" world".as_bytes())
203            // NOTE: no terminating zero
204            .build();
205
206        let mut reader = NixFramedReader::new(&mut mock);
207        let err = reader.read_to_string(&mut String::new()).await.unwrap_err();
208        assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
209        let err = reader.is_eof_unpin().await.unwrap_err();
210        assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
211    }
212
213    #[tokio::test]
214    async fn read_unexpected_eof_in_frame() {
215        let mut mock = Builder::new()
216            // The client sends len
217            .read(&5u64.to_le_bytes())
218            // Immediately followed by the bytes
219            .read("hello".as_bytes())
220            .wait(Duration::ZERO)
221            // Send more data separately
222            .read(&6u64.to_le_bytes())
223            .read(" worl".as_bytes())
224            // NOTE: we only sent five bytes of data before EOF
225            .build();
226
227        let mut reader = NixFramedReader::new(&mut mock);
228        let err = reader.read_to_string(&mut String::new()).await.unwrap_err();
229        assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
230        let is_eof = reader.is_eof_unpin().await.map_err(|e| e.kind());
231        assert!(matches!(
232            is_eof,
233            Ok(false) | Err(io::ErrorKind::UnexpectedEof)
234        ));
235    }
236
237    #[tokio::test]
238    async fn read_unexpected_eof_in_length() {
239        let mut mock = Builder::new()
240            // The client sends len
241            .read(&5u64.to_le_bytes())
242            // Immediately followed by the bytes
243            .read("hello".as_bytes())
244            .wait(Duration::ZERO)
245            // Send a truncated length header
246            .read(&[0; 7])
247            .build();
248
249        let mut reader = NixFramedReader::new(&mut mock);
250        let err = reader.read_to_string(&mut String::new()).await.unwrap_err();
251        assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
252        let err = reader.is_eof_unpin().await.unwrap_err();
253        assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
254    }
255
256    #[tokio::test]
257    async fn read_hello_world_in_two_frames() {
258        let mut mock = Builder::new()
259            // The client sends len
260            .read(&5u64.to_le_bytes())
261            // Immediately followed by the bytes
262            .read("hello".as_bytes())
263            .wait(Duration::ZERO)
264            // Send more data separately
265            .read(&6u64.to_le_bytes())
266            .read(" world".as_bytes())
267            .read(&0u64.to_le_bytes())
268            .build();
269
270        let mut reader = NixFramedReader::new(&mut mock);
271        let mut result = String::new();
272        reader
273            .read_to_string(&mut result)
274            .await
275            .expect("Could not read into result");
276        assert_eq!("hello world", result);
277        assert!(reader.is_eof_unpin().await.unwrap());
278    }
279
280    struct SplitMock<'a> {
281        data: &'a [u8],
282        pending: bool,
283    }
284
285    impl<'a> SplitMock<'a> {
286        fn new(data: &'a [u8]) -> Self {
287            Self {
288                data,
289                pending: false,
290            }
291        }
292    }
293
294    impl AsyncRead for SplitMock<'_> {
295        fn poll_read(
296            mut self: Pin<&mut Self>,
297            _cx: &mut task::Context<'_>,
298            buf: &mut ReadBuf<'_>,
299        ) -> Poll<io::Result<()>> {
300            if self.data.is_empty() {
301                self.pending = true;
302                Poll::Pending
303            } else {
304                let n = min(buf.remaining(), self.data.len());
305                buf.put_slice(&self.data[..n]);
306                self.data = &self.data[n..];
307
308                Poll::Ready(Ok(()))
309            }
310        }
311    }
312
313    /// Somewhat of a fuzz test, ensuring that we end up in identical states for the same input,
314    /// independent of how it is spread across read calls and poll cycles.
315    #[test]
316    fn split_verif() {
317        let mut cx = task::Context::from_waker(task::Waker::noop());
318        let mut input = make_framed(&[b"hello", b"world", b"!", b""]);
319        let framed_end = input.len();
320        input.extend_from_slice(b"trailing data");
321
322        for end_point in 0..input.len() {
323            let input = &input[..end_point];
324
325            let unsplit_res = {
326                let mut dut = NixFramedReader::new(SplitMock::new(input));
327                let mut data_buf = vec![0; input.len()];
328                let mut read_buf = ReadBuf::new(&mut data_buf);
329
330                for _ in 0..256 {
331                    match Pin::new(&mut dut).poll_read(&mut cx, &mut read_buf) {
332                        Poll::Ready(res) => res.unwrap(),
333                        Poll::Pending => {
334                            assert!(dut.reader.pending);
335                            break;
336                        }
337                    }
338                }
339
340                let len = read_buf.filled().len();
341                data_buf.truncate(len);
342
343                assert_eq!(
344                    end_point >= framed_end,
345                    matches!(dut.state, super::State::Eof),
346                    "end_point = {end_point}, state = {:?}",
347                    dut.state
348                );
349                (dut.state, data_buf, dut.reader.data)
350            };
351
352            for split_point in 1..end_point.saturating_sub(1) {
353                let split_res = {
354                    let mut dut = NixFramedReader::new(SplitMock::new(&[]));
355                    let mut data_buf = vec![0; input.len()];
356                    let mut read_buf = ReadBuf::new(&mut data_buf);
357
358                    dut.reader = SplitMock::new(&input[..split_point]);
359                    for _ in 0..256 {
360                        match Pin::new(&mut dut).poll_read(&mut cx, &mut read_buf) {
361                            Poll::Ready(res) => res.unwrap(),
362                            Poll::Pending => {
363                                assert!(dut.reader.pending);
364                                break;
365                            }
366                        }
367                    }
368
369                    dut.reader = SplitMock::new(&input[split_point - dut.reader.data.len()..]);
370                    for _ in 0..256 {
371                        match Pin::new(&mut dut).poll_read(&mut cx, &mut read_buf) {
372                            Poll::Ready(res) => res.unwrap(),
373                            Poll::Pending => {
374                                assert!(dut.reader.pending);
375                                break;
376                            }
377                        }
378                    }
379
380                    let len = read_buf.filled().len();
381                    data_buf.truncate(len);
382
383                    (dut.state, data_buf, dut.reader.data)
384                };
385
386                assert_eq!(split_res, unsplit_res);
387            }
388        }
389    }
390
391    /// Make framed data, given frame contents. Terminating frame is *not* implicitly included.
392    /// Include an empty slice explicitly.
393    fn make_framed(frames: &[&[u8]]) -> Vec<u8> {
394        let mut buf = vec![];
395
396        for &data in frames {
397            buf.extend_from_slice(&(data.len() as u64).to_le_bytes());
398            buf.extend_from_slice(data);
399        }
400
401        buf
402    }
403}