nix_compat/wire/bytes/
mod.rs1#[cfg(feature = "async")]
2use std::mem::MaybeUninit;
3use std::{
4 io::{Error, ErrorKind},
5 ops::RangeInclusive,
6};
7#[cfg(feature = "async")]
8use tokio::io::ReadBuf;
9use tokio::io::{self, AsyncReadExt, AsyncWriteExt};
10
11pub(crate) mod reader;
12pub use reader::BytesReader;
13mod writer;
14pub use writer::BytesWriter;
15
16pub(crate) const EMPTY_BYTES: &[u8; 8] = &[0u8; 8];
18
19const LEN_SIZE: usize = 8;
21
22pub async fn read_bytes<R>(r: &mut R, allowed_size: RangeInclusive<usize>) -> io::Result<Vec<u8>>
40where
41 R: AsyncReadExt + Unpin + ?Sized,
42{
43 let len = r.read_u64_le().await?;
45 let len: usize = len
46 .try_into()
47 .ok()
48 .filter(|len| allowed_size.contains(len))
49 .ok_or_else(|| {
50 io::Error::new(
51 io::ErrorKind::InvalidData,
52 "signalled package size not in allowed range",
53 )
54 })?;
55
56 let padded_len = padding_len(len as u64) as u64 + (len as u64);
59 let mut limited_reader = r.take(padded_len);
60
61 let mut buf = Vec::new();
62
63 let s = limited_reader.read_to_end(&mut buf).await?;
64
65 if s as u64 != padded_len {
67 return Err(io::ErrorKind::UnexpectedEof.into());
68 }
69
70 let (_content, padding) = buf.split_at(len);
71
72 if padding.iter().any(|&b| b != 0) {
74 return Err(io::Error::new(
75 io::ErrorKind::InvalidData,
76 "padding is not all zeroes",
77 ));
78 }
79
80 buf.truncate(len);
82 Ok(buf)
83}
84
85#[cfg(feature = "async")]
86pub(crate) async fn read_bytes_buf<'a, const N: usize, R>(
87 reader: &mut R,
88 buf: &'a mut [MaybeUninit<u8>; N],
89 allowed_size: RangeInclusive<usize>,
90) -> io::Result<&'a [u8]>
91where
92 R: AsyncReadExt + Unpin + ?Sized,
93{
94 assert_eq!(N % 8, 0);
95 assert!(*allowed_size.end() <= N);
96
97 let len = reader.read_u64_le().await?;
98 let len: usize = len
99 .try_into()
100 .ok()
101 .filter(|len| allowed_size.contains(len))
102 .ok_or_else(|| {
103 io::Error::new(
104 io::ErrorKind::InvalidData,
105 "signalled package size not in allowed range",
106 )
107 })?;
108
109 let buf_len = (len + 7) & !7;
110 let buf = {
111 let mut read_buf = ReadBuf::uninit(&mut buf[..buf_len]);
112
113 while read_buf.filled().len() < buf_len {
114 reader.read_buf(&mut read_buf).await?;
115 }
116
117 unsafe {
123 assert_eq!(read_buf.filled().as_ptr(), buf.as_ptr() as *const u8);
124 assume_init_bytes(&buf[..buf_len])
125 }
126 };
127
128 if buf[len..buf_len].iter().any(|&b| b != 0) {
129 return Err(io::Error::new(
130 io::ErrorKind::InvalidData,
131 "padding is not all zeroes",
132 ));
133 }
134
135 Ok(&buf[..len])
136}
137
138#[cfg(feature = "async")]
140unsafe fn assume_init_bytes(slice: &[MaybeUninit<u8>]) -> &[u8] {
141 &*(slice as *const [MaybeUninit<u8>] as *const [u8])
142}
143
144pub async fn read_string<R>(r: &mut R, allowed_size: RangeInclusive<usize>) -> io::Result<String>
148where
149 R: AsyncReadExt + Unpin,
150{
151 let bytes = read_bytes(r, allowed_size).await?;
152 String::from_utf8(bytes).map_err(|e| Error::new(ErrorKind::InvalidData, e))
153}
154
155pub async fn write_bytes<W: AsyncWriteExt + Unpin, B: AsRef<[u8]>>(
165 w: &mut W,
166 b: B,
167) -> io::Result<()> {
168 w.write_u64_le(b.as_ref().len() as u64).await?;
170
171 w.write_all(b.as_ref()).await?;
173
174 let padding_len = padding_len(b.as_ref().len() as u64) as usize;
176 if padding_len != 0 {
177 w.write_all(&EMPTY_BYTES[..padding_len]).await?;
178 }
179 Ok(())
180}
181
182pub(crate) fn padding_len(len: u64) -> u8 {
185 let aligned = len.wrapping_add(7) & !7;
186 aligned.wrapping_sub(len) as u8
187}
188
189#[cfg(test)]
190mod tests {
191 use tokio_test::{assert_ok, io::Builder};
192
193 use super::*;
194 use hex_literal::hex;
195
196 const MAX_LEN: usize = 1024;
199
200 #[tokio::test]
201 async fn test_read_8_bytes() {
202 let mut mock = Builder::new()
203 .read(&8u64.to_le_bytes())
204 .read(&12345678u64.to_le_bytes())
205 .build();
206
207 assert_eq!(
208 &12345678u64.to_le_bytes(),
209 read_bytes(&mut mock, 0..=MAX_LEN).await.unwrap().as_slice()
210 );
211 }
212
213 #[tokio::test]
214 async fn test_read_9_bytes() {
215 let mut mock = Builder::new()
216 .read(&9u64.to_le_bytes())
217 .read(&hex!("01020304050607080900000000000000"))
218 .build();
219
220 assert_eq!(
221 hex!("010203040506070809"),
222 read_bytes(&mut mock, 0..=MAX_LEN).await.unwrap().as_slice()
223 );
224 }
225
226 #[tokio::test]
227 async fn test_read_0_bytes() {
228 let mut mock = Builder::new().read(&0u64.to_le_bytes()).build();
231
232 assert_eq!(
233 hex!(""),
234 read_bytes(&mut mock, 0..=MAX_LEN).await.unwrap().as_slice()
235 );
236 }
237
238 #[tokio::test]
239 async fn test_read_reject_too_large() {
242 let mut mock = Builder::new().read(&100u64.to_le_bytes()).build();
243
244 read_bytes(&mut mock, 10..=10)
245 .await
246 .expect_err("expect this to fail");
247 }
248
249 #[tokio::test]
250 async fn test_write_bytes_no_padding() {
251 let input = hex!("6478696f34657661");
252 let len = input.len() as u64;
253 let mut mock = Builder::new()
254 .write(&len.to_le_bytes())
255 .write(&input)
256 .build();
257 assert_ok!(write_bytes(&mut mock, &input).await)
258 }
259 #[tokio::test]
260 async fn test_write_bytes_with_padding() {
261 let input = hex!("322e332e3137");
262 let len = input.len() as u64;
263 let mut mock = Builder::new()
264 .write(&len.to_le_bytes())
265 .write(&hex!("322e332e31370000"))
266 .build();
267 assert_ok!(write_bytes(&mut mock, &input).await)
268 }
269
270 #[tokio::test]
271 async fn test_write_string() {
272 let input = "Hello, World!";
273 let len = input.len() as u64;
274 let mut mock = Builder::new()
275 .write(&len.to_le_bytes())
276 .write(&hex!("48656c6c6f2c20576f726c6421000000"))
277 .build();
278 assert_ok!(write_bytes(&mut mock, &input).await)
279 }
280
281 #[test]
282 fn padding_len_u64_max() {
283 assert_eq!(padding_len(u64::MAX), 1);
284 }
285}