nix_compat/nix_daemon/framing/
framed_read.rs

1use std::{
2    num::NonZeroU64,
3    pin::Pin,
4    task::{self, ready, Poll},
5};
6
7use pin_project_lite::pin_project;
8use tokio::io::{self, AsyncRead, ReadBuf};
9
10/// State machine for [`NixFramedReader`].
11///
12/// We read length-prefixed chunks until we receive a zero-sized payload indicating EOF.
13/// Other than the zero-sized terminating chunk, chunk boundaries are not considered meaningful.
14/// Lengths are 64-bit little endian values on the wire.
15#[derive(Debug, Eq, PartialEq)]
16enum State {
17    Length { buf: [u8; 8], filled: u8 },
18    Chunk { remaining: NonZeroU64 },
19    Eof,
20}
21
22pin_project! {
23    /// Implements Nix's [Framed] reader protocol for protocol versions >= 1.23.
24    ///
25    /// Unexpected EOF on the underlying reader is returned as [UnexpectedEof][`std::io::ErrorKind::UnexpectedEof`].
26    /// True EOF (end-of-stream) is fused.
27    ///
28    /// [Framed]: https://snix.dev/docs/reference/nix-daemon-protocol/types/#framed
29    pub struct NixFramedReader<R> {
30        #[pin]
31        reader: R,
32        state: State,
33    }
34}
35
36impl<R> NixFramedReader<R> {
37    pub fn new(reader: R) -> Self {
38        Self {
39            reader,
40            state: State::Length {
41                buf: [0; 8],
42                filled: 0,
43            },
44        }
45    }
46
47    /// Returns `true` if the Nix Framed reader has reached EOF.
48    #[must_use]
49    pub fn is_eof(&self) -> bool {
50        matches!(self.state, State::Eof)
51    }
52}
53
54impl<R: AsyncRead> AsyncRead for NixFramedReader<R> {
55    fn poll_read(
56        mut self: Pin<&mut Self>,
57        cx: &mut task::Context<'_>,
58        buf: &mut ReadBuf<'_>,
59    ) -> Poll<io::Result<()>> {
60        let mut this = self.as_mut().project();
61
62        // reading nothing always succeeds
63        if buf.remaining() == 0 {
64            return Ok(()).into();
65        }
66
67        loop {
68            let reader = this.reader.as_mut();
69            match this.state {
70                State::Eof => {
71                    return Ok(()).into();
72                }
73                State::Length { buf, filled: 8 } => {
74                    *this.state = match NonZeroU64::new(u64::from_le_bytes(*buf)) {
75                        None => State::Eof,
76                        Some(remaining) => State::Chunk { remaining },
77                    };
78                }
79                State::Length { buf, filled } => {
80                    let bytes_read = {
81                        let mut b = ReadBuf::new(&mut buf[*filled as usize..]);
82                        ready!(reader.poll_read(cx, &mut b))?;
83                        b.filled().len() as u8
84                    };
85
86                    if bytes_read == 0 {
87                        return Err(io::ErrorKind::UnexpectedEof.into()).into();
88                    }
89
90                    *filled += bytes_read;
91                }
92                State::Chunk { remaining } => {
93                    let bytes_read = ready!(with_limited(buf, remaining.get(), |buf| {
94                        reader.poll_read(cx, buf).map_ok(|()| buf.filled().len())
95                    }))?;
96
97                    *this.state = match NonZeroU64::new(remaining.get() - bytes_read as u64) {
98                        None => State::Length {
99                            buf: [0; 8],
100                            filled: 0,
101                        },
102                        Some(remaining) => State::Chunk { remaining },
103                    };
104
105                    return if bytes_read == 0 {
106                        Err(io::ErrorKind::UnexpectedEof.into())
107                    } else {
108                        Ok(())
109                    }
110                    .into();
111                }
112            }
113        }
114    }
115}
116
117/// Make a limited version of `buf`, consisting only of up to `n` bytes of the unfilled section, and call `f` with it.
118/// After `f` returns, we propagate the filled cursor advancement back to `buf`.
119// TODO(edef): duplicate of src/wire/bytes/reader/mod.rs:with_limited
120fn with_limited<R>(buf: &mut ReadBuf, n: u64, f: impl FnOnce(&mut ReadBuf) -> R) -> R {
121    let mut nbuf = buf.take(n.try_into().unwrap_or(usize::MAX));
122    let ptr = nbuf.initialized().as_ptr();
123    let ret = f(&mut nbuf);
124
125    // SAFETY: `ReadBuf::take` only returns the *unfilled* section of `buf`,
126    // so anything filled is new, initialized data.
127    //
128    // We verify that `nbuf` still points to the same buffer,
129    // so we're sure it hasn't been swapped out.
130    unsafe {
131        // ensure our buffer hasn't been swapped out
132        assert_eq!(nbuf.initialized().as_ptr(), ptr);
133
134        let n = nbuf.filled().len();
135        buf.assume_init(n);
136        buf.advance(n);
137    }
138
139    ret
140}
141
142#[cfg(test)]
143mod nix_framed_tests {
144    use std::{
145        cmp::min,
146        pin::Pin,
147        task::{self, Poll},
148        time::Duration,
149    };
150
151    use tokio::io::{self, AsyncRead, AsyncReadExt, ReadBuf};
152    use tokio_test::io::Builder;
153
154    use crate::nix_daemon::framing::NixFramedReader;
155
156    #[tokio::test]
157    async fn read_unexpected_eof_after_frame() {
158        let mut mock = Builder::new()
159            // The client sends len
160            .read(&5u64.to_le_bytes())
161            // Immediately followed by the bytes
162            .read("hello".as_bytes())
163            .wait(Duration::ZERO)
164            // Send more data separately
165            .read(&6u64.to_le_bytes())
166            .read(" world".as_bytes())
167            // NOTE: no terminating zero
168            .build();
169
170        let mut reader = NixFramedReader::new(&mut mock);
171        let err = reader.read_to_string(&mut String::new()).await.unwrap_err();
172        assert!(!reader.is_eof());
173        assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
174    }
175
176    #[tokio::test]
177    async fn read_unexpected_eof_in_frame() {
178        let mut mock = Builder::new()
179            // The client sends len
180            .read(&5u64.to_le_bytes())
181            // Immediately followed by the bytes
182            .read("hello".as_bytes())
183            .wait(Duration::ZERO)
184            // Send more data separately
185            .read(&6u64.to_le_bytes())
186            .read(" worl".as_bytes())
187            // NOTE: we only sent five bytes of data before EOF
188            .build();
189
190        let mut reader = NixFramedReader::new(&mut mock);
191        let err = reader.read_to_string(&mut String::new()).await.unwrap_err();
192        assert!(!reader.is_eof());
193        assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
194    }
195
196    #[tokio::test]
197    async fn read_unexpected_eof_in_length() {
198        let mut mock = Builder::new()
199            // The client sends len
200            .read(&5u64.to_le_bytes())
201            // Immediately followed by the bytes
202            .read("hello".as_bytes())
203            .wait(Duration::ZERO)
204            // Send a truncated length header
205            .read(&[0; 7])
206            .build();
207
208        let mut reader = NixFramedReader::new(&mut mock);
209        let err = reader.read_to_string(&mut String::new()).await.unwrap_err();
210        assert!(!reader.is_eof());
211        assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
212    }
213
214    #[tokio::test]
215    async fn read_hello_world_in_two_frames() {
216        let mut mock = Builder::new()
217            // The client sends len
218            .read(&5u64.to_le_bytes())
219            // Immediately followed by the bytes
220            .read("hello".as_bytes())
221            .wait(Duration::ZERO)
222            // Send more data separately
223            .read(&6u64.to_le_bytes())
224            .read(" world".as_bytes())
225            .read(&0u64.to_le_bytes())
226            .build();
227
228        let mut reader = NixFramedReader::new(&mut mock);
229        let mut result = String::new();
230        reader
231            .read_to_string(&mut result)
232            .await
233            .expect("Could not read into result");
234        assert_eq!("hello world", result);
235        assert!(reader.is_eof());
236    }
237
238    struct SplitMock<'a> {
239        data: &'a [u8],
240        pending: bool,
241    }
242
243    impl<'a> SplitMock<'a> {
244        fn new(data: &'a [u8]) -> Self {
245            Self {
246                data,
247                pending: false,
248            }
249        }
250    }
251
252    impl AsyncRead for SplitMock<'_> {
253        fn poll_read(
254            mut self: Pin<&mut Self>,
255            _cx: &mut task::Context<'_>,
256            buf: &mut ReadBuf<'_>,
257        ) -> Poll<io::Result<()>> {
258            if self.data.is_empty() {
259                self.pending = true;
260                Poll::Pending
261            } else {
262                let n = min(buf.remaining(), self.data.len());
263                buf.put_slice(&self.data[..n]);
264                self.data = &self.data[n..];
265
266                Poll::Ready(Ok(()))
267            }
268        }
269    }
270
271    /// Somewhat of a fuzz test, ensuring that we end up in identical states for the same input,
272    /// independent of how it is spread across read calls and poll cycles.
273    #[test]
274    fn split_verif() {
275        let mut cx = task::Context::from_waker(task::Waker::noop());
276        let mut input = make_framed(&[b"hello", b"world", b"!", b""]);
277        let framed_end = input.len();
278        input.extend_from_slice(b"trailing data");
279
280        for end_point in 0..input.len() {
281            let input = &input[..end_point];
282
283            let unsplit_res = {
284                let mut dut = NixFramedReader::new(SplitMock::new(input));
285                let mut data_buf = vec![0; input.len()];
286                let mut read_buf = ReadBuf::new(&mut data_buf);
287
288                for _ in 0..256 {
289                    match Pin::new(&mut dut).poll_read(&mut cx, &mut read_buf) {
290                        Poll::Ready(res) => res.unwrap(),
291                        Poll::Pending => {
292                            assert!(dut.reader.pending);
293                            break;
294                        }
295                    }
296                }
297
298                let len = read_buf.filled().len();
299                data_buf.truncate(len);
300
301                assert_eq!(
302                    end_point >= framed_end,
303                    dut.is_eof(),
304                    "end_point = {end_point}, state = {:?}",
305                    dut.state
306                );
307                (dut.state, data_buf, dut.reader.data)
308            };
309
310            for split_point in 1..end_point.saturating_sub(1) {
311                let split_res = {
312                    let mut dut = NixFramedReader::new(SplitMock::new(&[]));
313                    let mut data_buf = vec![0; input.len()];
314                    let mut read_buf = ReadBuf::new(&mut data_buf);
315
316                    dut.reader = SplitMock::new(&input[..split_point]);
317                    for _ in 0..256 {
318                        match Pin::new(&mut dut).poll_read(&mut cx, &mut read_buf) {
319                            Poll::Ready(res) => res.unwrap(),
320                            Poll::Pending => {
321                                assert!(dut.reader.pending);
322                                break;
323                            }
324                        }
325                    }
326
327                    dut.reader = SplitMock::new(&input[split_point - dut.reader.data.len()..]);
328                    for _ in 0..256 {
329                        match Pin::new(&mut dut).poll_read(&mut cx, &mut read_buf) {
330                            Poll::Ready(res) => res.unwrap(),
331                            Poll::Pending => {
332                                assert!(dut.reader.pending);
333                                break;
334                            }
335                        }
336                    }
337
338                    let len = read_buf.filled().len();
339                    data_buf.truncate(len);
340
341                    (dut.state, data_buf, dut.reader.data)
342                };
343
344                assert_eq!(split_res, unsplit_res);
345            }
346        }
347    }
348
349    /// Make framed data, given frame contents. Terminating frame is *not* implicitly included.
350    /// Include an empty slice explicitly.
351    fn make_framed(frames: &[&[u8]]) -> Vec<u8> {
352        let mut buf = vec![];
353
354        for &data in frames {
355            buf.extend_from_slice(&(data.len() as u64).to_le_bytes());
356            buf.extend_from_slice(data);
357        }
358
359        buf
360    }
361}