nix_compat/wire/bytes/
writer.rs

1use pin_project_lite::pin_project;
2use std::task::{Poll, ready};
3
4use tokio::io::AsyncWrite;
5
6use super::{EMPTY_BYTES, LEN_SIZE, padding_len};
7
8pin_project! {
9    /// Writes a "bytes wire packet" to the underlying writer.
10    /// The format is the same as in [crate::wire::bytes::write_bytes],
11    /// however this structure provides a [AsyncWrite] interface,
12    /// allowing to not having to pass around the entire payload in memory.
13    ///
14    /// It internally takes care of writing (non-payload) framing (size and
15    /// padding).
16    ///
17    /// During construction, the expected payload size needs to be provided.
18    ///
19    /// After writing the payload to it, the user MUST call flush (or shutdown),
20    /// which will validate the written payload size to match, and write the
21    /// necessary padding.
22    ///
23    /// In case flush is not called at the end, invalid data might be sent
24    /// silently.
25    ///
26    /// The underlying writer returning `Ok(0)` is considered an EOF situation,
27    /// which is stronger than the "typically means the underlying object is no
28    /// longer able to accept bytes" interpretation from the docs. If such a
29    /// situation occurs, an error is returned.
30    ///
31    /// The struct holds three fields, the underlying writer, the (expected)
32    /// payload length, and an enum, tracking the state.
33    pub struct BytesWriter<W>
34    where
35        W: AsyncWrite,
36    {
37        #[pin]
38        inner: W,
39        payload_len: u64,
40        state: BytesPacketPosition,
41    }
42}
43
44/// Models the position inside a "bytes wire packet" that the writer is in.
45/// It can be in three different stages, inside size, payload or padding fields.
46/// The number tracks the number of bytes written inside the specific field.
47/// There shall be no ambiguous states, at the end of a stage we immediately
48/// move to the beginning of the next one:
49/// - Size(LEN_SIZE) must be expressed as Payload(0)
50/// - Payload(self.payload_len) must be expressed as Padding(0)
51///
52/// Padding(padding_len) means we're at the end of the bytes wire packet.
53#[derive(Clone, Debug, PartialEq, Eq)]
54enum BytesPacketPosition {
55    Size(usize),
56    Payload(u64),
57    Padding(usize),
58}
59
60impl<W> BytesWriter<W>
61where
62    W: AsyncWrite,
63{
64    /// Constructs a new BytesWriter, using the underlying passed writer.
65    pub fn new(w: W, payload_len: u64) -> Self {
66        Self {
67            inner: w,
68            payload_len,
69            state: BytesPacketPosition::Size(0),
70        }
71    }
72}
73
74/// Returns an error if the passed usize is 0.
75#[inline]
76fn ensure_nonzero_bytes_written(bytes_written: usize) -> Result<usize, std::io::Error> {
77    if bytes_written == 0 {
78        Err(std::io::Error::new(
79            std::io::ErrorKind::WriteZero,
80            "underlying writer accepted 0 bytes",
81        ))
82    } else {
83        Ok(bytes_written)
84    }
85}
86
87impl<W> AsyncWrite for BytesWriter<W>
88where
89    W: AsyncWrite,
90{
91    fn poll_write(
92        self: std::pin::Pin<&mut Self>,
93        cx: &mut std::task::Context<'_>,
94        buf: &[u8],
95    ) -> Poll<Result<usize, std::io::Error>> {
96        // Use a loop, so we can deal with (multiple) state transitions.
97        let mut this = self.project();
98
99        loop {
100            match *this.state {
101                BytesPacketPosition::Size(LEN_SIZE) => unreachable!(),
102                BytesPacketPosition::Size(pos) => {
103                    let size_field = &this.payload_len.to_le_bytes();
104
105                    let bytes_written = ensure_nonzero_bytes_written(ready!(
106                        this.inner.as_mut().poll_write(cx, &size_field[pos..])
107                    )?)?;
108
109                    let new_pos = pos + bytes_written;
110                    if new_pos == LEN_SIZE {
111                        *this.state = BytesPacketPosition::Payload(0);
112                    } else {
113                        *this.state = BytesPacketPosition::Size(new_pos);
114                    }
115                }
116                BytesPacketPosition::Payload(pos) => {
117                    // Ensure we still have space for more payload
118                    if pos + (buf.len() as u64) > *this.payload_len {
119                        return Poll::Ready(Err(std::io::Error::new(
120                            std::io::ErrorKind::InvalidData,
121                            "tried to write excess bytes",
122                        )));
123                    }
124                    let bytes_written = ready!(this.inner.as_mut().poll_write(cx, buf))?;
125                    ensure_nonzero_bytes_written(bytes_written)?;
126                    let new_pos = pos + (bytes_written as u64);
127                    if new_pos == *this.payload_len {
128                        *this.state = BytesPacketPosition::Padding(0)
129                    } else {
130                        *this.state = BytesPacketPosition::Payload(new_pos)
131                    }
132
133                    return Poll::Ready(Ok(bytes_written));
134                }
135                // If we're already in padding state, there should be no more payload left to write!
136                BytesPacketPosition::Padding(_pos) => {
137                    return Poll::Ready(Err(std::io::Error::new(
138                        std::io::ErrorKind::InvalidData,
139                        "tried to write excess bytes",
140                    )));
141                }
142            }
143        }
144    }
145
146    fn poll_flush(
147        self: std::pin::Pin<&mut Self>,
148        cx: &mut std::task::Context<'_>,
149    ) -> Poll<Result<(), std::io::Error>> {
150        let mut this = self.project();
151
152        loop {
153            match *this.state {
154                BytesPacketPosition::Size(LEN_SIZE) => unreachable!(),
155                BytesPacketPosition::Size(pos) => {
156                    // More bytes to write in the size field
157                    let size_field = &this.payload_len.to_le_bytes()[..];
158                    let bytes_written = ensure_nonzero_bytes_written(ready!(
159                        this.inner.as_mut().poll_write(cx, &size_field[pos..])
160                    )?)?;
161                    let new_pos = pos + bytes_written;
162                    if new_pos == LEN_SIZE {
163                        // Size field written, now ready to receive payload
164                        *this.state = BytesPacketPosition::Payload(0);
165                    } else {
166                        *this.state = BytesPacketPosition::Size(new_pos);
167                    }
168                }
169                BytesPacketPosition::Payload(_pos) => {
170                    // If we're at position 0 and want to write 0 bytes of payload
171                    // in total, we can transition to padding.
172                    // Otherwise, break, as we're expecting more payload to
173                    // be written.
174                    if *this.payload_len == 0 {
175                        *this.state = BytesPacketPosition::Padding(0);
176                    } else {
177                        break;
178                    }
179                }
180                BytesPacketPosition::Padding(pos) => {
181                    // Write remaining padding, if there is padding to write.
182                    let total_padding_len = padding_len(*this.payload_len) as usize;
183
184                    if pos != total_padding_len {
185                        let bytes_written = ensure_nonzero_bytes_written(ready!(
186                            this.inner
187                                .as_mut()
188                                .poll_write(cx, &EMPTY_BYTES[pos..total_padding_len])
189                        )?)?;
190                        *this.state = BytesPacketPosition::Padding(pos + bytes_written);
191                    } else {
192                        // everything written, break
193                        break;
194                    }
195                }
196            }
197        }
198        // Flush the underlying writer.
199        this.inner.as_mut().poll_flush(cx)
200    }
201
202    fn poll_shutdown(
203        mut self: std::pin::Pin<&mut Self>,
204        cx: &mut std::task::Context<'_>,
205    ) -> Poll<Result<(), std::io::Error>> {
206        // Call flush.
207        ready!(self.as_mut().poll_flush(cx))?;
208
209        let this = self.project();
210
211        // After a flush, being inside the padding state, and at the end of the padding
212        // is the only way to prevent a dirty shutdown.
213        if let BytesPacketPosition::Padding(pos) = *this.state {
214            let padding_len = padding_len(*this.payload_len) as usize;
215            if padding_len == pos {
216                // Shutdown the underlying writer
217                return this.inner.poll_shutdown(cx);
218            }
219        }
220
221        // Shutdown the underlying writer, bubbling up any errors.
222        ready!(this.inner.poll_shutdown(cx))?;
223
224        // return an error about unclean shutdown
225        Poll::Ready(Err(std::io::Error::new(
226            std::io::ErrorKind::BrokenPipe,
227            "unclean shutdown",
228        )))
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use std::sync::LazyLock;
235    use std::time::Duration;
236
237    use crate::wire::bytes::write_bytes;
238    use hex_literal::hex;
239    use tokio::io::AsyncWriteExt;
240    use tokio_test::{assert_err, assert_ok, io::Builder};
241
242    use super::*;
243
244    pub static LARGE_PAYLOAD: LazyLock<Vec<u8>> =
245        LazyLock::new(|| (0..255).collect::<Vec<u8>>().repeat(4 * 1024));
246
247    /// Helper function, calling the (simpler) write_bytes with the payload.
248    /// We use this to create data we want to see on the wire.
249    async fn produce_exp_bytes(payload: &[u8]) -> Vec<u8> {
250        let mut exp = vec![];
251        write_bytes(&mut exp, payload).await.unwrap();
252        exp
253    }
254
255    /// Write an empty bytes packet.
256    #[tokio::test]
257    async fn write_empty() {
258        let payload = &[];
259        let mut mock = Builder::new()
260            .write(&produce_exp_bytes(payload).await)
261            .build();
262
263        let mut w = BytesWriter::new(&mut mock, 0);
264        assert_ok!(w.write_all(&[]).await, "write all data");
265        assert_ok!(w.flush().await, "flush");
266    }
267
268    /// Write an empty bytes packet, not calling write.
269    #[tokio::test]
270    async fn write_empty_only_flush() {
271        let payload = &[];
272        let mut mock = Builder::new()
273            .write(&produce_exp_bytes(payload).await)
274            .build();
275
276        let mut w = BytesWriter::new(&mut mock, 0);
277        assert_ok!(w.flush().await, "flush");
278    }
279
280    /// Write an empty bytes packet, not calling write or flush, only shutdown.
281    #[tokio::test]
282    async fn write_empty_only_shutdown() {
283        let payload = &[];
284        let mut mock = Builder::new()
285            .write(&produce_exp_bytes(payload).await)
286            .build();
287
288        let mut w = BytesWriter::new(&mut mock, 0);
289        assert_ok!(w.shutdown().await, "shutdown");
290    }
291
292    /// Write a 1 bytes packet
293    #[tokio::test]
294    async fn write_1b() {
295        let payload = &[0xff];
296
297        let mut mock = Builder::new()
298            .write(&produce_exp_bytes(payload).await)
299            .build();
300
301        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
302        assert_ok!(w.write_all(payload).await);
303        assert_ok!(w.flush().await, "flush");
304    }
305
306    /// Write a 8 bytes payload (no padding)
307    #[tokio::test]
308    async fn write_8b() {
309        let payload = &hex!("0001020304050607");
310
311        let mut mock = Builder::new()
312            .write(&produce_exp_bytes(payload).await)
313            .build();
314
315        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
316        assert_ok!(w.write_all(payload).await);
317        assert_ok!(w.flush().await, "flush");
318    }
319
320    /// Write a 9 bytes payload (7 bytes padding)
321    #[tokio::test]
322    async fn write_9b() {
323        let payload = &hex!("000102030405060708");
324
325        let mut mock = Builder::new()
326            .write(&produce_exp_bytes(payload).await)
327            .build();
328
329        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
330        assert_ok!(w.write_all(payload).await);
331        assert_ok!(w.flush().await, "flush");
332    }
333
334    /// Write a 9 bytes packet very granularly, with a lot of flushing in between,
335    /// and a shutdown at the end.
336    #[tokio::test]
337    async fn write_9b_flush() {
338        let payload = &hex!("000102030405060708");
339        let exp_bytes = produce_exp_bytes(payload).await;
340
341        let mut mock = Builder::new().write(&exp_bytes).build();
342
343        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
344        assert_ok!(w.flush().await);
345
346        assert_ok!(w.write_all(&payload[..4]).await);
347        assert_ok!(w.flush().await);
348
349        // empty write, cause why not
350        assert_ok!(w.write_all(&[]).await);
351        assert_ok!(w.flush().await);
352
353        assert_ok!(w.write_all(&payload[4..]).await);
354        assert_ok!(w.flush().await);
355        assert_ok!(w.shutdown().await);
356    }
357
358    /// Write a 9 bytes packet, but cause the sink to only accept half of the
359    /// padding, ensuring we correctly write (only) the rest of the padding later.
360    /// We write another 2 bytes of "bait", where a faulty implementation (pre
361    /// cl/11384) would put too many null bytes.
362    #[tokio::test]
363    async fn write_9b_write_padding_2steps() {
364        let payload = &hex!("000102030405060708");
365        let exp_bytes = produce_exp_bytes(payload).await;
366
367        let mut mock = Builder::new()
368            .write(&exp_bytes[0..8]) // size
369            .write(&exp_bytes[8..17]) // payload
370            .write(&exp_bytes[17..19]) // padding (2 of 7 bytes)
371            // insert a wait to prevent Mock from merging the two writes into one
372            .wait(Duration::from_nanos(1))
373            .write(&hex!("0000000000ffff")) // padding (5 of 7 bytes, plus 2 bytes of "bait")
374            .build();
375
376        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
377        assert_ok!(w.write_all(&payload[..]).await);
378        assert_ok!(w.flush().await);
379        // Write bait
380        assert_ok!(mock.write_all(&hex!("ffff")).await);
381    }
382
383    /// Write a larger bytes packet
384    #[tokio::test]
385    async fn write_1m() {
386        let payload = LARGE_PAYLOAD.as_slice();
387        let exp_bytes = produce_exp_bytes(payload).await;
388
389        let mut mock = Builder::new().write(&exp_bytes).build();
390        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
391
392        assert_ok!(w.write_all(payload).await);
393        assert_ok!(w.flush().await, "flush");
394    }
395
396    /// Not calling flush at the end, but shutdown is also ok if we wrote all
397    /// bytes we promised to write (as shutdown implies flush)
398    #[tokio::test]
399    async fn write_shutdown_without_flush_end() {
400        let payload = &[0xf0, 0xff];
401        let exp_bytes = produce_exp_bytes(payload).await;
402
403        let mut mock = Builder::new().write(&exp_bytes).build();
404        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
405
406        // call flush to write the size field
407        assert_ok!(w.flush().await);
408
409        // write payload
410        assert_ok!(w.write_all(payload).await);
411
412        // call shutdown
413        assert_ok!(w.shutdown().await);
414    }
415
416    /// Writing more bytes than previously signalled should fail.
417    #[tokio::test]
418    async fn write_more_than_signalled_fail() {
419        let mut buf = Vec::new();
420        let mut w = BytesWriter::new(&mut buf, 2);
421
422        assert_err!(w.write_all(&hex!("000102")).await);
423    }
424    /// Writing more bytes than previously signalled, but in two parts
425    #[tokio::test]
426    async fn write_more_than_signalled_split_fail() {
427        let mut buf = Vec::new();
428        let mut w = BytesWriter::new(&mut buf, 2);
429
430        // write two bytes
431        assert_ok!(w.write_all(&hex!("0001")).await);
432
433        // write the excess byte.
434        assert_err!(w.write_all(&hex!("02")).await);
435    }
436
437    /// Writing more bytes than previously signalled, but flushing after the
438    /// signalled amount should fail.
439    #[tokio::test]
440    async fn write_more_than_signalled_flush_fail() {
441        let mut buf = Vec::new();
442        let mut w = BytesWriter::new(&mut buf, 2);
443
444        // write two bytes, then flush
445        assert_ok!(w.write_all(&hex!("0001")).await);
446        assert_ok!(w.flush().await);
447
448        // write the excess byte.
449        assert_err!(w.write_all(&hex!("02")).await);
450    }
451
452    /// Calling shutdown while not having written all bytes that were promised
453    /// returns an error.
454    /// Note there's still cases of silent corruption if the user doesn't call
455    /// shutdown explicitly (only drops).
456    #[tokio::test]
457    async fn premature_shutdown() {
458        let payload = &[0xf0, 0xff];
459        let mut buf = Vec::new();
460        let mut w = BytesWriter::new(&mut buf, payload.len() as u64);
461
462        // call flush to write the size field
463        assert_ok!(w.flush().await);
464
465        // write half of the payload (!)
466        assert_ok!(w.write_all(&payload[0..1]).await);
467
468        // call shutdown, ensure it fails
469        assert_err!(w.shutdown().await);
470    }
471
472    /// Write to a Writer that fails to write during the size packet (after 4 bytes).
473    /// Ensure this error gets propagated on the first call to write.
474    #[tokio::test]
475    async fn inner_writer_fail_during_size_firstwrite() {
476        let payload = &[0xf0];
477
478        let mut mock = Builder::new()
479            .write(&1u32.to_le_bytes())
480            .write_error(std::io::Error::other("🍿"))
481            .build();
482        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
483
484        assert_err!(w.write_all(payload).await);
485    }
486
487    /// Write to a Writer that fails to write during the size packet (after 4 bytes).
488    /// Ensure this error gets propagated during an initial flush
489    #[tokio::test]
490    async fn inner_writer_fail_during_size_initial_flush() {
491        let payload = &[0xf0];
492
493        let mut mock = Builder::new()
494            .write(&1u32.to_le_bytes())
495            .write_error(std::io::Error::other("🍿"))
496            .build();
497        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
498
499        assert_err!(w.flush().await);
500    }
501
502    /// Write to a writer that fails to write during the payload (after 9 bytes).
503    /// Ensure this error gets propagated when we're writing this byte.
504    #[tokio::test]
505    async fn inner_writer_fail_during_write() {
506        let payload = &hex!("f0ff");
507
508        let mut mock = Builder::new()
509            .write(&2u64.to_le_bytes())
510            .write(&hex!("f0"))
511            .write_error(std::io::Error::other("🍿"))
512            .build();
513        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
514
515        assert_ok!(w.write(&hex!("f0")).await);
516        assert_err!(w.write(&hex!("ff")).await);
517    }
518
519    /// Write to a writer that fails to write during the padding (after 10 bytes).
520    /// Ensure this error gets propagated during a flush.
521    #[tokio::test]
522    async fn inner_writer_fail_during_padding_flush() {
523        let payload = &hex!("f0");
524
525        let mut mock = Builder::new()
526            .write(&1u64.to_le_bytes())
527            .write(&hex!("f0"))
528            .write(&hex!("00"))
529            .write_error(std::io::Error::other("🍿"))
530            .build();
531        let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
532
533        assert_ok!(w.write(&hex!("f0")).await);
534        assert_err!(w.flush().await);
535    }
536}