nix_compat/wire/bytes/
writer.rs1use pin_project_lite::pin_project;
2use std::task::{ready, Poll};
3
4use tokio::io::AsyncWrite;
5
6use super::{padding_len, EMPTY_BYTES, LEN_SIZE};
7
8pin_project! {
9 pub struct BytesWriter<W>
34 where
35 W: AsyncWrite,
36 {
37 #[pin]
38 inner: W,
39 payload_len: u64,
40 state: BytesPacketPosition,
41 }
42}
43
44#[derive(Clone, Debug, PartialEq, Eq)]
54enum BytesPacketPosition {
55 Size(usize),
56 Payload(u64),
57 Padding(usize),
58}
59
60impl<W> BytesWriter<W>
61where
62 W: AsyncWrite,
63{
64 pub fn new(w: W, payload_len: u64) -> Self {
66 Self {
67 inner: w,
68 payload_len,
69 state: BytesPacketPosition::Size(0),
70 }
71 }
72}
73
74#[inline]
76fn ensure_nonzero_bytes_written(bytes_written: usize) -> Result<usize, std::io::Error> {
77 if bytes_written == 0 {
78 Err(std::io::Error::new(
79 std::io::ErrorKind::WriteZero,
80 "underlying writer accepted 0 bytes",
81 ))
82 } else {
83 Ok(bytes_written)
84 }
85}
86
87impl<W> AsyncWrite for BytesWriter<W>
88where
89 W: AsyncWrite,
90{
91 fn poll_write(
92 self: std::pin::Pin<&mut Self>,
93 cx: &mut std::task::Context<'_>,
94 buf: &[u8],
95 ) -> Poll<Result<usize, std::io::Error>> {
96 let mut this = self.project();
98
99 loop {
100 match *this.state {
101 BytesPacketPosition::Size(LEN_SIZE) => unreachable!(),
102 BytesPacketPosition::Size(pos) => {
103 let size_field = &this.payload_len.to_le_bytes();
104
105 let bytes_written = ensure_nonzero_bytes_written(ready!(this
106 .inner
107 .as_mut()
108 .poll_write(cx, &size_field[pos..]))?)?;
109
110 let new_pos = pos + bytes_written;
111 if new_pos == LEN_SIZE {
112 *this.state = BytesPacketPosition::Payload(0);
113 } else {
114 *this.state = BytesPacketPosition::Size(new_pos);
115 }
116 }
117 BytesPacketPosition::Payload(pos) => {
118 if pos + (buf.len() as u64) > *this.payload_len {
120 return Poll::Ready(Err(std::io::Error::new(
121 std::io::ErrorKind::InvalidData,
122 "tried to write excess bytes",
123 )));
124 }
125 let bytes_written = ready!(this.inner.as_mut().poll_write(cx, buf))?;
126 ensure_nonzero_bytes_written(bytes_written)?;
127 let new_pos = pos + (bytes_written as u64);
128 if new_pos == *this.payload_len {
129 *this.state = BytesPacketPosition::Padding(0)
130 } else {
131 *this.state = BytesPacketPosition::Payload(new_pos)
132 }
133
134 return Poll::Ready(Ok(bytes_written));
135 }
136 BytesPacketPosition::Padding(_pos) => {
138 return Poll::Ready(Err(std::io::Error::new(
139 std::io::ErrorKind::InvalidData,
140 "tried to write excess bytes",
141 )))
142 }
143 }
144 }
145 }
146
147 fn poll_flush(
148 self: std::pin::Pin<&mut Self>,
149 cx: &mut std::task::Context<'_>,
150 ) -> Poll<Result<(), std::io::Error>> {
151 let mut this = self.project();
152
153 loop {
154 match *this.state {
155 BytesPacketPosition::Size(LEN_SIZE) => unreachable!(),
156 BytesPacketPosition::Size(pos) => {
157 let size_field = &this.payload_len.to_le_bytes()[..];
159 let bytes_written = ensure_nonzero_bytes_written(ready!(this
160 .inner
161 .as_mut()
162 .poll_write(cx, &size_field[pos..]))?)?;
163 let new_pos = pos + bytes_written;
164 if new_pos == LEN_SIZE {
165 *this.state = BytesPacketPosition::Payload(0);
167 } else {
168 *this.state = BytesPacketPosition::Size(new_pos);
169 }
170 }
171 BytesPacketPosition::Payload(_pos) => {
172 if *this.payload_len == 0 {
177 *this.state = BytesPacketPosition::Padding(0);
178 } else {
179 break;
180 }
181 }
182 BytesPacketPosition::Padding(pos) => {
183 let total_padding_len = padding_len(*this.payload_len) as usize;
185
186 if pos != total_padding_len {
187 let bytes_written = ensure_nonzero_bytes_written(ready!(this
188 .inner
189 .as_mut()
190 .poll_write(cx, &EMPTY_BYTES[pos..total_padding_len]))?)?;
191 *this.state = BytesPacketPosition::Padding(pos + bytes_written);
192 } else {
193 break;
195 }
196 }
197 }
198 }
199 this.inner.as_mut().poll_flush(cx)
201 }
202
203 fn poll_shutdown(
204 mut self: std::pin::Pin<&mut Self>,
205 cx: &mut std::task::Context<'_>,
206 ) -> Poll<Result<(), std::io::Error>> {
207 ready!(self.as_mut().poll_flush(cx))?;
209
210 let this = self.project();
211
212 if let BytesPacketPosition::Padding(pos) = *this.state {
215 let padding_len = padding_len(*this.payload_len) as usize;
216 if padding_len == pos {
217 return this.inner.poll_shutdown(cx);
219 }
220 }
221
222 ready!(this.inner.poll_shutdown(cx))?;
224
225 Poll::Ready(Err(std::io::Error::new(
227 std::io::ErrorKind::BrokenPipe,
228 "unclean shutdown",
229 )))
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use std::sync::LazyLock;
236 use std::time::Duration;
237
238 use crate::wire::bytes::write_bytes;
239 use hex_literal::hex;
240 use tokio::io::AsyncWriteExt;
241 use tokio_test::{assert_err, assert_ok, io::Builder};
242
243 use super::*;
244
245 pub static LARGE_PAYLOAD: LazyLock<Vec<u8>> =
246 LazyLock::new(|| (0..255).collect::<Vec<u8>>().repeat(4 * 1024));
247
248 async fn produce_exp_bytes(payload: &[u8]) -> Vec<u8> {
251 let mut exp = vec![];
252 write_bytes(&mut exp, payload).await.unwrap();
253 exp
254 }
255
256 #[tokio::test]
258 async fn write_empty() {
259 let payload = &[];
260 let mut mock = Builder::new()
261 .write(&produce_exp_bytes(payload).await)
262 .build();
263
264 let mut w = BytesWriter::new(&mut mock, 0);
265 assert_ok!(w.write_all(&[]).await, "write all data");
266 assert_ok!(w.flush().await, "flush");
267 }
268
269 #[tokio::test]
271 async fn write_empty_only_flush() {
272 let payload = &[];
273 let mut mock = Builder::new()
274 .write(&produce_exp_bytes(payload).await)
275 .build();
276
277 let mut w = BytesWriter::new(&mut mock, 0);
278 assert_ok!(w.flush().await, "flush");
279 }
280
281 #[tokio::test]
283 async fn write_empty_only_shutdown() {
284 let payload = &[];
285 let mut mock = Builder::new()
286 .write(&produce_exp_bytes(payload).await)
287 .build();
288
289 let mut w = BytesWriter::new(&mut mock, 0);
290 assert_ok!(w.shutdown().await, "shutdown");
291 }
292
293 #[tokio::test]
295 async fn write_1b() {
296 let payload = &[0xff];
297
298 let mut mock = Builder::new()
299 .write(&produce_exp_bytes(payload).await)
300 .build();
301
302 let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
303 assert_ok!(w.write_all(payload).await);
304 assert_ok!(w.flush().await, "flush");
305 }
306
307 #[tokio::test]
309 async fn write_8b() {
310 let payload = &hex!("0001020304050607");
311
312 let mut mock = Builder::new()
313 .write(&produce_exp_bytes(payload).await)
314 .build();
315
316 let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
317 assert_ok!(w.write_all(payload).await);
318 assert_ok!(w.flush().await, "flush");
319 }
320
321 #[tokio::test]
323 async fn write_9b() {
324 let payload = &hex!("000102030405060708");
325
326 let mut mock = Builder::new()
327 .write(&produce_exp_bytes(payload).await)
328 .build();
329
330 let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
331 assert_ok!(w.write_all(payload).await);
332 assert_ok!(w.flush().await, "flush");
333 }
334
335 #[tokio::test]
338 async fn write_9b_flush() {
339 let payload = &hex!("000102030405060708");
340 let exp_bytes = produce_exp_bytes(payload).await;
341
342 let mut mock = Builder::new().write(&exp_bytes).build();
343
344 let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
345 assert_ok!(w.flush().await);
346
347 assert_ok!(w.write_all(&payload[..4]).await);
348 assert_ok!(w.flush().await);
349
350 assert_ok!(w.write_all(&[]).await);
352 assert_ok!(w.flush().await);
353
354 assert_ok!(w.write_all(&payload[4..]).await);
355 assert_ok!(w.flush().await);
356 assert_ok!(w.shutdown().await);
357 }
358
359 #[tokio::test]
364 async fn write_9b_write_padding_2steps() {
365 let payload = &hex!("000102030405060708");
366 let exp_bytes = produce_exp_bytes(payload).await;
367
368 let mut mock = Builder::new()
369 .write(&exp_bytes[0..8]) .write(&exp_bytes[8..17]) .write(&exp_bytes[17..19]) .wait(Duration::from_nanos(1))
374 .write(&hex!("0000000000ffff")) .build();
376
377 let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
378 assert_ok!(w.write_all(&payload[..]).await);
379 assert_ok!(w.flush().await);
380 assert_ok!(mock.write_all(&hex!("ffff")).await);
382 }
383
384 #[tokio::test]
386 async fn write_1m() {
387 let payload = LARGE_PAYLOAD.as_slice();
388 let exp_bytes = produce_exp_bytes(payload).await;
389
390 let mut mock = Builder::new().write(&exp_bytes).build();
391 let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
392
393 assert_ok!(w.write_all(payload).await);
394 assert_ok!(w.flush().await, "flush");
395 }
396
397 #[tokio::test]
400 async fn write_shutdown_without_flush_end() {
401 let payload = &[0xf0, 0xff];
402 let exp_bytes = produce_exp_bytes(payload).await;
403
404 let mut mock = Builder::new().write(&exp_bytes).build();
405 let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
406
407 assert_ok!(w.flush().await);
409
410 assert_ok!(w.write_all(payload).await);
412
413 assert_ok!(w.shutdown().await);
415 }
416
417 #[tokio::test]
419 async fn write_more_than_signalled_fail() {
420 let mut buf = Vec::new();
421 let mut w = BytesWriter::new(&mut buf, 2);
422
423 assert_err!(w.write_all(&hex!("000102")).await);
424 }
425 #[tokio::test]
427 async fn write_more_than_signalled_split_fail() {
428 let mut buf = Vec::new();
429 let mut w = BytesWriter::new(&mut buf, 2);
430
431 assert_ok!(w.write_all(&hex!("0001")).await);
433
434 assert_err!(w.write_all(&hex!("02")).await);
436 }
437
438 #[tokio::test]
441 async fn write_more_than_signalled_flush_fail() {
442 let mut buf = Vec::new();
443 let mut w = BytesWriter::new(&mut buf, 2);
444
445 assert_ok!(w.write_all(&hex!("0001")).await);
447 assert_ok!(w.flush().await);
448
449 assert_err!(w.write_all(&hex!("02")).await);
451 }
452
453 #[tokio::test]
458 async fn premature_shutdown() {
459 let payload = &[0xf0, 0xff];
460 let mut buf = Vec::new();
461 let mut w = BytesWriter::new(&mut buf, payload.len() as u64);
462
463 assert_ok!(w.flush().await);
465
466 assert_ok!(w.write_all(&payload[0..1]).await);
468
469 assert_err!(w.shutdown().await);
471 }
472
473 #[tokio::test]
476 async fn inner_writer_fail_during_size_firstwrite() {
477 let payload = &[0xf0];
478
479 let mut mock = Builder::new()
480 .write(&1u32.to_le_bytes())
481 .write_error(std::io::Error::new(std::io::ErrorKind::Other, "🍿"))
482 .build();
483 let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
484
485 assert_err!(w.write_all(payload).await);
486 }
487
488 #[tokio::test]
491 async fn inner_writer_fail_during_size_initial_flush() {
492 let payload = &[0xf0];
493
494 let mut mock = Builder::new()
495 .write(&1u32.to_le_bytes())
496 .write_error(std::io::Error::new(std::io::ErrorKind::Other, "🍿"))
497 .build();
498 let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
499
500 assert_err!(w.flush().await);
501 }
502
503 #[tokio::test]
506 async fn inner_writer_fail_during_write() {
507 let payload = &hex!("f0ff");
508
509 let mut mock = Builder::new()
510 .write(&2u64.to_le_bytes())
511 .write(&hex!("f0"))
512 .write_error(std::io::Error::new(std::io::ErrorKind::Other, "🍿"))
513 .build();
514 let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
515
516 assert_ok!(w.write(&hex!("f0")).await);
517 assert_err!(w.write(&hex!("ff")).await);
518 }
519
520 #[tokio::test]
523 async fn inner_writer_fail_during_padding_flush() {
524 let payload = &hex!("f0");
525
526 let mut mock = Builder::new()
527 .write(&1u64.to_le_bytes())
528 .write(&hex!("f0"))
529 .write(&hex!("00"))
530 .write_error(std::io::Error::new(std::io::ErrorKind::Other, "🍿"))
531 .build();
532 let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
533
534 assert_ok!(w.write(&hex!("f0")).await);
535 assert_err!(w.flush().await);
536 }
537}