nix_compat/wire/bytes/reader/
trailer.rs1use std::{
2 fmt::Debug,
3 future::Future,
4 marker::PhantomData,
5 ops::Deref,
6 pin::Pin,
7 task::{self, ready, Poll},
8};
9
10use tokio::io::{self, AsyncRead, ReadBuf};
11
12#[derive(Debug)]
14pub(crate) struct Trailer {
15 data_len: u8,
16 buf: [u8; 8],
17}
18
19impl Deref for Trailer {
20 type Target = [u8];
21
22 fn deref(&self) -> &Self::Target {
23 &self.buf[..self.data_len as usize]
24 }
25}
26
27pub(crate) trait Tag {
29 const PATTERN: &'static [u8];
33
34 type Buf: AsRef<[u8]> + AsMut<[u8]> + Debug + Unpin;
38
39 fn make_buf() -> Self::Buf;
41}
42
43#[derive(Debug)]
44pub enum Pad {}
45
46impl Tag for Pad {
47 const PATTERN: &'static [u8] = &[0; 8];
48
49 type Buf = [u8; 8];
50
51 fn make_buf() -> Self::Buf {
52 [0; 8]
53 }
54}
55
56#[derive(Debug)]
57pub(crate) struct ReadTrailer<R, T: Tag> {
58 reader: R,
59 data_len: u8,
60 filled: u8,
61 buf: T::Buf,
62 _phantom: PhantomData<fn(T) -> T>,
63}
64
65pub(crate) fn read_trailer<R: AsyncRead + Unpin, T: Tag>(
67 reader: R,
68 data_len: u8,
69) -> ReadTrailer<R, T> {
70 assert!(data_len <= 8, "payload in trailer must be <= 8 bytes");
71
72 let buf = T::make_buf();
73 assert_eq!(buf.as_ref().len(), T::PATTERN.len());
74 assert_eq!(T::PATTERN.len() % 8, 0);
75
76 ReadTrailer {
77 reader,
78 data_len,
79 filled: if data_len != 0 { 0 } else { 8 },
80 buf,
81 _phantom: PhantomData,
82 }
83}
84
85impl<R, T: Tag> ReadTrailer<R, T> {
86 pub fn len(&self) -> u8 {
87 self.data_len
88 }
89}
90
91impl<R: AsyncRead + Unpin, T: Tag> Future for ReadTrailer<R, T> {
92 type Output = io::Result<Trailer>;
93
94 fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Self::Output> {
95 let this = &mut *self;
96
97 loop {
98 if this.filled >= this.data_len {
99 let check_range = || this.data_len as usize..this.filled as usize;
100
101 if this.buf.as_ref()[check_range()] != T::PATTERN[check_range()] {
102 return Err(io::Error::new(
103 io::ErrorKind::InvalidData,
104 "invalid trailer",
105 ))
106 .into();
107 }
108 }
109
110 if this.filled as usize == T::PATTERN.len() {
111 let mut buf = [0; 8];
112 buf.copy_from_slice(&this.buf.as_ref()[..8]);
113
114 return Ok(Trailer {
115 data_len: this.data_len,
116 buf,
117 })
118 .into();
119 }
120
121 let mut buf = ReadBuf::new(this.buf.as_mut());
122 buf.advance(this.filled as usize);
123
124 ready!(Pin::new(&mut this.reader).poll_read(cx, &mut buf))?;
125
126 this.filled = {
127 let filled = buf.filled().len() as u8;
128
129 if filled == this.filled {
130 return Err(io::ErrorKind::UnexpectedEof.into()).into();
131 }
132
133 filled
134 };
135 }
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use std::time::Duration;
142
143 use super::*;
144
145 #[tokio::test]
146 async fn unexpected_eof() {
147 let reader = tokio_test::io::Builder::new()
148 .read(&[0xed])
149 .wait(Duration::ZERO)
150 .read(&[0xef, 0x00])
151 .build();
152
153 assert_eq!(
154 read_trailer::<_, Pad>(reader, 2).await.unwrap_err().kind(),
155 io::ErrorKind::UnexpectedEof
156 );
157 }
158
159 #[tokio::test]
160 async fn invalid_padding() {
161 let reader = tokio_test::io::Builder::new()
162 .read(&[0xed])
163 .wait(Duration::ZERO)
164 .read(&[0xef, 0x01, 0x00])
165 .wait(Duration::ZERO)
166 .build();
167
168 assert_eq!(
169 read_trailer::<_, Pad>(reader, 2).await.unwrap_err().kind(),
170 io::ErrorKind::InvalidData
171 );
172 }
173
174 #[tokio::test]
175 async fn success() {
176 let reader = tokio_test::io::Builder::new()
177 .read(&[0xed])
178 .wait(Duration::ZERO)
179 .read(&[0xef, 0x00])
180 .wait(Duration::ZERO)
181 .read(&[0x00, 0x00, 0x00, 0x00, 0x00])
182 .build();
183
184 assert_eq!(
185 &*read_trailer::<_, Pad>(reader, 2).await.unwrap(),
186 &[0xed, 0xef]
187 );
188 }
189
190 #[tokio::test]
191 async fn no_padding() {
192 assert!(read_trailer::<_, Pad>(io::empty(), 0)
193 .await
194 .unwrap()
195 .is_empty());
196 }
197}