nix_compat/wire/bytes/reader/
trailer.rs

1use std::{
2    fmt::Debug,
3    future::Future,
4    marker::PhantomData,
5    ops::Deref,
6    pin::Pin,
7    task::{self, ready, Poll},
8};
9
10use tokio::io::{self, AsyncRead, ReadBuf};
11
12/// Trailer represents up to 8 bytes of data read as part of the trailer block(s)
13#[derive(Debug)]
14pub(crate) struct Trailer {
15    data_len: u8,
16    buf: [u8; 8],
17}
18
19impl Deref for Trailer {
20    type Target = [u8];
21
22    fn deref(&self) -> &Self::Target {
23        &self.buf[..self.data_len as usize]
24    }
25}
26
27/// Tag defines a "trailer tag": specific, fixed bytes that must follow wire data.
28pub(crate) trait Tag {
29    /// The expected suffix
30    ///
31    /// The first 8 bytes may be ignored, and it must be an 8-byte aligned size.
32    const PATTERN: &'static [u8];
33
34    /// Suitably sized buffer for reading [Self::PATTERN]
35    ///
36    /// HACK: This is a workaround for const generics limitations.
37    type Buf: AsRef<[u8]> + AsMut<[u8]> + Debug + Unpin;
38
39    /// Make an instance of [Self::Buf]
40    fn make_buf() -> Self::Buf;
41}
42
43#[derive(Debug)]
44pub enum Pad {}
45
46impl Tag for Pad {
47    const PATTERN: &'static [u8] = &[0; 8];
48
49    type Buf = [u8; 8];
50
51    fn make_buf() -> Self::Buf {
52        [0; 8]
53    }
54}
55
56#[derive(Debug)]
57pub(crate) struct ReadTrailer<R, T: Tag> {
58    reader: R,
59    data_len: u8,
60    filled: u8,
61    buf: T::Buf,
62    _phantom: PhantomData<fn(T) -> T>,
63}
64
65/// read_trailer returns a [Future] that reads a trailer with a given [Tag] from `reader`
66pub(crate) fn read_trailer<R: AsyncRead + Unpin, T: Tag>(
67    reader: R,
68    data_len: u8,
69) -> ReadTrailer<R, T> {
70    assert!(data_len <= 8, "payload in trailer must be <= 8 bytes");
71
72    let buf = T::make_buf();
73    assert_eq!(buf.as_ref().len(), T::PATTERN.len());
74    assert_eq!(T::PATTERN.len() % 8, 0);
75
76    ReadTrailer {
77        reader,
78        data_len,
79        filled: if data_len != 0 { 0 } else { 8 },
80        buf,
81        _phantom: PhantomData,
82    }
83}
84
85impl<R, T: Tag> ReadTrailer<R, T> {
86    pub fn len(&self) -> u8 {
87        self.data_len
88    }
89}
90
91impl<R: AsyncRead + Unpin, T: Tag> Future for ReadTrailer<R, T> {
92    type Output = io::Result<Trailer>;
93
94    fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Self::Output> {
95        let this = &mut *self;
96
97        loop {
98            if this.filled >= this.data_len {
99                let check_range = || this.data_len as usize..this.filled as usize;
100
101                if this.buf.as_ref()[check_range()] != T::PATTERN[check_range()] {
102                    return Err(io::Error::new(
103                        io::ErrorKind::InvalidData,
104                        "invalid trailer",
105                    ))
106                    .into();
107                }
108            }
109
110            if this.filled as usize == T::PATTERN.len() {
111                let mut buf = [0; 8];
112                buf.copy_from_slice(&this.buf.as_ref()[..8]);
113
114                return Ok(Trailer {
115                    data_len: this.data_len,
116                    buf,
117                })
118                .into();
119            }
120
121            let mut buf = ReadBuf::new(this.buf.as_mut());
122            buf.advance(this.filled as usize);
123
124            ready!(Pin::new(&mut this.reader).poll_read(cx, &mut buf))?;
125
126            this.filled = {
127                let filled = buf.filled().len() as u8;
128
129                if filled == this.filled {
130                    return Err(io::ErrorKind::UnexpectedEof.into()).into();
131                }
132
133                filled
134            };
135        }
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use std::time::Duration;
142
143    use super::*;
144
145    #[tokio::test]
146    async fn unexpected_eof() {
147        let reader = tokio_test::io::Builder::new()
148            .read(&[0xed])
149            .wait(Duration::ZERO)
150            .read(&[0xef, 0x00])
151            .build();
152
153        assert_eq!(
154            read_trailer::<_, Pad>(reader, 2).await.unwrap_err().kind(),
155            io::ErrorKind::UnexpectedEof
156        );
157    }
158
159    #[tokio::test]
160    async fn invalid_padding() {
161        let reader = tokio_test::io::Builder::new()
162            .read(&[0xed])
163            .wait(Duration::ZERO)
164            .read(&[0xef, 0x01, 0x00])
165            .wait(Duration::ZERO)
166            .build();
167
168        assert_eq!(
169            read_trailer::<_, Pad>(reader, 2).await.unwrap_err().kind(),
170            io::ErrorKind::InvalidData
171        );
172    }
173
174    #[tokio::test]
175    async fn success() {
176        let reader = tokio_test::io::Builder::new()
177            .read(&[0xed])
178            .wait(Duration::ZERO)
179            .read(&[0xef, 0x00])
180            .wait(Duration::ZERO)
181            .read(&[0x00, 0x00, 0x00, 0x00, 0x00])
182            .build();
183
184        assert_eq!(
185            &*read_trailer::<_, Pad>(reader, 2).await.unwrap(),
186            &[0xed, 0xef]
187        );
188    }
189
190    #[tokio::test]
191    async fn no_padding() {
192        assert!(read_trailer::<_, Pad>(io::empty(), 0)
193            .await
194            .unwrap()
195            .is_empty());
196    }
197}