nix_compat/wire/bytes/
writer.rs

1use pin_project_lite::pin_project;
2use std::task::{ready, Poll};
3
4use tokio::io::AsyncWrite;
5
6use super::{padding_len, EMPTY_BYTES, LEN_SIZE};
7
8pin_project! {
9    /// Writes a "bytes wire packet" to the underlying writer.
10    /// The format is the same as in [crate::wire::bytes::write_bytes],
11    /// however this structure provides a [AsyncWrite] interface,
12    /// allowing to not having to pass around the entire payload in memory.
13    ///
14    /// It internally takes care of writing (non-payload) framing (size and
15    /// padding).
16    ///
17    /// During construction, the expected payload size needs to be provided.
18    ///
19    /// After writing the payload to it, the user MUST call flush (or shutdown),
20    /// which will validate the written payload size to match, and write the
21    /// necessary padding.
22    ///
23    /// In case flush is not called at the end, invalid data might be sent
24    /// silently.
25    ///
26    /// The underlying writer returning `Ok(0)` is considered an EOF situation,
27    /// which is stronger than the "typically means the underlying object is no
28    /// longer able to accept bytes" interpretation from the docs. If such a
29    /// situation occurs, an error is returned.
30    ///
31    /// The struct holds three fields, the underlying writer, the (expected)
32    /// payload length, and an enum, tracking the state.
33    pub struct BytesWriter<W>
34    where
35        W: AsyncWrite,
36    {
37        #[pin]
38        inner: W,
39        payload_len: u64,
40        state: BytesPacketPosition,
41    }
42}
43
44/// Models the position inside a "bytes wire packet" that the writer is in.
45/// It can be in three different stages, inside size, payload or padding fields.
46/// The number tracks the number of bytes written inside the specific field.
47/// There shall be no ambiguous states, at the end of a stage we immediately
48/// move to the beginning of the next one:
49/// - Size(LEN_SIZE) must be expressed as Payload(0)
50/// - Payload(self.payload_len) must be expressed as Padding(0)
51///
52/// Padding(padding_len) means we're at the end of the bytes wire packet.
53#[derive(Clone, Debug, PartialEq, Eq)]
54enum BytesPacketPosition {
55    Size(usize),
56    Payload(u64),
57    Padding(usize),
58}
59
60impl<W> BytesWriter<W>
61where
62    W: AsyncWrite,
63{
64    /// Constructs a new BytesWriter, using the underlying passed writer.
65    pub fn new(w: W, payload_len: u64) -> Self {
66        Self {
67            inner: w,
68            payload_len,
69            state: BytesPacketPosition::Size(0),
70        }
71    }
72}
73
74/// Returns an error if the passed usize is 0.
75#[inline]
76fn ensure_nonzero_bytes_written(bytes_written: usize) -> Result<usize, std::io::Error> {
77    if bytes_written == 0 {
78        Err(std::io::Error::new(
79            std::io::ErrorKind::WriteZero,
80            "underlying writer accepted 0 bytes",
81        ))
82    } else {
83        Ok(bytes_written)
84    }
85}
86
87impl<W> AsyncWrite for BytesWriter<W>
88where
89    W: AsyncWrite,
90{
91    fn poll_write(
92        self: std::pin::Pin<&mut Self>,
93        cx: &mut std::task::Context<'_>,
94        buf: &[u8],
95    ) -> Poll<Result<usize, std::io::Error>> {
96        // Use a loop, so we can deal with (multiple) state transitions.
97        let mut this = self.project();
98
99        loop {
100            match *this.state {
101                BytesPacketPosition::Size(LEN_SIZE) => unreachable!(),
102                BytesPacketPosition::Size(pos) => {
103                    let size_field = &this.payload_len.to_le_bytes();
104
105                    let bytes_written = ensure_nonzero_bytes_written(ready!(this
106                        .inner
107                        .as_mut()
108                        .poll_write(cx, &size_field[pos..]))?)?;
109
110                    let new_pos = pos + bytes_written;
111                    if new_pos == LEN_SIZE {
112                        *this.state = BytesPacketPosition::Payload(0);
113                    } else {
114                        *this.state = BytesPacketPosition::Size(new_pos);
115                    }
116                }
117                BytesPacketPosition::Payload(pos) => {
118                    // Ensure we still have space for more payload
119                    if pos + (buf.len() as u64) > *this.payload_len {
120                        return Poll::Ready(Err(std::io::Error::new(
121                            std::io::ErrorKind::InvalidData,
122                            "tried to write excess bytes",
123                        )));
124                    }
125                    let bytes_written = ready!(this.inner.as_mut().poll_write(cx, buf))?;
126                    ensure_nonzero_bytes_written(bytes_written)?;
127                    let new_pos = pos + (bytes_written as u64);
128                    if new_pos == *this.payload_len {
129                        *this.state = BytesPacketPosition::Padding(0)
130                    } else {
131                        *this.state = BytesPacketPosition::Payload(new_pos)
132                    }
133
134                    return Poll::Ready(Ok(bytes_written));
135                }
136                // If we're already in padding state, there should be no more payload left to write!
137                BytesPacketPosition::Padding(_pos) => {
138                    return Poll::Ready(Err(std::io::Error::new(
139                        std::io::ErrorKind::InvalidData,
140                        "tried to write excess bytes",
141                    )))
142                }
143            }
144        }
145    }
146
147    fn poll_flush(
148        self: std::pin::Pin<&mut Self>,
149        cx: &mut std::task::Context<'_>,
150    ) -> Poll<Result<(), std::io::Error>> {
151        let mut this = self.project();
152
153        loop {
154            match *this.state {
155                BytesPacketPosition::Size(LEN_SIZE) => unreachable!(),
156                BytesPacketPosition::Size(pos) => {
157                    // More bytes to write in the size field
158                    let size_field = &this.payload_len.to_le_bytes()[..];
159                    let bytes_written = ensure_nonzero_bytes_written(ready!(this
160                        .inner
161                        .as_mut()
162                        .poll_write(cx, &size_field[pos..]))?)?;
163                    let new_pos = pos + bytes_written;
164                    if new_pos == LEN_SIZE {
165                        // Size field written, now ready to receive payload
166                        *this.state = BytesPacketPosition::Payload(0);
167                    } else {
168                        *this.state = BytesPacketPosition::Size(new_pos);
169                    }
170                }
171                BytesPacketPosition::Payload(_pos) => {
172                    // If we're at position 0 and want to write 0 bytes of payload
173                    // in total, we can transition to padding.
174                    // Otherwise, break, as we're expecting more payload to
175                    // be written.
176                    if *this.payload_len == 0 {
177                        *this.state = BytesPacketPosition::Padding(0);
178                    } else {
179                        break;
180                    }
181                }
182                BytesPacketPosition::Padding(pos) => {
183                    // Write remaining padding, if there is padding to write.
184                    let total_padding_len = padding_len(*this.payload_len) as usize;
185
186                    if pos != total_padding_len {
187                        let bytes_written = ensure_nonzero_bytes_written(ready!(this
188                            .inner
189                            .as_mut()
190                            .poll_write(cx, &EMPTY_BYTES[pos..total_padding_len]))?)?;
191                        *this.state = BytesPacketPosition::Padding(pos + bytes_written);
192                    } else {
193                        // everything written, break
194                        break;
195                    }
196                }
197            }
198        }
199        // Flush the underlying writer.
200        this.inner.as_mut().poll_flush(cx)
201    }
202
203    fn poll_shutdown(
204        mut self: std::pin::Pin<&mut Self>,
205        cx: &mut std::task::Context<'_>,
206    ) -> Poll<Result<(), std::io::Error>> {
207        // Call flush.
208        ready!(self.as_mut().poll_flush(cx))?;
209
210        let this = self.project();
211
212        // After a flush, being inside the padding state, and at the end of the padding
213        // is the only way to prevent a dirty shutdown.
214        if let BytesPacketPosition::Padding(pos) = *this.state {
215            let padding_len = padding_len(*this.payload_len) as usize;
216            if padding_len == pos {
217                // Shutdown the underlying writer
218                return this.inner.poll_shutdown(cx);
219            }
220        }
221
222        // Shutdown the underlying writer, bubbling up any errors.
223        ready!(this.inner.poll_shutdown(cx))?;
224
225        // return an error about unclean shutdown
226        Poll::Ready(Err(std::io::Error::new(
227            std::io::ErrorKind::BrokenPipe,
228            "unclean shutdown",
229        )))
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use std::sync::LazyLock;
236    use std::time::Duration;
237
238    use crate::wire::bytes::write_bytes;
239    use hex_literal::hex;
240    use tokio::io::AsyncWriteExt;
241    use tokio_test::{assert_err, assert_ok, io::Builder};
242
243    use super::*;
244
245    pub static LARGE_PAYLOAD: LazyLock<Vec<u8>> =
246        LazyLock::new(|| (0..255).collect::<Vec<u8>>().repeat(4 * 1024));
247
248    /// Helper function, calling the (simpler) write_bytes with the payload.
249    /// We use this to create data we want to see on the wire.
250    async fn produce_exp_bytes(payload: &[u8]) -> Vec<u8> {
251        let mut exp = vec![];
252        write_bytes(&mut exp, payload).await.unwrap();
253        exp
254    }
255
256    /// Write an empty bytes packet.
257    #[tokio::test]
258    async fn write_empty() {
259        let payload = &[];
260        let mut mock = Builder::new()
261            .write(&produce_exp_bytes(payload).await)
262            .build();
263
264        let mut w = BytesWriter::new(&mut mock, 0);
265        assert_ok!(w.write_all(&[]).await, "write all data");
266        assert_ok!(w.flush().await, "flush");
267    }
268
269    /// Write an empty bytes packet, not calling write.
270    #[tokio::test]
271    async fn write_empty_only_flush() {
272        let payload = &[];
273        let mut mock = Builder::new()
274            .write(&produce_exp_bytes(payload).await)
275            .build();
276
277        let mut w = BytesWriter::new(&mut mock, 0);
278        assert_ok!(w.flush().await, "flush");
279    }
280
281    /// Write an empty bytes packet, not calling write or flush, only shutdown.
282    #[tokio::test]
283    async fn write_empty_only_shutdown() {
284        let payload = &[];
285        let mut mock = Builder::new()
286            .write(&produce_exp_bytes(payload).await)
287            .build();
288
289        let mut w = BytesWriter::new(&mut mock, 0);
290        assert_ok!(w.shutdown().await, "shutdown");
291    }
292
293    /// Write a 1 bytes packet
294    #[tokio::test]
295    async fn write_1b() {
296        let payload = &[0xff];
297
298        let mut mock = Builder::new()
299            .write(&produce_exp_bytes(payload).await)
300            .build();
301
302        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
303        assert_ok!(w.write_all(payload).await);
304        assert_ok!(w.flush().await, "flush");
305    }
306
307    /// Write a 8 bytes payload (no padding)
308    #[tokio::test]
309    async fn write_8b() {
310        let payload = &hex!("0001020304050607");
311
312        let mut mock = Builder::new()
313            .write(&produce_exp_bytes(payload).await)
314            .build();
315
316        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
317        assert_ok!(w.write_all(payload).await);
318        assert_ok!(w.flush().await, "flush");
319    }
320
321    /// Write a 9 bytes payload (7 bytes padding)
322    #[tokio::test]
323    async fn write_9b() {
324        let payload = &hex!("000102030405060708");
325
326        let mut mock = Builder::new()
327            .write(&produce_exp_bytes(payload).await)
328            .build();
329
330        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
331        assert_ok!(w.write_all(payload).await);
332        assert_ok!(w.flush().await, "flush");
333    }
334
335    /// Write a 9 bytes packet very granularly, with a lot of flushing in between,
336    /// and a shutdown at the end.
337    #[tokio::test]
338    async fn write_9b_flush() {
339        let payload = &hex!("000102030405060708");
340        let exp_bytes = produce_exp_bytes(payload).await;
341
342        let mut mock = Builder::new().write(&exp_bytes).build();
343
344        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
345        assert_ok!(w.flush().await);
346
347        assert_ok!(w.write_all(&payload[..4]).await);
348        assert_ok!(w.flush().await);
349
350        // empty write, cause why not
351        assert_ok!(w.write_all(&[]).await);
352        assert_ok!(w.flush().await);
353
354        assert_ok!(w.write_all(&payload[4..]).await);
355        assert_ok!(w.flush().await);
356        assert_ok!(w.shutdown().await);
357    }
358
359    /// Write a 9 bytes packet, but cause the sink to only accept half of the
360    /// padding, ensuring we correctly write (only) the rest of the padding later.
361    /// We write another 2 bytes of "bait", where a faulty implementation (pre
362    /// cl/11384) would put too many null bytes.
363    #[tokio::test]
364    async fn write_9b_write_padding_2steps() {
365        let payload = &hex!("000102030405060708");
366        let exp_bytes = produce_exp_bytes(payload).await;
367
368        let mut mock = Builder::new()
369            .write(&exp_bytes[0..8]) // size
370            .write(&exp_bytes[8..17]) // payload
371            .write(&exp_bytes[17..19]) // padding (2 of 7 bytes)
372            // insert a wait to prevent Mock from merging the two writes into one
373            .wait(Duration::from_nanos(1))
374            .write(&hex!("0000000000ffff")) // padding (5 of 7 bytes, plus 2 bytes of "bait")
375            .build();
376
377        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
378        assert_ok!(w.write_all(&payload[..]).await);
379        assert_ok!(w.flush().await);
380        // Write bait
381        assert_ok!(mock.write_all(&hex!("ffff")).await);
382    }
383
384    /// Write a larger bytes packet
385    #[tokio::test]
386    async fn write_1m() {
387        let payload = LARGE_PAYLOAD.as_slice();
388        let exp_bytes = produce_exp_bytes(payload).await;
389
390        let mut mock = Builder::new().write(&exp_bytes).build();
391        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
392
393        assert_ok!(w.write_all(payload).await);
394        assert_ok!(w.flush().await, "flush");
395    }
396
397    /// Not calling flush at the end, but shutdown is also ok if we wrote all
398    /// bytes we promised to write (as shutdown implies flush)
399    #[tokio::test]
400    async fn write_shutdown_without_flush_end() {
401        let payload = &[0xf0, 0xff];
402        let exp_bytes = produce_exp_bytes(payload).await;
403
404        let mut mock = Builder::new().write(&exp_bytes).build();
405        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
406
407        // call flush to write the size field
408        assert_ok!(w.flush().await);
409
410        // write payload
411        assert_ok!(w.write_all(payload).await);
412
413        // call shutdown
414        assert_ok!(w.shutdown().await);
415    }
416
417    /// Writing more bytes than previously signalled should fail.
418    #[tokio::test]
419    async fn write_more_than_signalled_fail() {
420        let mut buf = Vec::new();
421        let mut w = BytesWriter::new(&mut buf, 2);
422
423        assert_err!(w.write_all(&hex!("000102")).await);
424    }
425    /// Writing more bytes than previously signalled, but in two parts
426    #[tokio::test]
427    async fn write_more_than_signalled_split_fail() {
428        let mut buf = Vec::new();
429        let mut w = BytesWriter::new(&mut buf, 2);
430
431        // write two bytes
432        assert_ok!(w.write_all(&hex!("0001")).await);
433
434        // write the excess byte.
435        assert_err!(w.write_all(&hex!("02")).await);
436    }
437
438    /// Writing more bytes than previously signalled, but flushing after the
439    /// signalled amount should fail.
440    #[tokio::test]
441    async fn write_more_than_signalled_flush_fail() {
442        let mut buf = Vec::new();
443        let mut w = BytesWriter::new(&mut buf, 2);
444
445        // write two bytes, then flush
446        assert_ok!(w.write_all(&hex!("0001")).await);
447        assert_ok!(w.flush().await);
448
449        // write the excess byte.
450        assert_err!(w.write_all(&hex!("02")).await);
451    }
452
453    /// Calling shutdown while not having written all bytes that were promised
454    /// returns an error.
455    /// Note there's still cases of silent corruption if the user doesn't call
456    /// shutdown explicitly (only drops).
457    #[tokio::test]
458    async fn premature_shutdown() {
459        let payload = &[0xf0, 0xff];
460        let mut buf = Vec::new();
461        let mut w = BytesWriter::new(&mut buf, payload.len() as u64);
462
463        // call flush to write the size field
464        assert_ok!(w.flush().await);
465
466        // write half of the payload (!)
467        assert_ok!(w.write_all(&payload[0..1]).await);
468
469        // call shutdown, ensure it fails
470        assert_err!(w.shutdown().await);
471    }
472
473    /// Write to a Writer that fails to write during the size packet (after 4 bytes).
474    /// Ensure this error gets propagated on the first call to write.
475    #[tokio::test]
476    async fn inner_writer_fail_during_size_firstwrite() {
477        let payload = &[0xf0];
478
479        let mut mock = Builder::new()
480            .write(&1u32.to_le_bytes())
481            .write_error(std::io::Error::new(std::io::ErrorKind::Other, "🍿"))
482            .build();
483        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
484
485        assert_err!(w.write_all(payload).await);
486    }
487
488    /// Write to a Writer that fails to write during the size packet (after 4 bytes).
489    /// Ensure this error gets propagated during an initial flush
490    #[tokio::test]
491    async fn inner_writer_fail_during_size_initial_flush() {
492        let payload = &[0xf0];
493
494        let mut mock = Builder::new()
495            .write(&1u32.to_le_bytes())
496            .write_error(std::io::Error::new(std::io::ErrorKind::Other, "🍿"))
497            .build();
498        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
499
500        assert_err!(w.flush().await);
501    }
502
503    /// Write to a writer that fails to write during the payload (after 9 bytes).
504    /// Ensure this error gets propagated when we're writing this byte.
505    #[tokio::test]
506    async fn inner_writer_fail_during_write() {
507        let payload = &hex!("f0ff");
508
509        let mut mock = Builder::new()
510            .write(&2u64.to_le_bytes())
511            .write(&hex!("f0"))
512            .write_error(std::io::Error::new(std::io::ErrorKind::Other, "🍿"))
513            .build();
514        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
515
516        assert_ok!(w.write(&hex!("f0")).await);
517        assert_err!(w.write(&hex!("ff")).await);
518    }
519
520    /// Write to a writer that fails to write during the padding (after 10 bytes).
521    /// Ensure this error gets propagated during a flush.
522    #[tokio::test]
523    async fn inner_writer_fail_during_padding_flush() {
524        let payload = &hex!("f0");
525
526        let mut mock = Builder::new()
527            .write(&1u64.to_le_bytes())
528            .write(&hex!("f0"))
529            .write(&hex!("00"))
530            .write_error(std::io::Error::new(std::io::ErrorKind::Other, "🍿"))
531            .build();
532        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
533
534        assert_ok!(w.write(&hex!("f0")).await);
535        assert_err!(w.flush().await);
536    }
537}