nix_compat/wire/de/
reader.rs

1use std::future::poll_fn;
2use std::io::{self, Cursor};
3use std::ops::RangeInclusive;
4use std::pin::Pin;
5use std::task::{ready, Context, Poll};
6
7use bytes::{Buf, BufMut, Bytes, BytesMut};
8use pin_project_lite::pin_project;
9use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, ReadBuf};
10
11use crate::wire::{ProtocolVersion, EMPTY_BYTES};
12
13use super::{Error, NixRead};
14
15pub struct NixReaderBuilder {
16    buf: Option<BytesMut>,
17    reserved_buf_size: usize,
18    max_buf_size: usize,
19    version: ProtocolVersion,
20}
21
22impl Default for NixReaderBuilder {
23    fn default() -> Self {
24        Self {
25            buf: Default::default(),
26            reserved_buf_size: 8192,
27            max_buf_size: 8192,
28            version: Default::default(),
29        }
30    }
31}
32
33impl NixReaderBuilder {
34    pub fn set_buffer(mut self, buf: BytesMut) -> Self {
35        self.buf = Some(buf);
36        self
37    }
38
39    pub fn set_reserved_buf_size(mut self, size: usize) -> Self {
40        self.reserved_buf_size = size;
41        self
42    }
43
44    pub fn set_max_buf_size(mut self, size: usize) -> Self {
45        self.max_buf_size = size;
46        self
47    }
48
49    pub fn set_version(mut self, version: ProtocolVersion) -> Self {
50        self.version = version;
51        self
52    }
53
54    pub fn build<R>(self, reader: R) -> NixReader<R> {
55        let buf = self.buf.unwrap_or_else(|| BytesMut::with_capacity(0));
56        NixReader {
57            buf,
58            inner: reader,
59            reserved_buf_size: self.reserved_buf_size,
60            max_buf_size: self.max_buf_size,
61            version: self.version,
62        }
63    }
64}
65
66pin_project! {
67    pub struct NixReader<R> {
68        #[pin]
69        inner: R,
70        buf: BytesMut,
71        reserved_buf_size: usize,
72        max_buf_size: usize,
73        version: ProtocolVersion,
74    }
75}
76
77impl NixReader<Cursor<Vec<u8>>> {
78    pub fn builder() -> NixReaderBuilder {
79        NixReaderBuilder::default()
80    }
81}
82
83impl<R> NixReader<R>
84where
85    R: AsyncReadExt,
86{
87    pub fn new(reader: R) -> NixReader<R> {
88        NixReader::builder().build(reader)
89    }
90
91    pub fn buffer(&self) -> &[u8] {
92        &self.buf[..]
93    }
94
95    #[cfg(test)]
96    pub(crate) fn buffer_mut(&mut self) -> &mut BytesMut {
97        &mut self.buf
98    }
99
100    /// Remaining capacity in internal buffer
101    pub fn remaining_mut(&self) -> usize {
102        self.buf.capacity() - self.buf.len()
103    }
104
105    fn poll_force_fill_buf(
106        mut self: Pin<&mut Self>,
107        cx: &mut Context<'_>,
108    ) -> Poll<io::Result<usize>> {
109        // Ensure that buffer has space for at least reserved_buf_size bytes
110        if self.remaining_mut() < self.reserved_buf_size {
111            let me = self.as_mut().project();
112            me.buf.reserve(*me.reserved_buf_size);
113        }
114        let me = self.project();
115        let n = {
116            let dst = me.buf.spare_capacity_mut();
117            let mut buf = ReadBuf::uninit(dst);
118            let ptr = buf.filled().as_ptr();
119            ready!(me.inner.poll_read(cx, &mut buf)?);
120
121            // Ensure the pointer does not change from under us
122            assert_eq!(ptr, buf.filled().as_ptr());
123            buf.filled().len()
124        };
125
126        // SAFETY: This is guaranteed to be the number of initialized (and read)
127        // bytes due to the invariants provided by `ReadBuf::filled`.
128        unsafe {
129            me.buf.advance_mut(n);
130        }
131        Poll::Ready(Ok(n))
132    }
133}
134
135impl<R> NixReader<R>
136where
137    R: AsyncReadExt + Unpin,
138{
139    async fn force_fill(&mut self) -> io::Result<usize> {
140        let mut p = Pin::new(self);
141        let read = poll_fn(|cx| p.as_mut().poll_force_fill_buf(cx)).await?;
142        Ok(read)
143    }
144}
145
146impl<R> NixRead for NixReader<R>
147where
148    R: AsyncReadExt + Send + Unpin,
149{
150    type Error = io::Error;
151
152    fn version(&self) -> ProtocolVersion {
153        self.version
154    }
155
156    async fn try_read_number(&mut self) -> Result<Option<u64>, Self::Error> {
157        let mut buf = [0u8; 8];
158        let read = self.read_buf(&mut &mut buf[..]).await?;
159        if read == 0 {
160            return Ok(None);
161        }
162        if read < 8 {
163            self.read_exact(&mut buf[read..]).await?;
164        }
165        let num = Buf::get_u64_le(&mut &buf[..]);
166        Ok(Some(num))
167    }
168
169    async fn try_read_bytes_limited(
170        &mut self,
171        limit: RangeInclusive<usize>,
172    ) -> Result<Option<Bytes>, Self::Error> {
173        assert!(
174            *limit.end() <= self.max_buf_size,
175            "The limit must be smaller than {}",
176            self.max_buf_size
177        );
178        match self.try_read_number().await? {
179            Some(raw_len) => {
180                // Check that length is in range and convert to usize
181                let len = raw_len
182                    .try_into()
183                    .ok()
184                    .filter(|v| limit.contains(v))
185                    .ok_or_else(|| Self::Error::invalid_data("bytes length out of range"))?;
186
187                // Calculate 64bit aligned length and convert to usize
188                let aligned: usize = raw_len
189                    .checked_add(7)
190                    .map(|v| v & !7)
191                    .ok_or_else(|| Self::Error::invalid_data("bytes length out of range"))?
192                    .try_into()
193                    .map_err(Self::Error::invalid_data)?;
194
195                // Ensure that there is enough space in buffer for contents
196                if self.buf.len() + self.remaining_mut() < aligned {
197                    self.buf.reserve(aligned - self.buf.len());
198                }
199                while self.buf.len() < aligned {
200                    if self.force_fill().await? == 0 {
201                        return Err(Self::Error::missing_data(
202                            "unexpected end-of-file reading bytes",
203                        ));
204                    }
205                }
206                let mut contents = self.buf.split_to(aligned);
207
208                let padding = aligned - len;
209                // Ensure padding is all zeros
210                if contents[len..] != EMPTY_BYTES[..padding] {
211                    return Err(Self::Error::invalid_data("non-zero padding"));
212                }
213
214                contents.truncate(len);
215                Ok(Some(contents.freeze()))
216            }
217            None => Ok(None),
218        }
219    }
220
221    fn try_read_bytes(
222        &mut self,
223    ) -> impl std::future::Future<Output = Result<Option<Bytes>, Self::Error>> + Send + '_ {
224        self.try_read_bytes_limited(0..=self.max_buf_size)
225    }
226
227    fn read_bytes(
228        &mut self,
229    ) -> impl std::future::Future<Output = Result<Bytes, Self::Error>> + Send + '_ {
230        self.read_bytes_limited(0..=self.max_buf_size)
231    }
232}
233
234impl<R: AsyncRead> AsyncRead for NixReader<R> {
235    fn poll_read(
236        mut self: Pin<&mut Self>,
237        cx: &mut Context<'_>,
238        buf: &mut ReadBuf<'_>,
239    ) -> Poll<io::Result<()>> {
240        let rem = ready!(self.as_mut().poll_fill_buf(cx))?;
241        let amt = std::cmp::min(rem.len(), buf.remaining());
242        buf.put_slice(&rem[0..amt]);
243        self.consume(amt);
244        Poll::Ready(Ok(()))
245    }
246}
247
248impl<R: AsyncRead> AsyncBufRead for NixReader<R> {
249    fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
250        if self.as_ref().project_ref().buf.is_empty() {
251            ready!(self.as_mut().poll_force_fill_buf(cx))?;
252        }
253        let me = self.project();
254        Poll::Ready(Ok(&me.buf[..]))
255    }
256
257    fn consume(self: Pin<&mut Self>, amt: usize) {
258        let me = self.project();
259        me.buf.advance(amt)
260    }
261}
262
263#[cfg(test)]
264mod test {
265    use std::time::Duration;
266
267    use hex_literal::hex;
268    use rstest::rstest;
269    use tokio_test::io::Builder;
270
271    use super::*;
272    use crate::wire::de::NixRead;
273
274    #[tokio::test]
275    async fn test_read_u64() {
276        let mock = Builder::new().read(&hex!("0100 0000 0000 0000")).build();
277        let mut reader = NixReader::new(mock);
278
279        assert_eq!(1, reader.read_number().await.unwrap());
280        assert_eq!(hex!(""), reader.buffer());
281
282        let mut buf = Vec::new();
283        reader.read_to_end(&mut buf).await.unwrap();
284        assert_eq!(hex!(""), &buf[..]);
285    }
286
287    #[tokio::test]
288    async fn test_read_u64_rest() {
289        let mock = Builder::new()
290            .read(&hex!("0100 0000 0000 0000 0123 4567 89AB CDEF"))
291            .build();
292        let mut reader = NixReader::new(mock);
293
294        assert_eq!(1, reader.read_number().await.unwrap());
295        assert_eq!(hex!("0123 4567 89AB CDEF"), reader.buffer());
296
297        let mut buf = Vec::new();
298        reader.read_to_end(&mut buf).await.unwrap();
299        assert_eq!(hex!("0123 4567 89AB CDEF"), &buf[..]);
300    }
301
302    #[tokio::test]
303    async fn test_read_u64_partial() {
304        let mock = Builder::new()
305            .read(&hex!("0100 0000"))
306            .wait(Duration::ZERO)
307            .read(&hex!("0000 0000 0123 4567 89AB CDEF"))
308            .wait(Duration::ZERO)
309            .read(&hex!("0100 0000"))
310            .build();
311        let mut reader = NixReader::new(mock);
312
313        assert_eq!(1, reader.read_number().await.unwrap());
314        assert_eq!(hex!("0123 4567 89AB CDEF"), reader.buffer());
315
316        let mut buf = Vec::new();
317        reader.read_to_end(&mut buf).await.unwrap();
318        assert_eq!(hex!("0123 4567 89AB CDEF 0100 0000"), &buf[..]);
319    }
320
321    #[tokio::test]
322    async fn test_read_u64_eof() {
323        let mock = Builder::new().build();
324        let mut reader = NixReader::new(mock);
325
326        assert_eq!(
327            io::ErrorKind::UnexpectedEof,
328            reader.read_number().await.unwrap_err().kind()
329        );
330    }
331
332    #[tokio::test]
333    async fn test_try_read_u64_none() {
334        let mock = Builder::new().build();
335        let mut reader = NixReader::new(mock);
336
337        assert_eq!(None, reader.try_read_number().await.unwrap());
338    }
339
340    #[tokio::test]
341    async fn test_try_read_u64_eof() {
342        let mock = Builder::new().read(&hex!("0100 0000 0000")).build();
343        let mut reader = NixReader::new(mock);
344
345        assert_eq!(
346            io::ErrorKind::UnexpectedEof,
347            reader.try_read_number().await.unwrap_err().kind()
348        );
349    }
350
351    #[tokio::test]
352    async fn test_try_read_u64_eof2() {
353        let mock = Builder::new()
354            .read(&hex!("0100"))
355            .wait(Duration::ZERO)
356            .read(&hex!("0000 0000"))
357            .build();
358        let mut reader = NixReader::new(mock);
359
360        assert_eq!(
361            io::ErrorKind::UnexpectedEof,
362            reader.try_read_number().await.unwrap_err().kind()
363        );
364    }
365
366    #[rstest]
367    #[case::empty(b"", &hex!("0000 0000 0000 0000"))]
368    #[case::one(b")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))]
369    #[case::two(b"it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))]
370    #[case::three(b"tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))]
371    #[case::four(b"were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))]
372    #[case::five(b"where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))]
373    #[case::six(b"unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))]
374    #[case::seven(b"where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))]
375    #[case::aligned(b"read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))]
376    #[case::more_bytes(b"read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))]
377    #[tokio::test]
378    async fn test_read_bytes(#[case] expected: &[u8], #[case] data: &[u8]) {
379        let mock = Builder::new().read(data).build();
380        let mut reader = NixReader::new(mock);
381        let actual = reader.read_bytes().await.unwrap();
382        assert_eq!(&actual[..], expected);
383    }
384
385    #[tokio::test]
386    async fn test_read_bytes_empty() {
387        let mock = Builder::new().build();
388        let mut reader = NixReader::new(mock);
389
390        assert_eq!(
391            io::ErrorKind::UnexpectedEof,
392            reader.read_bytes().await.unwrap_err().kind()
393        );
394    }
395
396    #[tokio::test]
397    async fn test_try_read_bytes_none() {
398        let mock = Builder::new().build();
399        let mut reader = NixReader::new(mock);
400
401        assert_eq!(None, reader.try_read_bytes().await.unwrap());
402    }
403
404    #[tokio::test]
405    async fn test_try_read_bytes_missing_data() {
406        let mock = Builder::new()
407            .read(&hex!("0500"))
408            .wait(Duration::ZERO)
409            .read(&hex!("0000 0000"))
410            .build();
411        let mut reader = NixReader::new(mock);
412
413        assert_eq!(
414            io::ErrorKind::UnexpectedEof,
415            reader.try_read_bytes().await.unwrap_err().kind()
416        );
417    }
418
419    #[tokio::test]
420    async fn test_try_read_bytes_missing_padding() {
421        let mock = Builder::new()
422            .read(&hex!("0200 0000 0000 0000"))
423            .wait(Duration::ZERO)
424            .read(&hex!("1234"))
425            .build();
426        let mut reader = NixReader::new(mock);
427
428        assert_eq!(
429            io::ErrorKind::UnexpectedEof,
430            reader.try_read_bytes().await.unwrap_err().kind()
431        );
432    }
433
434    #[tokio::test]
435    async fn test_read_bytes_bad_padding() {
436        let mock = Builder::new()
437            .read(&hex!("0200 0000 0000 0000"))
438            .wait(Duration::ZERO)
439            .read(&hex!("1234 0100 0000 0000"))
440            .build();
441        let mut reader = NixReader::new(mock);
442
443        assert_eq!(
444            io::ErrorKind::InvalidData,
445            reader.read_bytes().await.unwrap_err().kind()
446        );
447    }
448
449    #[tokio::test]
450    async fn test_read_bytes_limited_out_of_range() {
451        let mock = Builder::new().read(&hex!("FFFF 0000 0000 0000")).build();
452        let mut reader = NixReader::new(mock);
453
454        assert_eq!(
455            io::ErrorKind::InvalidData,
456            reader.read_bytes_limited(0..=50).await.unwrap_err().kind()
457        );
458    }
459
460    #[tokio::test]
461    async fn test_read_bytes_length_overflow() {
462        let mock = Builder::new().read(&hex!("F9FF FFFF FFFF FFFF")).build();
463        let mut reader = NixReader::builder()
464            .set_max_buf_size(usize::MAX)
465            .build(mock);
466
467        assert_eq!(
468            io::ErrorKind::InvalidData,
469            reader
470                .read_bytes_limited(0..=usize::MAX)
471                .await
472                .unwrap_err()
473                .kind()
474        );
475    }
476
477    // FUTUREWORK: Test this on supported hardware
478    #[tokio::test]
479    #[cfg(any(target_pointer_width = "16", target_pointer_width = "32"))]
480    async fn test_bytes_length_conversion_overflow() {
481        let len = (usize::MAX as u64) + 1;
482        let mock = Builder::new().read(&len.to_le_bytes()).build();
483        let mut reader = NixReader::new(mock);
484        assert_eq!(
485            std::io::ErrorKind::InvalidData,
486            reader.read_value::<usize>().await.unwrap_err().kind()
487        );
488    }
489
490    // FUTUREWORK: Test this on supported hardware
491    #[tokio::test]
492    #[cfg(any(target_pointer_width = "16", target_pointer_width = "32"))]
493    async fn test_bytes_aligned_length_conversion_overflow() {
494        let len = (usize::MAX - 6) as u64;
495        let mock = Builder::new().read(&len.to_le_bytes()).build();
496        let mut reader = NixReader::new(mock);
497        assert_eq!(
498            std::io::ErrorKind::InvalidData,
499            reader.read_value::<usize>().await.unwrap_err().kind()
500        );
501    }
502
503    #[tokio::test]
504    async fn test_buffer_resize() {
505        let mock = Builder::new()
506            .read(&hex!("0100"))
507            .read(&hex!("0000 0000 0000"))
508            .build();
509        let mut reader = NixReader::builder().set_reserved_buf_size(8).build(mock);
510        // buffer has no capacity initially
511        assert_eq!(0, reader.buffer_mut().capacity());
512
513        assert_eq!(2, reader.force_fill().await.unwrap());
514
515        // After first read buffer should have capacity we chose
516        assert_eq!(8, reader.buffer_mut().capacity());
517
518        // Because there was only 6 bytes remaining in buffer,
519        // which is enough to read the last 6 bytes, but we require
520        // capacity for 8 bytes, it doubled the capacity
521        assert_eq!(6, reader.force_fill().await.unwrap());
522        assert_eq!(16, reader.buffer_mut().capacity());
523
524        assert_eq!(1, reader.read_number().await.unwrap());
525    }
526}