nix_compat/nix_daemon/framing/
framed_read.rs1use std::{
2 num::NonZeroU64,
3 pin::Pin,
4 task::{self, Poll, ready},
5};
6
7use pin_project_lite::pin_project;
8use tokio::io::{self, AsyncRead, AsyncReadExt, ReadBuf};
9
10#[derive(Debug, Eq, PartialEq)]
16enum State {
17 Length { buf: [u8; 8], filled: u8 },
18 Chunk { remaining: NonZeroU64 },
19 Eof,
20}
21
22pin_project! {
23 pub struct NixFramedReader<R> {
30 #[pin]
31 reader: R,
32 state: State,
33 }
34}
35
36impl<R> NixFramedReader<R> {
37 pub fn new(reader: R) -> Self {
38 Self {
39 reader,
40 state: State::Length {
41 buf: [0; 8],
42 filled: 0,
43 },
44 }
45 }
46}
47
48impl<R: AsyncRead + Unpin> NixFramedReader<R> {
49 pub async fn is_eof_unpin(&mut self) -> io::Result<bool> {
51 Pin::new(self).is_eof().await
52 }
53}
54
55impl<R: AsyncRead> NixFramedReader<R> {
56 pub async fn is_eof(self: Pin<&mut Self>) -> io::Result<bool> {
58 let mut this = self.project();
59 loop {
63 match this.state {
64 State::Length { buf, filled: 8 } => {
65 *this.state = match NonZeroU64::new(u64::from_le_bytes(*buf)) {
66 None => State::Eof,
67 Some(remaining) => State::Chunk { remaining },
68 };
69 }
70 State::Length { buf, filled } => {
71 let bytes_read = this.reader.read(&mut buf[*filled as usize..]).await? as u8;
72
73 if bytes_read == 0 {
74 return Err(io::ErrorKind::UnexpectedEof.into());
75 }
76
77 *filled += bytes_read;
78 }
79 State::Chunk { .. } => {
80 return Ok(false);
81 }
82 State::Eof => {
83 return Ok(true);
84 }
85 }
86 }
87 }
88}
89
90impl<R: AsyncRead> AsyncRead for NixFramedReader<R> {
91 fn poll_read(
92 mut self: Pin<&mut Self>,
93 cx: &mut task::Context<'_>,
94 buf: &mut ReadBuf<'_>,
95 ) -> Poll<io::Result<()>> {
96 let mut this = self.as_mut().project();
97
98 if buf.remaining() == 0 {
100 return Ok(()).into();
101 }
102
103 loop {
104 let reader = this.reader.as_mut();
105 match this.state {
106 State::Eof => {
107 return Ok(()).into();
108 }
109 State::Length { buf, filled: 8 } => {
110 *this.state = match NonZeroU64::new(u64::from_le_bytes(*buf)) {
111 None => State::Eof,
112 Some(remaining) => State::Chunk { remaining },
113 };
114 }
115 State::Length { buf, filled } => {
116 let bytes_read = {
117 let mut b = ReadBuf::new(&mut buf[*filled as usize..]);
118 ready!(reader.poll_read(cx, &mut b))?;
119 b.filled().len() as u8
120 };
121
122 if bytes_read == 0 {
123 return Err(io::ErrorKind::UnexpectedEof.into()).into();
124 }
125
126 *filled += bytes_read;
127 }
128 State::Chunk { remaining } => {
129 let bytes_read = ready!(with_limited(buf, remaining.get(), |buf| {
130 reader.poll_read(cx, buf).map_ok(|()| buf.filled().len())
131 }))?;
132
133 *this.state = match NonZeroU64::new(remaining.get() - bytes_read as u64) {
134 None => State::Length {
135 buf: [0; 8],
136 filled: 0,
137 },
138 Some(remaining) => State::Chunk { remaining },
139 };
140
141 return if bytes_read == 0 {
142 Err(io::ErrorKind::UnexpectedEof.into())
143 } else {
144 Ok(())
145 }
146 .into();
147 }
148 }
149 }
150 }
151}
152
153fn with_limited<R>(buf: &mut ReadBuf, n: u64, f: impl FnOnce(&mut ReadBuf) -> R) -> R {
157 let mut nbuf = buf.take(n.try_into().unwrap_or(usize::MAX));
158 let ptr = nbuf.initialized().as_ptr();
159 let ret = f(&mut nbuf);
160
161 unsafe {
167 assert_eq!(nbuf.initialized().as_ptr(), ptr);
169
170 let n = nbuf.filled().len();
171 buf.assume_init(n);
172 buf.advance(n);
173 }
174
175 ret
176}
177
178#[cfg(test)]
179mod nix_framed_tests {
180 use std::{
181 cmp::min,
182 pin::Pin,
183 task::{self, Poll},
184 time::Duration,
185 };
186
187 use tokio::io::{self, AsyncRead, AsyncReadExt, ReadBuf};
188 use tokio_test::io::Builder;
189
190 use crate::nix_daemon::framing::NixFramedReader;
191
192 #[tokio::test]
193 async fn read_unexpected_eof_after_frame() {
194 let mut mock = Builder::new()
195 .read(&5u64.to_le_bytes())
197 .read("hello".as_bytes())
199 .wait(Duration::ZERO)
200 .read(&6u64.to_le_bytes())
202 .read(" world".as_bytes())
203 .build();
205
206 let mut reader = NixFramedReader::new(&mut mock);
207 let err = reader.read_to_string(&mut String::new()).await.unwrap_err();
208 assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
209 let err = reader.is_eof_unpin().await.unwrap_err();
210 assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
211 }
212
213 #[tokio::test]
214 async fn read_unexpected_eof_in_frame() {
215 let mut mock = Builder::new()
216 .read(&5u64.to_le_bytes())
218 .read("hello".as_bytes())
220 .wait(Duration::ZERO)
221 .read(&6u64.to_le_bytes())
223 .read(" worl".as_bytes())
224 .build();
226
227 let mut reader = NixFramedReader::new(&mut mock);
228 let err = reader.read_to_string(&mut String::new()).await.unwrap_err();
229 assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
230 let is_eof = reader.is_eof_unpin().await.map_err(|e| e.kind());
231 assert!(matches!(
232 is_eof,
233 Ok(false) | Err(io::ErrorKind::UnexpectedEof)
234 ));
235 }
236
237 #[tokio::test]
238 async fn read_unexpected_eof_in_length() {
239 let mut mock = Builder::new()
240 .read(&5u64.to_le_bytes())
242 .read("hello".as_bytes())
244 .wait(Duration::ZERO)
245 .read(&[0; 7])
247 .build();
248
249 let mut reader = NixFramedReader::new(&mut mock);
250 let err = reader.read_to_string(&mut String::new()).await.unwrap_err();
251 assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
252 let err = reader.is_eof_unpin().await.unwrap_err();
253 assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
254 }
255
256 #[tokio::test]
257 async fn read_hello_world_in_two_frames() {
258 let mut mock = Builder::new()
259 .read(&5u64.to_le_bytes())
261 .read("hello".as_bytes())
263 .wait(Duration::ZERO)
264 .read(&6u64.to_le_bytes())
266 .read(" world".as_bytes())
267 .read(&0u64.to_le_bytes())
268 .build();
269
270 let mut reader = NixFramedReader::new(&mut mock);
271 let mut result = String::new();
272 reader
273 .read_to_string(&mut result)
274 .await
275 .expect("Could not read into result");
276 assert_eq!("hello world", result);
277 assert!(reader.is_eof_unpin().await.unwrap());
278 }
279
280 struct SplitMock<'a> {
281 data: &'a [u8],
282 pending: bool,
283 }
284
285 impl<'a> SplitMock<'a> {
286 fn new(data: &'a [u8]) -> Self {
287 Self {
288 data,
289 pending: false,
290 }
291 }
292 }
293
294 impl AsyncRead for SplitMock<'_> {
295 fn poll_read(
296 mut self: Pin<&mut Self>,
297 _cx: &mut task::Context<'_>,
298 buf: &mut ReadBuf<'_>,
299 ) -> Poll<io::Result<()>> {
300 if self.data.is_empty() {
301 self.pending = true;
302 Poll::Pending
303 } else {
304 let n = min(buf.remaining(), self.data.len());
305 buf.put_slice(&self.data[..n]);
306 self.data = &self.data[n..];
307
308 Poll::Ready(Ok(()))
309 }
310 }
311 }
312
313 #[test]
316 fn split_verif() {
317 let mut cx = task::Context::from_waker(task::Waker::noop());
318 let mut input = make_framed(&[b"hello", b"world", b"!", b""]);
319 let framed_end = input.len();
320 input.extend_from_slice(b"trailing data");
321
322 for end_point in 0..input.len() {
323 let input = &input[..end_point];
324
325 let unsplit_res = {
326 let mut dut = NixFramedReader::new(SplitMock::new(input));
327 let mut data_buf = vec![0; input.len()];
328 let mut read_buf = ReadBuf::new(&mut data_buf);
329
330 for _ in 0..256 {
331 match Pin::new(&mut dut).poll_read(&mut cx, &mut read_buf) {
332 Poll::Ready(res) => res.unwrap(),
333 Poll::Pending => {
334 assert!(dut.reader.pending);
335 break;
336 }
337 }
338 }
339
340 let len = read_buf.filled().len();
341 data_buf.truncate(len);
342
343 assert_eq!(
344 end_point >= framed_end,
345 matches!(dut.state, super::State::Eof),
346 "end_point = {end_point}, state = {:?}",
347 dut.state
348 );
349 (dut.state, data_buf, dut.reader.data)
350 };
351
352 for split_point in 1..end_point.saturating_sub(1) {
353 let split_res = {
354 let mut dut = NixFramedReader::new(SplitMock::new(&[]));
355 let mut data_buf = vec![0; input.len()];
356 let mut read_buf = ReadBuf::new(&mut data_buf);
357
358 dut.reader = SplitMock::new(&input[..split_point]);
359 for _ in 0..256 {
360 match Pin::new(&mut dut).poll_read(&mut cx, &mut read_buf) {
361 Poll::Ready(res) => res.unwrap(),
362 Poll::Pending => {
363 assert!(dut.reader.pending);
364 break;
365 }
366 }
367 }
368
369 dut.reader = SplitMock::new(&input[split_point - dut.reader.data.len()..]);
370 for _ in 0..256 {
371 match Pin::new(&mut dut).poll_read(&mut cx, &mut read_buf) {
372 Poll::Ready(res) => res.unwrap(),
373 Poll::Pending => {
374 assert!(dut.reader.pending);
375 break;
376 }
377 }
378 }
379
380 let len = read_buf.filled().len();
381 data_buf.truncate(len);
382
383 (dut.state, data_buf, dut.reader.data)
384 };
385
386 assert_eq!(split_res, unsplit_res);
387 }
388 }
389 }
390
391 fn make_framed(frames: &[&[u8]]) -> Vec<u8> {
394 let mut buf = vec![];
395
396 for &data in frames {
397 buf.extend_from_slice(&(data.len() as u64).to_le_bytes());
398 buf.extend_from_slice(data);
399 }
400
401 buf
402 }
403}