1use std::future::poll_fn;
2use std::io::{self, Cursor};
3use std::ops::RangeInclusive;
4use std::pin::Pin;
5use std::task::{ready, Context, Poll};
6
7use bytes::{Buf, BufMut, Bytes, BytesMut};
8use pin_project_lite::pin_project;
9use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, ReadBuf};
10
11use crate::wire::{ProtocolVersion, EMPTY_BYTES};
12
13use super::{Error, NixRead};
14
15pub struct NixReaderBuilder {
16 buf: Option<BytesMut>,
17 reserved_buf_size: usize,
18 max_buf_size: usize,
19 version: ProtocolVersion,
20}
21
22impl Default for NixReaderBuilder {
23 fn default() -> Self {
24 Self {
25 buf: Default::default(),
26 reserved_buf_size: 8192,
27 max_buf_size: 8192,
28 version: Default::default(),
29 }
30 }
31}
32
33impl NixReaderBuilder {
34 pub fn set_buffer(mut self, buf: BytesMut) -> Self {
35 self.buf = Some(buf);
36 self
37 }
38
39 pub fn set_reserved_buf_size(mut self, size: usize) -> Self {
40 self.reserved_buf_size = size;
41 self
42 }
43
44 pub fn set_max_buf_size(mut self, size: usize) -> Self {
45 self.max_buf_size = size;
46 self
47 }
48
49 pub fn set_version(mut self, version: ProtocolVersion) -> Self {
50 self.version = version;
51 self
52 }
53
54 pub fn build<R>(self, reader: R) -> NixReader<R> {
55 let buf = self.buf.unwrap_or_else(|| BytesMut::with_capacity(0));
56 NixReader {
57 buf,
58 inner: reader,
59 reserved_buf_size: self.reserved_buf_size,
60 max_buf_size: self.max_buf_size,
61 version: self.version,
62 }
63 }
64}
65
66pin_project! {
67 pub struct NixReader<R> {
68 #[pin]
69 inner: R,
70 buf: BytesMut,
71 reserved_buf_size: usize,
72 max_buf_size: usize,
73 version: ProtocolVersion,
74 }
75}
76
77impl NixReader<Cursor<Vec<u8>>> {
78 pub fn builder() -> NixReaderBuilder {
79 NixReaderBuilder::default()
80 }
81}
82
83impl<R> NixReader<R>
84where
85 R: AsyncReadExt,
86{
87 pub fn new(reader: R) -> NixReader<R> {
88 NixReader::builder().build(reader)
89 }
90
91 pub fn buffer(&self) -> &[u8] {
92 &self.buf[..]
93 }
94
95 #[cfg(test)]
96 pub(crate) fn buffer_mut(&mut self) -> &mut BytesMut {
97 &mut self.buf
98 }
99
100 pub fn remaining_mut(&self) -> usize {
102 self.buf.capacity() - self.buf.len()
103 }
104
105 fn poll_force_fill_buf(
106 mut self: Pin<&mut Self>,
107 cx: &mut Context<'_>,
108 ) -> Poll<io::Result<usize>> {
109 if self.remaining_mut() < self.reserved_buf_size {
111 let me = self.as_mut().project();
112 me.buf.reserve(*me.reserved_buf_size);
113 }
114 let me = self.project();
115 let n = {
116 let dst = me.buf.spare_capacity_mut();
117 let mut buf = ReadBuf::uninit(dst);
118 let ptr = buf.filled().as_ptr();
119 ready!(me.inner.poll_read(cx, &mut buf)?);
120
121 assert_eq!(ptr, buf.filled().as_ptr());
123 buf.filled().len()
124 };
125
126 unsafe {
129 me.buf.advance_mut(n);
130 }
131 Poll::Ready(Ok(n))
132 }
133}
134
135impl<R> NixReader<R>
136where
137 R: AsyncReadExt + Unpin,
138{
139 async fn force_fill(&mut self) -> io::Result<usize> {
140 let mut p = Pin::new(self);
141 let read = poll_fn(|cx| p.as_mut().poll_force_fill_buf(cx)).await?;
142 Ok(read)
143 }
144}
145
146impl<R> NixRead for NixReader<R>
147where
148 R: AsyncReadExt + Send + Unpin,
149{
150 type Error = io::Error;
151
152 fn version(&self) -> ProtocolVersion {
153 self.version
154 }
155
156 async fn try_read_number(&mut self) -> Result<Option<u64>, Self::Error> {
157 let mut buf = [0u8; 8];
158 let read = self.read_buf(&mut &mut buf[..]).await?;
159 if read == 0 {
160 return Ok(None);
161 }
162 if read < 8 {
163 self.read_exact(&mut buf[read..]).await?;
164 }
165 let num = Buf::get_u64_le(&mut &buf[..]);
166 Ok(Some(num))
167 }
168
169 async fn try_read_bytes_limited(
170 &mut self,
171 limit: RangeInclusive<usize>,
172 ) -> Result<Option<Bytes>, Self::Error> {
173 assert!(
174 *limit.end() <= self.max_buf_size,
175 "The limit must be smaller than {}",
176 self.max_buf_size
177 );
178 match self.try_read_number().await? {
179 Some(raw_len) => {
180 let len = raw_len
182 .try_into()
183 .ok()
184 .filter(|v| limit.contains(v))
185 .ok_or_else(|| Self::Error::invalid_data("bytes length out of range"))?;
186
187 let aligned: usize = raw_len
189 .checked_add(7)
190 .map(|v| v & !7)
191 .ok_or_else(|| Self::Error::invalid_data("bytes length out of range"))?
192 .try_into()
193 .map_err(Self::Error::invalid_data)?;
194
195 if self.buf.len() + self.remaining_mut() < aligned {
197 self.buf.reserve(aligned - self.buf.len());
198 }
199 while self.buf.len() < aligned {
200 if self.force_fill().await? == 0 {
201 return Err(Self::Error::missing_data(
202 "unexpected end-of-file reading bytes",
203 ));
204 }
205 }
206 let mut contents = self.buf.split_to(aligned);
207
208 let padding = aligned - len;
209 if contents[len..] != EMPTY_BYTES[..padding] {
211 return Err(Self::Error::invalid_data("non-zero padding"));
212 }
213
214 contents.truncate(len);
215 Ok(Some(contents.freeze()))
216 }
217 None => Ok(None),
218 }
219 }
220
221 fn try_read_bytes(
222 &mut self,
223 ) -> impl std::future::Future<Output = Result<Option<Bytes>, Self::Error>> + Send + '_ {
224 self.try_read_bytes_limited(0..=self.max_buf_size)
225 }
226
227 fn read_bytes(
228 &mut self,
229 ) -> impl std::future::Future<Output = Result<Bytes, Self::Error>> + Send + '_ {
230 self.read_bytes_limited(0..=self.max_buf_size)
231 }
232}
233
234impl<R: AsyncRead> AsyncRead for NixReader<R> {
235 fn poll_read(
236 mut self: Pin<&mut Self>,
237 cx: &mut Context<'_>,
238 buf: &mut ReadBuf<'_>,
239 ) -> Poll<io::Result<()>> {
240 let rem = ready!(self.as_mut().poll_fill_buf(cx))?;
241 let amt = std::cmp::min(rem.len(), buf.remaining());
242 buf.put_slice(&rem[0..amt]);
243 self.consume(amt);
244 Poll::Ready(Ok(()))
245 }
246}
247
248impl<R: AsyncRead> AsyncBufRead for NixReader<R> {
249 fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
250 if self.as_ref().project_ref().buf.is_empty() {
251 ready!(self.as_mut().poll_force_fill_buf(cx))?;
252 }
253 let me = self.project();
254 Poll::Ready(Ok(&me.buf[..]))
255 }
256
257 fn consume(self: Pin<&mut Self>, amt: usize) {
258 let me = self.project();
259 me.buf.advance(amt)
260 }
261}
262
263#[cfg(test)]
264mod test {
265 use std::time::Duration;
266
267 use hex_literal::hex;
268 use rstest::rstest;
269 use tokio_test::io::Builder;
270
271 use super::*;
272 use crate::wire::de::NixRead;
273
274 #[tokio::test]
275 async fn test_read_u64() {
276 let mock = Builder::new().read(&hex!("0100 0000 0000 0000")).build();
277 let mut reader = NixReader::new(mock);
278
279 assert_eq!(1, reader.read_number().await.unwrap());
280 assert_eq!(hex!(""), reader.buffer());
281
282 let mut buf = Vec::new();
283 reader.read_to_end(&mut buf).await.unwrap();
284 assert_eq!(hex!(""), &buf[..]);
285 }
286
287 #[tokio::test]
288 async fn test_read_u64_rest() {
289 let mock = Builder::new()
290 .read(&hex!("0100 0000 0000 0000 0123 4567 89AB CDEF"))
291 .build();
292 let mut reader = NixReader::new(mock);
293
294 assert_eq!(1, reader.read_number().await.unwrap());
295 assert_eq!(hex!("0123 4567 89AB CDEF"), reader.buffer());
296
297 let mut buf = Vec::new();
298 reader.read_to_end(&mut buf).await.unwrap();
299 assert_eq!(hex!("0123 4567 89AB CDEF"), &buf[..]);
300 }
301
302 #[tokio::test]
303 async fn test_read_u64_partial() {
304 let mock = Builder::new()
305 .read(&hex!("0100 0000"))
306 .wait(Duration::ZERO)
307 .read(&hex!("0000 0000 0123 4567 89AB CDEF"))
308 .wait(Duration::ZERO)
309 .read(&hex!("0100 0000"))
310 .build();
311 let mut reader = NixReader::new(mock);
312
313 assert_eq!(1, reader.read_number().await.unwrap());
314 assert_eq!(hex!("0123 4567 89AB CDEF"), reader.buffer());
315
316 let mut buf = Vec::new();
317 reader.read_to_end(&mut buf).await.unwrap();
318 assert_eq!(hex!("0123 4567 89AB CDEF 0100 0000"), &buf[..]);
319 }
320
321 #[tokio::test]
322 async fn test_read_u64_eof() {
323 let mock = Builder::new().build();
324 let mut reader = NixReader::new(mock);
325
326 assert_eq!(
327 io::ErrorKind::UnexpectedEof,
328 reader.read_number().await.unwrap_err().kind()
329 );
330 }
331
332 #[tokio::test]
333 async fn test_try_read_u64_none() {
334 let mock = Builder::new().build();
335 let mut reader = NixReader::new(mock);
336
337 assert_eq!(None, reader.try_read_number().await.unwrap());
338 }
339
340 #[tokio::test]
341 async fn test_try_read_u64_eof() {
342 let mock = Builder::new().read(&hex!("0100 0000 0000")).build();
343 let mut reader = NixReader::new(mock);
344
345 assert_eq!(
346 io::ErrorKind::UnexpectedEof,
347 reader.try_read_number().await.unwrap_err().kind()
348 );
349 }
350
351 #[tokio::test]
352 async fn test_try_read_u64_eof2() {
353 let mock = Builder::new()
354 .read(&hex!("0100"))
355 .wait(Duration::ZERO)
356 .read(&hex!("0000 0000"))
357 .build();
358 let mut reader = NixReader::new(mock);
359
360 assert_eq!(
361 io::ErrorKind::UnexpectedEof,
362 reader.try_read_number().await.unwrap_err().kind()
363 );
364 }
365
366 #[rstest]
367 #[case::empty(b"", &hex!("0000 0000 0000 0000"))]
368 #[case::one(b")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))]
369 #[case::two(b"it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))]
370 #[case::three(b"tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))]
371 #[case::four(b"were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))]
372 #[case::five(b"where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))]
373 #[case::six(b"unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))]
374 #[case::seven(b"where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))]
375 #[case::aligned(b"read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))]
376 #[case::more_bytes(b"read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))]
377 #[tokio::test]
378 async fn test_read_bytes(#[case] expected: &[u8], #[case] data: &[u8]) {
379 let mock = Builder::new().read(data).build();
380 let mut reader = NixReader::new(mock);
381 let actual = reader.read_bytes().await.unwrap();
382 assert_eq!(&actual[..], expected);
383 }
384
385 #[tokio::test]
386 async fn test_read_bytes_empty() {
387 let mock = Builder::new().build();
388 let mut reader = NixReader::new(mock);
389
390 assert_eq!(
391 io::ErrorKind::UnexpectedEof,
392 reader.read_bytes().await.unwrap_err().kind()
393 );
394 }
395
396 #[tokio::test]
397 async fn test_try_read_bytes_none() {
398 let mock = Builder::new().build();
399 let mut reader = NixReader::new(mock);
400
401 assert_eq!(None, reader.try_read_bytes().await.unwrap());
402 }
403
404 #[tokio::test]
405 async fn test_try_read_bytes_missing_data() {
406 let mock = Builder::new()
407 .read(&hex!("0500"))
408 .wait(Duration::ZERO)
409 .read(&hex!("0000 0000"))
410 .build();
411 let mut reader = NixReader::new(mock);
412
413 assert_eq!(
414 io::ErrorKind::UnexpectedEof,
415 reader.try_read_bytes().await.unwrap_err().kind()
416 );
417 }
418
419 #[tokio::test]
420 async fn test_try_read_bytes_missing_padding() {
421 let mock = Builder::new()
422 .read(&hex!("0200 0000 0000 0000"))
423 .wait(Duration::ZERO)
424 .read(&hex!("1234"))
425 .build();
426 let mut reader = NixReader::new(mock);
427
428 assert_eq!(
429 io::ErrorKind::UnexpectedEof,
430 reader.try_read_bytes().await.unwrap_err().kind()
431 );
432 }
433
434 #[tokio::test]
435 async fn test_read_bytes_bad_padding() {
436 let mock = Builder::new()
437 .read(&hex!("0200 0000 0000 0000"))
438 .wait(Duration::ZERO)
439 .read(&hex!("1234 0100 0000 0000"))
440 .build();
441 let mut reader = NixReader::new(mock);
442
443 assert_eq!(
444 io::ErrorKind::InvalidData,
445 reader.read_bytes().await.unwrap_err().kind()
446 );
447 }
448
449 #[tokio::test]
450 async fn test_read_bytes_limited_out_of_range() {
451 let mock = Builder::new().read(&hex!("FFFF 0000 0000 0000")).build();
452 let mut reader = NixReader::new(mock);
453
454 assert_eq!(
455 io::ErrorKind::InvalidData,
456 reader.read_bytes_limited(0..=50).await.unwrap_err().kind()
457 );
458 }
459
460 #[tokio::test]
461 async fn test_read_bytes_length_overflow() {
462 let mock = Builder::new().read(&hex!("F9FF FFFF FFFF FFFF")).build();
463 let mut reader = NixReader::builder()
464 .set_max_buf_size(usize::MAX)
465 .build(mock);
466
467 assert_eq!(
468 io::ErrorKind::InvalidData,
469 reader
470 .read_bytes_limited(0..=usize::MAX)
471 .await
472 .unwrap_err()
473 .kind()
474 );
475 }
476
477 #[tokio::test]
479 #[cfg(any(target_pointer_width = "16", target_pointer_width = "32"))]
480 async fn test_bytes_length_conversion_overflow() {
481 let len = (usize::MAX as u64) + 1;
482 let mock = Builder::new().read(&len.to_le_bytes()).build();
483 let mut reader = NixReader::new(mock);
484 assert_eq!(
485 std::io::ErrorKind::InvalidData,
486 reader.read_value::<usize>().await.unwrap_err().kind()
487 );
488 }
489
490 #[tokio::test]
492 #[cfg(any(target_pointer_width = "16", target_pointer_width = "32"))]
493 async fn test_bytes_aligned_length_conversion_overflow() {
494 let len = (usize::MAX - 6) as u64;
495 let mock = Builder::new().read(&len.to_le_bytes()).build();
496 let mut reader = NixReader::new(mock);
497 assert_eq!(
498 std::io::ErrorKind::InvalidData,
499 reader.read_value::<usize>().await.unwrap_err().kind()
500 );
501 }
502
503 #[tokio::test]
504 async fn test_buffer_resize() {
505 let mock = Builder::new()
506 .read(&hex!("0100"))
507 .read(&hex!("0000 0000 0000"))
508 .build();
509 let mut reader = NixReader::builder().set_reserved_buf_size(8).build(mock);
510 assert_eq!(0, reader.buffer_mut().capacity());
512
513 assert_eq!(2, reader.force_fill().await.unwrap());
514
515 assert_eq!(8, reader.buffer_mut().capacity());
517
518 assert_eq!(6, reader.force_fill().await.unwrap());
522 assert_eq!(16, reader.buffer_mut().capacity());
523
524 assert_eq!(1, reader.read_number().await.unwrap());
525 }
526}