nix_compat/nix_daemon/framing/
framed_read.rs1use std::{
2 num::NonZeroU64,
3 pin::Pin,
4 task::{self, ready, Poll},
5};
6
7use pin_project_lite::pin_project;
8use tokio::io::{self, AsyncRead, ReadBuf};
9
10#[derive(Debug, Eq, PartialEq)]
16enum State {
17 Length { buf: [u8; 8], filled: u8 },
18 Chunk { remaining: NonZeroU64 },
19 Eof,
20}
21
22pin_project! {
23 pub struct NixFramedReader<R> {
30 #[pin]
31 reader: R,
32 state: State,
33 }
34}
35
36impl<R> NixFramedReader<R> {
37 pub fn new(reader: R) -> Self {
38 Self {
39 reader,
40 state: State::Length {
41 buf: [0; 8],
42 filled: 0,
43 },
44 }
45 }
46
47 #[must_use]
49 pub fn is_eof(&self) -> bool {
50 matches!(self.state, State::Eof)
51 }
52}
53
54impl<R: AsyncRead> AsyncRead for NixFramedReader<R> {
55 fn poll_read(
56 mut self: Pin<&mut Self>,
57 cx: &mut task::Context<'_>,
58 buf: &mut ReadBuf<'_>,
59 ) -> Poll<io::Result<()>> {
60 let mut this = self.as_mut().project();
61
62 if buf.remaining() == 0 {
64 return Ok(()).into();
65 }
66
67 loop {
68 let reader = this.reader.as_mut();
69 match this.state {
70 State::Eof => {
71 return Ok(()).into();
72 }
73 State::Length { buf, filled: 8 } => {
74 *this.state = match NonZeroU64::new(u64::from_le_bytes(*buf)) {
75 None => State::Eof,
76 Some(remaining) => State::Chunk { remaining },
77 };
78 }
79 State::Length { buf, filled } => {
80 let bytes_read = {
81 let mut b = ReadBuf::new(&mut buf[*filled as usize..]);
82 ready!(reader.poll_read(cx, &mut b))?;
83 b.filled().len() as u8
84 };
85
86 if bytes_read == 0 {
87 return Err(io::ErrorKind::UnexpectedEof.into()).into();
88 }
89
90 *filled += bytes_read;
91 }
92 State::Chunk { remaining } => {
93 let bytes_read = ready!(with_limited(buf, remaining.get(), |buf| {
94 reader.poll_read(cx, buf).map_ok(|()| buf.filled().len())
95 }))?;
96
97 *this.state = match NonZeroU64::new(remaining.get() - bytes_read as u64) {
98 None => State::Length {
99 buf: [0; 8],
100 filled: 0,
101 },
102 Some(remaining) => State::Chunk { remaining },
103 };
104
105 return if bytes_read == 0 {
106 Err(io::ErrorKind::UnexpectedEof.into())
107 } else {
108 Ok(())
109 }
110 .into();
111 }
112 }
113 }
114 }
115}
116
117fn with_limited<R>(buf: &mut ReadBuf, n: u64, f: impl FnOnce(&mut ReadBuf) -> R) -> R {
121 let mut nbuf = buf.take(n.try_into().unwrap_or(usize::MAX));
122 let ptr = nbuf.initialized().as_ptr();
123 let ret = f(&mut nbuf);
124
125 unsafe {
131 assert_eq!(nbuf.initialized().as_ptr(), ptr);
133
134 let n = nbuf.filled().len();
135 buf.assume_init(n);
136 buf.advance(n);
137 }
138
139 ret
140}
141
142#[cfg(test)]
143mod nix_framed_tests {
144 use std::{
145 cmp::min,
146 pin::Pin,
147 task::{self, Poll},
148 time::Duration,
149 };
150
151 use tokio::io::{self, AsyncRead, AsyncReadExt, ReadBuf};
152 use tokio_test::io::Builder;
153
154 use crate::nix_daemon::framing::NixFramedReader;
155
156 #[tokio::test]
157 async fn read_unexpected_eof_after_frame() {
158 let mut mock = Builder::new()
159 .read(&5u64.to_le_bytes())
161 .read("hello".as_bytes())
163 .wait(Duration::ZERO)
164 .read(&6u64.to_le_bytes())
166 .read(" world".as_bytes())
167 .build();
169
170 let mut reader = NixFramedReader::new(&mut mock);
171 let err = reader.read_to_string(&mut String::new()).await.unwrap_err();
172 assert!(!reader.is_eof());
173 assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
174 }
175
176 #[tokio::test]
177 async fn read_unexpected_eof_in_frame() {
178 let mut mock = Builder::new()
179 .read(&5u64.to_le_bytes())
181 .read("hello".as_bytes())
183 .wait(Duration::ZERO)
184 .read(&6u64.to_le_bytes())
186 .read(" worl".as_bytes())
187 .build();
189
190 let mut reader = NixFramedReader::new(&mut mock);
191 let err = reader.read_to_string(&mut String::new()).await.unwrap_err();
192 assert!(!reader.is_eof());
193 assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
194 }
195
196 #[tokio::test]
197 async fn read_unexpected_eof_in_length() {
198 let mut mock = Builder::new()
199 .read(&5u64.to_le_bytes())
201 .read("hello".as_bytes())
203 .wait(Duration::ZERO)
204 .read(&[0; 7])
206 .build();
207
208 let mut reader = NixFramedReader::new(&mut mock);
209 let err = reader.read_to_string(&mut String::new()).await.unwrap_err();
210 assert!(!reader.is_eof());
211 assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
212 }
213
214 #[tokio::test]
215 async fn read_hello_world_in_two_frames() {
216 let mut mock = Builder::new()
217 .read(&5u64.to_le_bytes())
219 .read("hello".as_bytes())
221 .wait(Duration::ZERO)
222 .read(&6u64.to_le_bytes())
224 .read(" world".as_bytes())
225 .read(&0u64.to_le_bytes())
226 .build();
227
228 let mut reader = NixFramedReader::new(&mut mock);
229 let mut result = String::new();
230 reader
231 .read_to_string(&mut result)
232 .await
233 .expect("Could not read into result");
234 assert_eq!("hello world", result);
235 assert!(reader.is_eof());
236 }
237
238 struct SplitMock<'a> {
239 data: &'a [u8],
240 pending: bool,
241 }
242
243 impl<'a> SplitMock<'a> {
244 fn new(data: &'a [u8]) -> Self {
245 Self {
246 data,
247 pending: false,
248 }
249 }
250 }
251
252 impl AsyncRead for SplitMock<'_> {
253 fn poll_read(
254 mut self: Pin<&mut Self>,
255 _cx: &mut task::Context<'_>,
256 buf: &mut ReadBuf<'_>,
257 ) -> Poll<io::Result<()>> {
258 if self.data.is_empty() {
259 self.pending = true;
260 Poll::Pending
261 } else {
262 let n = min(buf.remaining(), self.data.len());
263 buf.put_slice(&self.data[..n]);
264 self.data = &self.data[n..];
265
266 Poll::Ready(Ok(()))
267 }
268 }
269 }
270
271 #[test]
274 fn split_verif() {
275 let mut cx = task::Context::from_waker(task::Waker::noop());
276 let mut input = make_framed(&[b"hello", b"world", b"!", b""]);
277 let framed_end = input.len();
278 input.extend_from_slice(b"trailing data");
279
280 for end_point in 0..input.len() {
281 let input = &input[..end_point];
282
283 let unsplit_res = {
284 let mut dut = NixFramedReader::new(SplitMock::new(input));
285 let mut data_buf = vec![0; input.len()];
286 let mut read_buf = ReadBuf::new(&mut data_buf);
287
288 for _ in 0..256 {
289 match Pin::new(&mut dut).poll_read(&mut cx, &mut read_buf) {
290 Poll::Ready(res) => res.unwrap(),
291 Poll::Pending => {
292 assert!(dut.reader.pending);
293 break;
294 }
295 }
296 }
297
298 let len = read_buf.filled().len();
299 data_buf.truncate(len);
300
301 assert_eq!(
302 end_point >= framed_end,
303 dut.is_eof(),
304 "end_point = {end_point}, state = {:?}",
305 dut.state
306 );
307 (dut.state, data_buf, dut.reader.data)
308 };
309
310 for split_point in 1..end_point.saturating_sub(1) {
311 let split_res = {
312 let mut dut = NixFramedReader::new(SplitMock::new(&[]));
313 let mut data_buf = vec![0; input.len()];
314 let mut read_buf = ReadBuf::new(&mut data_buf);
315
316 dut.reader = SplitMock::new(&input[..split_point]);
317 for _ in 0..256 {
318 match Pin::new(&mut dut).poll_read(&mut cx, &mut read_buf) {
319 Poll::Ready(res) => res.unwrap(),
320 Poll::Pending => {
321 assert!(dut.reader.pending);
322 break;
323 }
324 }
325 }
326
327 dut.reader = SplitMock::new(&input[split_point - dut.reader.data.len()..]);
328 for _ in 0..256 {
329 match Pin::new(&mut dut).poll_read(&mut cx, &mut read_buf) {
330 Poll::Ready(res) => res.unwrap(),
331 Poll::Pending => {
332 assert!(dut.reader.pending);
333 break;
334 }
335 }
336 }
337
338 let len = read_buf.filled().len();
339 data_buf.truncate(len);
340
341 (dut.state, data_buf, dut.reader.data)
342 };
343
344 assert_eq!(split_res, unsplit_res);
345 }
346 }
347 }
348
349 fn make_framed(frames: &[&[u8]]) -> Vec<u8> {
352 let mut buf = vec![];
353
354 for &data in frames {
355 buf.extend_from_slice(&(data.len() as u64).to_le_bytes());
356 buf.extend_from_slice(data);
357 }
358
359 buf
360 }
361}