nix_compat/nix_daemon/framing/stderr_read.rs
1use std::{
2 io::Result,
3 pin::Pin,
4 task::{ready, Poll},
5};
6
7use bytes::{BufMut, BytesMut};
8use pin_project_lite::pin_project;
9use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
10
11use crate::worker_protocol::STDERR_READ;
12
13#[derive(Debug)]
14struct U64WriteState {
15 bytes: [u8; 8],
16 written: usize,
17}
18
19impl U64WriteState {
20 fn remaining(&self) -> &[u8] {
21 &self.bytes[self.written..]
22 }
23}
24
25/// State machine for [`StderrReadFramedReader`].
26///
27/// As the reader progresses it linearly cycles through the states.
28#[derive(Debug)]
29enum StderrReaderState {
30 /// Represents the state indicating that we are about to request a new frame.
31 ///
32 /// When poll_read is called, it writes STDERR_READ into the writer and
33 /// progresses to the [`StderrReaderState::RequestingFrameLen`] state
34 ///
35 /// The reader always starts in this state and is reached after every frame has
36 /// been fully read.
37 RequestingNextFrame { write_state: U64WriteState },
38 /// At this point the reader writes the desired payload length we want to receive
39 /// based on read_buf.remaining().
40 RequestingFrameLen {
41 // We need to write 8 bytes of the length u64 value,
42 // this variable stores how many we've written so far.
43 write_state: U64WriteState,
44 },
45 /// At this point the reader just flushes the writer and gets ready to receive
46 /// the actual payload size that is about to be sent to us by transitioning to
47 /// the [`StderrReaderState::ReadingSize`] state.
48 FrameLenRequested,
49 /// The size is a u64 which is 8 bytes long, while it's likely that we will receive
50 /// the whole u64 in one read, it's possible that it will arrive in smaller chunks.
51 /// So in this state we read up to 8 bytes and transition to
52 /// [`StderrReaderState::ReadingPayload`] when done.
53 ReadingSize { buf: [u8; 8], filled: usize },
54 /// This is where we read the actual payload that is sent to us.
55 /// All of the previous states were just internal bookkeeping where we did not return
56 /// any data to the conumer, and only returned Poll::Pending.
57 ///
58 /// Having read the full payload, progresses to the [`StderrReaderState::RequestingNextFrame`]
59 /// state to read the next frame when/if requested.
60 ReadingPayload {
61 /// Represents the remaining number of bytes we expect to read based on the value
62 /// read in the previous state.
63 remaining: u64,
64 /// Represents the remaining of padding we expect to read before switching back
65 /// to the RequestingNextFrame state.
66 pad: usize,
67 /// In an ideal case this reader does not allocate, but in the scenario where
68 /// we've read the whol payload frame but still have padding remaining, it's not
69 /// safe to return the payload to the consumer as there is risk that the reader
70 /// won't be called again, leaving dangling padding. In this case we store the
71 /// payload in this buffer until we've read the padding, and then return the data
72 /// from here.
73 tmp_buf: BytesMut,
74 },
75}
76
77impl StderrReaderState {
78 fn request_next_frame() -> Self {
79 Self::RequestingNextFrame {
80 write_state: U64WriteState {
81 bytes: STDERR_READ.to_le_bytes(),
82 written: 0,
83 },
84 }
85 }
86
87 fn read_written(len: u64) -> Self {
88 Self::RequestingFrameLen {
89 write_state: U64WriteState {
90 bytes: len.to_le_bytes(),
91 written: 0,
92 },
93 }
94 }
95}
96
97pin_project! {
98 /// Implements the reader protocol for STDERR_READ in nix protocol version 1.21..1.23.
99 ///
100 /// See logging.md#stderr_read and [`StderrReaderState`] for details.
101 ///
102 /// FUTUREWORK: As per the nix protocol, it should be possible to send logging messages
103 /// concurrently with reads, however this reader currently monopolizes the writer until eof is
104 /// reached or the writer is dropped. It's important we don't allow certain interleavings of
105 /// log writes, i.e. it's not ok to issue a log message right after we've requested
106 /// STDERR_READ, but before requesting the length.
107 pub struct StderrReadFramedReader<R, W> {
108 #[pin]
109 reader: R,
110 #[pin]
111 writer: W,
112 state: StderrReaderState
113 }
114}
115
116impl<R, W> StderrReadFramedReader<R, W> {
117 pub fn new(reader: R, writer: W) -> Self {
118 Self {
119 reader,
120 writer,
121 state: StderrReaderState::request_next_frame(),
122 }
123 }
124}
125
126impl<R: AsyncRead, W: AsyncWrite> AsyncRead for StderrReadFramedReader<R, W> {
127 fn poll_read(
128 mut self: Pin<&mut Self>,
129 cx: &mut std::task::Context<'_>,
130 read_buf: &mut ReadBuf<'_>,
131 ) -> Poll<Result<()>> {
132 loop {
133 let mut this = self.as_mut().project();
134 match this.state {
135 StderrReaderState::RequestingNextFrame { write_state } => {
136 write_state.written +=
137 ready!(this.writer.poll_write(cx, write_state.remaining()))?;
138 if write_state.written == 8 {
139 *this.state = StderrReaderState::read_written(read_buf.remaining() as u64);
140 }
141 }
142 StderrReaderState::RequestingFrameLen { write_state } => {
143 write_state.written +=
144 ready!(this.writer.poll_write(cx, write_state.remaining()))?;
145 if write_state.written == 8 {
146 *this.state = StderrReaderState::FrameLenRequested;
147 }
148 }
149 StderrReaderState::FrameLenRequested => {
150 ready!(this.writer.poll_flush(cx))?;
151 *this.state = StderrReaderState::ReadingSize {
152 buf: [0u8; 8],
153 filled: 0,
154 };
155 }
156 StderrReaderState::ReadingSize { buf, filled } => {
157 if *filled < buf.len() {
158 let mut size_buf = ReadBuf::new(buf);
159 size_buf.advance(*filled);
160
161 ready!(this.reader.poll_read(cx, &mut size_buf))?;
162 let bytes_read = size_buf.filled().len() - *filled;
163 if bytes_read == 0 {
164 // oef
165 return Poll::Ready(Ok(()));
166 }
167 *filled += bytes_read;
168 continue;
169 }
170 let size = u64::from_le_bytes(*buf);
171 if size == 0 {
172 // eof
173 *this.state = StderrReaderState::request_next_frame();
174 return Poll::Ready(Ok(()));
175 }
176 let pad = (8 - (size % 8) as usize) % 8;
177 *this.state = StderrReaderState::ReadingPayload {
178 remaining: size,
179 pad,
180 tmp_buf: BytesMut::new(),
181 };
182 }
183 StderrReaderState::ReadingPayload {
184 remaining,
185 pad,
186 tmp_buf,
187 } => {
188 // Make sure we never try to read more than usize which is 4 bytes on 32-bit platforms.
189 let safe_remaining = if *remaining <= (usize::MAX - *pad) as u64 {
190 *remaining as usize + *pad
191 } else {
192 usize::MAX
193 };
194 if safe_remaining - *pad > 0 {
195 // The buffer is no larger than the amount of data that we expect.
196 // Otherwise we will trim the buffer below and come back here.
197 if read_buf.remaining() <= safe_remaining {
198 let filled_before = read_buf.filled().len();
199
200 ready!(this.reader.as_mut().poll_read(cx, read_buf))?;
201 let bytes_read = read_buf.filled().len() - filled_before;
202 let payload_size = std::cmp::min(bytes_read, safe_remaining - *pad);
203
204 // we don't want to include padding bytes in the result, so we remove them from read_buf.
205 read_buf.set_filled(filled_before + payload_size);
206
207 *remaining -= payload_size as u64;
208 if *remaining > 0 {
209 // We have more data to read so we just return ok, knowing that the consumer
210 // will read again.
211 return Poll::Ready(Ok(()));
212 }
213
214 // If we don't have any remaining data to read, consume any padding we may have just read.
215 *pad -= bytes_read - payload_size;
216 if *pad != 0 {
217 // We haven't read all the padding yet, so we stash it away to return to the caller
218 // once we've read the remaining padding.
219 tmp_buf.clear();
220 tmp_buf.put_slice(&read_buf.filled()[filled_before..payload_size]);
221 read_buf.set_filled(filled_before);
222 continue;
223 }
224 *this.state = StderrReaderState::request_next_frame();
225 return Poll::Ready(Ok(()));
226 }
227
228 // Don't read more than remaining + pad bytes, it avoids unnecessary allocations and makes
229 // internal bookkeeping simpler.
230 let mut smaller_buf = read_buf.take(safe_remaining);
231 ready!(self.as_mut().poll_read(cx, &mut smaller_buf))?;
232
233 let bytes_read = smaller_buf.filled().len();
234
235 // SAFETY: we just read this number of bytes into read_buf's backing slice above.
236 unsafe { read_buf.assume_init(bytes_read) };
237 read_buf.advance(bytes_read);
238 return Poll::Ready(Ok(()));
239 } else if *pad > 0 {
240 // if we've read the whole payload but there is still padding remaining,
241 // we read it into a stack allocated array
242 let mut pad_arr = [0u8; 7];
243 let mut pad_buf = ReadBuf::new(&mut pad_arr);
244 pad_buf.advance(7 - *pad);
245 ready!(this.reader.poll_read(cx, &mut pad_buf))?;
246 *pad = pad_buf.remaining();
247 if *pad != 0 {
248 continue;
249 }
250 }
251 // now it's finally time to hand out the read data to the caller and reset to the RequestingNextFrame state.
252 read_buf.put_slice(tmp_buf);
253 tmp_buf.clear();
254 *this.state = StderrReaderState::request_next_frame();
255 return Poll::Ready(Ok(()));
256 }
257 }
258 }
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use std::time::Duration;
265
266 use hex_literal::hex;
267 use tokio::io::{split, AsyncReadExt, BufReader};
268 use tokio_test::io::Builder;
269
270 use crate::{nix_daemon::framing::StderrReadFramedReader, worker_protocol::STDERR_READ};
271
272 #[tokio::test]
273 async fn test_single_two_byte_read_with_desired_size_ten() {
274 let mock = Builder::new()
275 // The reader should first write STDERR_READ and requested number of bytes into the writer
276 .write(&STDERR_READ.to_le_bytes())
277 .write(&10u64.to_le_bytes())
278 .wait(Duration::ZERO)
279 // The client sent not 10 but 2 bytes
280 .read(&2u64.to_le_bytes())
281 // Immediately followed by the bytes and padding
282 .read("hi".as_bytes())
283 .read(&hex!("0000 0000 0000"))
284 .build();
285 let (r, w) = split(mock);
286 let mut reader = StderrReadFramedReader::new(r, w);
287
288 let mut result = [0u8; 2];
289 let mut buf_reader = BufReader::with_capacity(10, &mut reader);
290 let n = buf_reader.read_exact(&mut result).await.unwrap();
291
292 assert_eq!(2, n);
293 assert_eq!("hi".as_bytes(), result);
294 }
295
296 #[tokio::test]
297 async fn test_single_read_with_padding_delayed() {
298 let mock = Builder::new()
299 // The reader should first write STDERR_READ and requested number of bytes into the writer
300 .write(&STDERR_READ.to_le_bytes())
301 .write(&10u64.to_le_bytes())
302 // The client sent 9 bytes not 10.
303 .read(&9u64.to_le_bytes())
304 // Immeditaly followed by bytes
305 .read(&hex!("0202 0104 ffff ffaa 00"))
306 // Followed by a delayed padding
307 .wait(Duration::ZERO)
308 .read(&hex!("0000 0000 0000 00"))
309 .build();
310 let (r, w) = split(mock);
311 let mut reader = StderrReadFramedReader::new(r, w);
312
313 let mut result = [0u8; 9];
314 let mut buf_reader = BufReader::with_capacity(10, &mut reader);
315 let n = buf_reader.read_exact(&mut result).await.unwrap();
316
317 assert_eq!(9, n);
318 assert_eq!(hex!("0202 0104 ffff ffaa 00"), result);
319 }
320
321 #[tokio::test]
322 async fn test_multiple_consecutive_reads_with_arbitrary_delays() {
323 let mock = Builder::new()
324 // The reader should first write STDERR_READ and requested number of bytes into the writer
325 .write(&STDERR_READ.to_le_bytes())
326 .write(&8192u64.to_le_bytes())
327 .wait(Duration::ZERO)
328 // The client sends 6 bytes 'hello ' plus padding
329 .read(&6u64.to_le_bytes())
330 .wait(Duration::ZERO)
331 .read("hello ".as_bytes())
332 .read(&hex!("0000"))
333 // The reader sends desired length again
334 .write(&STDERR_READ.to_le_bytes())
335 .write(&8192u64.to_le_bytes())
336 // The client sends 11 bytes 'racerunners' with 's' and padding delayed
337 .wait(Duration::ZERO)
338 .read(&11u64.to_le_bytes())
339 .read("racerunner".as_bytes())
340 .wait(Duration::ZERO)
341 .read("s".as_bytes())
342 .read(&hex!("0000 0000"))
343 .wait(Duration::ZERO)
344 .read(&hex!("00"))
345 .write(&STDERR_READ.to_le_bytes())
346 .write(&8192u64.to_le_bytes())
347 .wait(Duration::ZERO)
348 .read(&0u64.to_le_bytes())
349 .build();
350 let (r, w) = split(mock);
351 let mut reader = StderrReadFramedReader::new(r, w);
352
353 let mut res = String::new();
354 let mut buf_reader = BufReader::with_capacity(8192, &mut reader);
355 let n = buf_reader.read_to_string(&mut res).await.unwrap();
356
357 assert_eq!(17, n);
358 assert_eq!("hello racerunners", &res);
359 }
360
361 #[tokio::test]
362 async fn test_single_read_where_writing_stderr_and_desired_size_take_more_than_one_write() {
363 let stderr_bytes = STDERR_READ.to_le_bytes();
364 let length_bytes = 10u64.to_le_bytes();
365 let mock = Builder::new()
366 .write(&stderr_bytes[..4])
367 .wait(Duration::ZERO)
368 .write(&stderr_bytes[4..])
369 .wait(Duration::ZERO)
370 .write(&length_bytes[..4])
371 .wait(Duration::ZERO)
372 .write(&length_bytes[4..])
373 .wait(Duration::ZERO)
374 // The client sent not 10 but 2 bytes
375 .read(&2u64.to_le_bytes())
376 // Immediately followed by the bytes and padding
377 .read("hi".as_bytes())
378 .read(&hex!("0000 0000 0000"))
379 .build();
380 let (r, w) = split(mock);
381 let mut reader = StderrReadFramedReader::new(r, w);
382
383 let mut result = [0u8; 2];
384 let mut buf_reader = BufReader::with_capacity(10, &mut reader);
385 let n = buf_reader.read_exact(&mut result).await.unwrap();
386
387 assert_eq!(2, n);
388 assert_eq!("hi".as_bytes(), result);
389 }
390
391 #[tokio::test]
392 async fn hello() {}
393}