nix_compat/wire/bytes/
mod.rs

1#[cfg(feature = "async")]
2use std::mem::MaybeUninit;
3use std::{
4    io::{Error, ErrorKind},
5    ops::RangeInclusive,
6};
7#[cfg(feature = "async")]
8use tokio::io::ReadBuf;
9use tokio::io::{self, AsyncReadExt, AsyncWriteExt};
10
11pub(crate) mod reader;
12pub use reader::BytesReader;
13mod writer;
14pub use writer::BytesWriter;
15
16/// 8 null bytes, used to write out padding.
17pub(crate) const EMPTY_BYTES: &[u8; 8] = &[0u8; 8];
18
19/// The length of the size field, in bytes is always 8.
20const LEN_SIZE: usize = 8;
21
22/// Read a "bytes wire packet" from the AsyncRead.
23/// Rejects reading more than `allowed_size` bytes of payload.
24///
25/// The packet is made up of three parts:
26/// - a length header, u64, LE-encoded
27/// - the payload itself
28/// - null bytes to the next 8 byte boundary
29///
30/// Ensures the payload size fits into the `allowed_size` passed,
31/// and that the padding is actual null bytes.
32///
33/// On success, the returned `Vec<u8>` only contains the payload itself.
34/// On failure (for example if a too large byte packet was sent), the reader
35/// becomes unusable.
36///
37/// This buffers the entire payload into memory,
38/// a streaming version is available at [crate::wire::bytes::BytesReader].
39pub async fn read_bytes<R>(r: &mut R, allowed_size: RangeInclusive<usize>) -> io::Result<Vec<u8>>
40where
41    R: AsyncReadExt + Unpin + ?Sized,
42{
43    // read the length field
44    let len = r.read_u64_le().await?;
45    let len: usize = len
46        .try_into()
47        .ok()
48        .filter(|len| allowed_size.contains(len))
49        .ok_or_else(|| {
50            io::Error::new(
51                io::ErrorKind::InvalidData,
52                "signalled package size not in allowed range",
53            )
54        })?;
55
56    // calculate the total length, including padding.
57    // byte packets are padded to 8 byte blocks each.
58    let padded_len = padding_len(len as u64) as u64 + (len as u64);
59    let mut limited_reader = r.take(padded_len);
60
61    let mut buf = Vec::new();
62
63    let s = limited_reader.read_to_end(&mut buf).await?;
64
65    // make sure we got exactly the number of bytes, and not less.
66    if s as u64 != padded_len {
67        return Err(io::ErrorKind::UnexpectedEof.into());
68    }
69
70    let (_content, padding) = buf.split_at(len);
71
72    // ensure the padding is all zeroes.
73    if padding.iter().any(|&b| b != 0) {
74        return Err(io::Error::new(
75            io::ErrorKind::InvalidData,
76            "padding is not all zeroes",
77        ));
78    }
79
80    // return the data without the padding
81    buf.truncate(len);
82    Ok(buf)
83}
84
85#[cfg(feature = "async")]
86pub(crate) async fn read_bytes_buf<'a, const N: usize, R>(
87    reader: &mut R,
88    buf: &'a mut [MaybeUninit<u8>; N],
89    allowed_size: RangeInclusive<usize>,
90) -> io::Result<&'a [u8]>
91where
92    R: AsyncReadExt + Unpin + ?Sized,
93{
94    assert_eq!(N % 8, 0);
95    assert!(*allowed_size.end() <= N);
96
97    let len = reader.read_u64_le().await?;
98    let len: usize = len
99        .try_into()
100        .ok()
101        .filter(|len| allowed_size.contains(len))
102        .ok_or_else(|| {
103            io::Error::new(
104                io::ErrorKind::InvalidData,
105                "signalled package size not in allowed range",
106            )
107        })?;
108
109    let buf_len = (len + 7) & !7;
110    let buf = {
111        let mut read_buf = ReadBuf::uninit(&mut buf[..buf_len]);
112
113        while read_buf.filled().len() < buf_len {
114            reader.read_buf(&mut read_buf).await?;
115        }
116
117        // ReadBuf::filled does not pass the underlying buffer's lifetime through,
118        // so we must make a trip to hell.
119        //
120        // SAFETY: `read_buf` is filled up to `buf_len`, and we verify that it is
121        // still pointing at the same underlying buffer.
122        unsafe {
123            assert_eq!(read_buf.filled().as_ptr(), buf.as_ptr() as *const u8);
124            assume_init_bytes(&buf[..buf_len])
125        }
126    };
127
128    if buf[len..buf_len].iter().any(|&b| b != 0) {
129        return Err(io::Error::new(
130            io::ErrorKind::InvalidData,
131            "padding is not all zeroes",
132        ));
133    }
134
135    Ok(&buf[..len])
136}
137
138/// SAFETY: The bytes have to actually be initialized.
139#[cfg(feature = "async")]
140unsafe fn assume_init_bytes(slice: &[MaybeUninit<u8>]) -> &[u8] {
141    &*(slice as *const [MaybeUninit<u8>] as *const [u8])
142}
143
144/// Read a "bytes wire packet" of from the AsyncRead and tries to parse as string.
145/// Internally uses [read_bytes].
146/// Rejects reading more than `allowed_size` bytes of payload.
147pub async fn read_string<R>(r: &mut R, allowed_size: RangeInclusive<usize>) -> io::Result<String>
148where
149    R: AsyncReadExt + Unpin,
150{
151    let bytes = read_bytes(r, allowed_size).await?;
152    String::from_utf8(bytes).map_err(|e| Error::new(ErrorKind::InvalidData, e))
153}
154
155/// Writes a "bytes wire packet" to a (hopefully buffered) [AsyncWriteExt].
156///
157/// Accepts anything implementing AsRef<[u8]> as payload.
158///
159/// See [read_bytes] for a description of the format.
160///
161/// Note: if performance matters to you, make sure your
162/// [AsyncWriteExt] handle is buffered. This function is quite
163/// write-intesive.
164pub async fn write_bytes<W: AsyncWriteExt + Unpin, B: AsRef<[u8]>>(
165    w: &mut W,
166    b: B,
167) -> io::Result<()> {
168    // write the size packet.
169    w.write_u64_le(b.as_ref().len() as u64).await?;
170
171    // write the payload
172    w.write_all(b.as_ref()).await?;
173
174    // write padding if needed
175    let padding_len = padding_len(b.as_ref().len() as u64) as usize;
176    if padding_len != 0 {
177        w.write_all(&EMPTY_BYTES[..padding_len]).await?;
178    }
179    Ok(())
180}
181
182/// Computes the number of bytes we should add to len (a length in
183/// bytes) to be aligned on 64 bits (8 bytes).
184pub(crate) fn padding_len(len: u64) -> u8 {
185    let aligned = len.wrapping_add(7) & !7;
186    aligned.wrapping_sub(len) as u8
187}
188
189#[cfg(test)]
190mod tests {
191    use tokio_test::{assert_ok, io::Builder};
192
193    use super::*;
194    use hex_literal::hex;
195
196    /// The maximum length of bytes packets we're willing to accept in the test
197    /// cases.
198    const MAX_LEN: usize = 1024;
199
200    #[tokio::test]
201    async fn test_read_8_bytes() {
202        let mut mock = Builder::new()
203            .read(&8u64.to_le_bytes())
204            .read(&12345678u64.to_le_bytes())
205            .build();
206
207        assert_eq!(
208            &12345678u64.to_le_bytes(),
209            read_bytes(&mut mock, 0..=MAX_LEN).await.unwrap().as_slice()
210        );
211    }
212
213    #[tokio::test]
214    async fn test_read_9_bytes() {
215        let mut mock = Builder::new()
216            .read(&9u64.to_le_bytes())
217            .read(&hex!("01020304050607080900000000000000"))
218            .build();
219
220        assert_eq!(
221            hex!("010203040506070809"),
222            read_bytes(&mut mock, 0..=MAX_LEN).await.unwrap().as_slice()
223        );
224    }
225
226    #[tokio::test]
227    async fn test_read_0_bytes() {
228        // A empty byte packet is essentially just the 0 length field.
229        // No data is read, and there's zero padding.
230        let mut mock = Builder::new().read(&0u64.to_le_bytes()).build();
231
232        assert_eq!(
233            hex!(""),
234            read_bytes(&mut mock, 0..=MAX_LEN).await.unwrap().as_slice()
235        );
236    }
237
238    #[tokio::test]
239    /// Ensure we don't read any further than the size field if the length
240    /// doesn't match the range we want to accept.
241    async fn test_read_reject_too_large() {
242        let mut mock = Builder::new().read(&100u64.to_le_bytes()).build();
243
244        read_bytes(&mut mock, 10..=10)
245            .await
246            .expect_err("expect this to fail");
247    }
248
249    #[tokio::test]
250    async fn test_write_bytes_no_padding() {
251        let input = hex!("6478696f34657661");
252        let len = input.len() as u64;
253        let mut mock = Builder::new()
254            .write(&len.to_le_bytes())
255            .write(&input)
256            .build();
257        assert_ok!(write_bytes(&mut mock, &input).await)
258    }
259    #[tokio::test]
260    async fn test_write_bytes_with_padding() {
261        let input = hex!("322e332e3137");
262        let len = input.len() as u64;
263        let mut mock = Builder::new()
264            .write(&len.to_le_bytes())
265            .write(&hex!("322e332e31370000"))
266            .build();
267        assert_ok!(write_bytes(&mut mock, &input).await)
268    }
269
270    #[tokio::test]
271    async fn test_write_string() {
272        let input = "Hello, World!";
273        let len = input.len() as u64;
274        let mut mock = Builder::new()
275            .write(&len.to_le_bytes())
276            .write(&hex!("48656c6c6f2c20576f726c6421000000"))
277            .build();
278        assert_ok!(write_bytes(&mut mock, &input).await)
279    }
280
281    #[test]
282    fn padding_len_u64_max() {
283        assert_eq!(padding_len(u64::MAX), 1);
284    }
285}