nix_compat/nix_daemon/framing/
stderr_read.rs

1use std::{
2    io::Result,
3    pin::Pin,
4    task::{ready, Poll},
5};
6
7use bytes::{BufMut, BytesMut};
8use pin_project_lite::pin_project;
9use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
10
11use crate::worker_protocol::STDERR_READ;
12
13#[derive(Debug)]
14struct U64WriteState {
15    bytes: [u8; 8],
16    written: usize,
17}
18
19impl U64WriteState {
20    fn remaining(&self) -> &[u8] {
21        &self.bytes[self.written..]
22    }
23}
24
25/// State machine for [`StderrReadFramedReader`].
26///
27/// As the reader progresses it linearly cycles through the states.
28#[derive(Debug)]
29enum StderrReaderState {
30    /// Represents the state indicating that we are about to request a new frame.
31    ///
32    /// When poll_read is called, it writes STDERR_READ into the writer and
33    /// progresses to the [`StderrReaderState::RequestingFrameLen`] state
34    ///
35    /// The reader always starts in this state and is reached after every frame has
36    /// been fully read.
37    RequestingNextFrame { write_state: U64WriteState },
38    /// At this point the reader writes the desired payload length we want to receive
39    /// based on read_buf.remaining().
40    RequestingFrameLen {
41        // We need to write 8 bytes of the length u64 value,
42        // this variable stores how many we've written so far.
43        write_state: U64WriteState,
44    },
45    /// At this point the reader just flushes the writer and gets ready to receive
46    /// the actual payload size that is about to be sent to us by transitioning to
47    /// the [`StderrReaderState::ReadingSize`] state.
48    FrameLenRequested,
49    /// The size is a u64 which is 8 bytes long, while it's likely that we will receive
50    /// the whole u64 in one read, it's possible that it will arrive in smaller chunks.
51    /// So in this state we read up to 8 bytes and transition to
52    /// [`StderrReaderState::ReadingPayload`] when done.
53    ReadingSize { buf: [u8; 8], filled: usize },
54    /// This is where we read the actual payload that is sent to us.
55    /// All of the previous states were just internal bookkeeping where we did not return
56    /// any data to the conumer, and only returned Poll::Pending.
57    ///
58    /// Having read the full payload, progresses to the [`StderrReaderState::RequestingNextFrame`]
59    /// state to read the next frame when/if requested.
60    ReadingPayload {
61        /// Represents the remaining number of bytes we expect to read based on the value
62        /// read in the previous state.
63        remaining: u64,
64        /// Represents the remaining of padding we expect to read before switching back
65        /// to the RequestingNextFrame state.
66        pad: usize,
67        /// In an ideal case this reader does not allocate, but in the scenario where
68        /// we've read the whol payload frame but still have padding remaining, it's not
69        /// safe to return the payload to the consumer as there is risk that the reader
70        /// won't be called again, leaving dangling padding. In this case we store the
71        /// payload in this buffer until we've read the padding, and then return the data
72        /// from here.
73        tmp_buf: BytesMut,
74    },
75}
76
77impl StderrReaderState {
78    fn request_next_frame() -> Self {
79        Self::RequestingNextFrame {
80            write_state: U64WriteState {
81                bytes: STDERR_READ.to_le_bytes(),
82                written: 0,
83            },
84        }
85    }
86
87    fn read_written(len: u64) -> Self {
88        Self::RequestingFrameLen {
89            write_state: U64WriteState {
90                bytes: len.to_le_bytes(),
91                written: 0,
92            },
93        }
94    }
95}
96
97pin_project! {
98    /// Implements the reader protocol for STDERR_READ in nix protocol version 1.21..1.23.
99    ///
100    /// See logging.md#stderr_read and [`StderrReaderState`] for details.
101    ///
102    /// FUTUREWORK: As per the nix protocol, it should be possible to send logging messages
103    /// concurrently with reads, however this reader currently monopolizes the writer until eof is
104    /// reached or the writer is dropped. It's important we don't allow certain interleavings of
105    /// log writes, i.e. it's not ok to issue a log message right after we've requested
106    /// STDERR_READ, but before requesting the length.
107    pub struct StderrReadFramedReader<R, W> {
108        #[pin]
109        reader: R,
110        #[pin]
111        writer: W,
112        state: StderrReaderState
113    }
114}
115
116impl<R, W> StderrReadFramedReader<R, W> {
117    pub fn new(reader: R, writer: W) -> Self {
118        Self {
119            reader,
120            writer,
121            state: StderrReaderState::request_next_frame(),
122        }
123    }
124}
125
126impl<R: AsyncRead, W: AsyncWrite> AsyncRead for StderrReadFramedReader<R, W> {
127    fn poll_read(
128        mut self: Pin<&mut Self>,
129        cx: &mut std::task::Context<'_>,
130        read_buf: &mut ReadBuf<'_>,
131    ) -> Poll<Result<()>> {
132        loop {
133            let mut this = self.as_mut().project();
134            match this.state {
135                StderrReaderState::RequestingNextFrame { write_state } => {
136                    write_state.written +=
137                        ready!(this.writer.poll_write(cx, write_state.remaining()))?;
138                    if write_state.written == 8 {
139                        *this.state = StderrReaderState::read_written(read_buf.remaining() as u64);
140                    }
141                }
142                StderrReaderState::RequestingFrameLen { write_state } => {
143                    write_state.written +=
144                        ready!(this.writer.poll_write(cx, write_state.remaining()))?;
145                    if write_state.written == 8 {
146                        *this.state = StderrReaderState::FrameLenRequested;
147                    }
148                }
149                StderrReaderState::FrameLenRequested => {
150                    ready!(this.writer.poll_flush(cx))?;
151                    *this.state = StderrReaderState::ReadingSize {
152                        buf: [0u8; 8],
153                        filled: 0,
154                    };
155                }
156                StderrReaderState::ReadingSize { buf, filled } => {
157                    if *filled < buf.len() {
158                        let mut size_buf = ReadBuf::new(buf);
159                        size_buf.advance(*filled);
160
161                        ready!(this.reader.poll_read(cx, &mut size_buf))?;
162                        let bytes_read = size_buf.filled().len() - *filled;
163                        if bytes_read == 0 {
164                            // oef
165                            return Poll::Ready(Ok(()));
166                        }
167                        *filled += bytes_read;
168                        continue;
169                    }
170                    let size = u64::from_le_bytes(*buf);
171                    if size == 0 {
172                        // eof
173                        *this.state = StderrReaderState::request_next_frame();
174                        return Poll::Ready(Ok(()));
175                    }
176                    let pad = (8 - (size % 8) as usize) % 8;
177                    *this.state = StderrReaderState::ReadingPayload {
178                        remaining: size,
179                        pad,
180                        tmp_buf: BytesMut::new(),
181                    };
182                }
183                StderrReaderState::ReadingPayload {
184                    remaining,
185                    pad,
186                    tmp_buf,
187                } => {
188                    // Make sure we never try to read more than usize which is 4 bytes on 32-bit platforms.
189                    let safe_remaining = if *remaining <= (usize::MAX - *pad) as u64 {
190                        *remaining as usize + *pad
191                    } else {
192                        usize::MAX
193                    };
194                    if safe_remaining - *pad > 0 {
195                        // The buffer is no larger than the amount of data that we expect.
196                        // Otherwise we will trim the buffer below and come back here.
197                        if read_buf.remaining() <= safe_remaining {
198                            let filled_before = read_buf.filled().len();
199
200                            ready!(this.reader.as_mut().poll_read(cx, read_buf))?;
201                            let bytes_read = read_buf.filled().len() - filled_before;
202                            let payload_size = std::cmp::min(bytes_read, safe_remaining - *pad);
203
204                            // we don't want to include padding bytes in the result, so we remove them from read_buf.
205                            read_buf.set_filled(filled_before + payload_size);
206
207                            *remaining -= payload_size as u64;
208                            if *remaining > 0 {
209                                // We have more data to read so we just return ok, knowing that the consumer
210                                // will read again.
211                                return Poll::Ready(Ok(()));
212                            }
213
214                            // If we don't have any remaining data to read, consume any padding we may have just read.
215                            *pad -= bytes_read - payload_size;
216                            if *pad != 0 {
217                                // We haven't read all the padding yet, so we stash it away to return to the caller
218                                // once we've read the remaining padding.
219                                tmp_buf.clear();
220                                tmp_buf.put_slice(&read_buf.filled()[filled_before..payload_size]);
221                                read_buf.set_filled(filled_before);
222                                continue;
223                            }
224                            *this.state = StderrReaderState::request_next_frame();
225                            return Poll::Ready(Ok(()));
226                        }
227
228                        // Don't read more than remaining + pad bytes, it avoids unnecessary allocations and makes
229                        // internal bookkeeping simpler.
230                        let mut smaller_buf = read_buf.take(safe_remaining);
231                        ready!(self.as_mut().poll_read(cx, &mut smaller_buf))?;
232
233                        let bytes_read = smaller_buf.filled().len();
234
235                        // SAFETY: we just read this number of bytes into read_buf's backing slice above.
236                        unsafe { read_buf.assume_init(bytes_read) };
237                        read_buf.advance(bytes_read);
238                        return Poll::Ready(Ok(()));
239                    } else if *pad > 0 {
240                        // if we've read the whole payload but there is still padding remaining,
241                        // we read it into a stack allocated array
242                        let mut pad_arr = [0u8; 7];
243                        let mut pad_buf = ReadBuf::new(&mut pad_arr);
244                        pad_buf.advance(7 - *pad);
245                        ready!(this.reader.poll_read(cx, &mut pad_buf))?;
246                        *pad = pad_buf.remaining();
247                        if *pad != 0 {
248                            continue;
249                        }
250                    }
251                    // now it's finally time to hand out the read data to the caller and reset to the RequestingNextFrame state.
252                    read_buf.put_slice(tmp_buf);
253                    tmp_buf.clear();
254                    *this.state = StderrReaderState::request_next_frame();
255                    return Poll::Ready(Ok(()));
256                }
257            }
258        }
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use std::time::Duration;
265
266    use hex_literal::hex;
267    use tokio::io::{split, AsyncReadExt, BufReader};
268    use tokio_test::io::Builder;
269
270    use crate::{nix_daemon::framing::StderrReadFramedReader, worker_protocol::STDERR_READ};
271
272    #[tokio::test]
273    async fn test_single_two_byte_read_with_desired_size_ten() {
274        let mock = Builder::new()
275            // The reader should first write STDERR_READ and requested number of bytes into the writer
276            .write(&STDERR_READ.to_le_bytes())
277            .write(&10u64.to_le_bytes())
278            .wait(Duration::ZERO)
279            // The client sent not 10 but 2 bytes
280            .read(&2u64.to_le_bytes())
281            // Immediately followed by the bytes and padding
282            .read("hi".as_bytes())
283            .read(&hex!("0000 0000 0000"))
284            .build();
285        let (r, w) = split(mock);
286        let mut reader = StderrReadFramedReader::new(r, w);
287
288        let mut result = [0u8; 2];
289        let mut buf_reader = BufReader::with_capacity(10, &mut reader);
290        let n = buf_reader.read_exact(&mut result).await.unwrap();
291
292        assert_eq!(2, n);
293        assert_eq!("hi".as_bytes(), result);
294    }
295
296    #[tokio::test]
297    async fn test_single_read_with_padding_delayed() {
298        let mock = Builder::new()
299            // The reader should first write STDERR_READ and requested number of bytes into the writer
300            .write(&STDERR_READ.to_le_bytes())
301            .write(&10u64.to_le_bytes())
302            // The client sent 9 bytes not 10.
303            .read(&9u64.to_le_bytes())
304            // Immeditaly followed by bytes
305            .read(&hex!("0202 0104 ffff ffaa 00"))
306            // Followed by a delayed padding
307            .wait(Duration::ZERO)
308            .read(&hex!("0000 0000 0000 00"))
309            .build();
310        let (r, w) = split(mock);
311        let mut reader = StderrReadFramedReader::new(r, w);
312
313        let mut result = [0u8; 9];
314        let mut buf_reader = BufReader::with_capacity(10, &mut reader);
315        let n = buf_reader.read_exact(&mut result).await.unwrap();
316
317        assert_eq!(9, n);
318        assert_eq!(hex!("0202 0104 ffff ffaa 00"), result);
319    }
320
321    #[tokio::test]
322    async fn test_multiple_consecutive_reads_with_arbitrary_delays() {
323        let mock = Builder::new()
324            // The reader should first write STDERR_READ and requested number of bytes into the writer
325            .write(&STDERR_READ.to_le_bytes())
326            .write(&8192u64.to_le_bytes())
327            .wait(Duration::ZERO)
328            // The client sends 6 bytes 'hello ' plus padding
329            .read(&6u64.to_le_bytes())
330            .wait(Duration::ZERO)
331            .read("hello ".as_bytes())
332            .read(&hex!("0000"))
333            // The reader sends desired length again
334            .write(&STDERR_READ.to_le_bytes())
335            .write(&8192u64.to_le_bytes())
336            // The client sends 11 bytes 'racerunners' with 's' and padding delayed
337            .wait(Duration::ZERO)
338            .read(&11u64.to_le_bytes())
339            .read("racerunner".as_bytes())
340            .wait(Duration::ZERO)
341            .read("s".as_bytes())
342            .read(&hex!("0000 0000"))
343            .wait(Duration::ZERO)
344            .read(&hex!("00"))
345            .write(&STDERR_READ.to_le_bytes())
346            .write(&8192u64.to_le_bytes())
347            .wait(Duration::ZERO)
348            .read(&0u64.to_le_bytes())
349            .build();
350        let (r, w) = split(mock);
351        let mut reader = StderrReadFramedReader::new(r, w);
352
353        let mut res = String::new();
354        let mut buf_reader = BufReader::with_capacity(8192, &mut reader);
355        let n = buf_reader.read_to_string(&mut res).await.unwrap();
356
357        assert_eq!(17, n);
358        assert_eq!("hello racerunners", &res);
359    }
360
361    #[tokio::test]
362    async fn test_single_read_where_writing_stderr_and_desired_size_take_more_than_one_write() {
363        let stderr_bytes = STDERR_READ.to_le_bytes();
364        let length_bytes = 10u64.to_le_bytes();
365        let mock = Builder::new()
366            .write(&stderr_bytes[..4])
367            .wait(Duration::ZERO)
368            .write(&stderr_bytes[4..])
369            .wait(Duration::ZERO)
370            .write(&length_bytes[..4])
371            .wait(Duration::ZERO)
372            .write(&length_bytes[4..])
373            .wait(Duration::ZERO)
374            // The client sent not 10 but 2 bytes
375            .read(&2u64.to_le_bytes())
376            // Immediately followed by the bytes and padding
377            .read("hi".as_bytes())
378            .read(&hex!("0000 0000 0000"))
379            .build();
380        let (r, w) = split(mock);
381        let mut reader = StderrReadFramedReader::new(r, w);
382
383        let mut result = [0u8; 2];
384        let mut buf_reader = BufReader::with_capacity(10, &mut reader);
385        let n = buf_reader.read_exact(&mut result).await.unwrap();
386
387        assert_eq!(2, n);
388        assert_eq!("hi".as_bytes(), result);
389    }
390
391    #[tokio::test]
392    async fn hello() {}
393}