nix_compat/wire/bytes/
writer.rs1use pin_project_lite::pin_project;
2use std::task::{Poll, ready};
3
4use tokio::io::AsyncWrite;
5
6use super::{EMPTY_BYTES, LEN_SIZE, padding_len};
7
8pin_project! {
9 pub struct BytesWriter<W>
34 where
35 W: AsyncWrite,
36 {
37 #[pin]
38 inner: W,
39 payload_len: u64,
40 state: BytesPacketPosition,
41 }
42}
43
44#[derive(Clone, Debug, PartialEq, Eq)]
54enum BytesPacketPosition {
55 Size(usize),
56 Payload(u64),
57 Padding(usize),
58}
59
60impl<W> BytesWriter<W>
61where
62 W: AsyncWrite,
63{
64 pub fn new(w: W, payload_len: u64) -> Self {
66 Self {
67 inner: w,
68 payload_len,
69 state: BytesPacketPosition::Size(0),
70 }
71 }
72}
73
74#[inline]
76fn ensure_nonzero_bytes_written(bytes_written: usize) -> Result<usize, std::io::Error> {
77 if bytes_written == 0 {
78 Err(std::io::Error::new(
79 std::io::ErrorKind::WriteZero,
80 "underlying writer accepted 0 bytes",
81 ))
82 } else {
83 Ok(bytes_written)
84 }
85}
86
87impl<W> AsyncWrite for BytesWriter<W>
88where
89 W: AsyncWrite,
90{
91 fn poll_write(
92 self: std::pin::Pin<&mut Self>,
93 cx: &mut std::task::Context<'_>,
94 buf: &[u8],
95 ) -> Poll<Result<usize, std::io::Error>> {
96 let mut this = self.project();
98
99 loop {
100 match *this.state {
101 BytesPacketPosition::Size(LEN_SIZE) => unreachable!(),
102 BytesPacketPosition::Size(pos) => {
103 let size_field = &this.payload_len.to_le_bytes();
104
105 let bytes_written = ensure_nonzero_bytes_written(ready!(
106 this.inner.as_mut().poll_write(cx, &size_field[pos..])
107 )?)?;
108
109 let new_pos = pos + bytes_written;
110 if new_pos == LEN_SIZE {
111 *this.state = BytesPacketPosition::Payload(0);
112 } else {
113 *this.state = BytesPacketPosition::Size(new_pos);
114 }
115 }
116 BytesPacketPosition::Payload(pos) => {
117 if pos + (buf.len() as u64) > *this.payload_len {
119 return Poll::Ready(Err(std::io::Error::new(
120 std::io::ErrorKind::InvalidData,
121 "tried to write excess bytes",
122 )));
123 }
124 let bytes_written = ready!(this.inner.as_mut().poll_write(cx, buf))?;
125 ensure_nonzero_bytes_written(bytes_written)?;
126 let new_pos = pos + (bytes_written as u64);
127 if new_pos == *this.payload_len {
128 *this.state = BytesPacketPosition::Padding(0)
129 } else {
130 *this.state = BytesPacketPosition::Payload(new_pos)
131 }
132
133 return Poll::Ready(Ok(bytes_written));
134 }
135 BytesPacketPosition::Padding(_pos) => {
137 return Poll::Ready(Err(std::io::Error::new(
138 std::io::ErrorKind::InvalidData,
139 "tried to write excess bytes",
140 )));
141 }
142 }
143 }
144 }
145
146 fn poll_flush(
147 self: std::pin::Pin<&mut Self>,
148 cx: &mut std::task::Context<'_>,
149 ) -> Poll<Result<(), std::io::Error>> {
150 let mut this = self.project();
151
152 loop {
153 match *this.state {
154 BytesPacketPosition::Size(LEN_SIZE) => unreachable!(),
155 BytesPacketPosition::Size(pos) => {
156 let size_field = &this.payload_len.to_le_bytes()[..];
158 let bytes_written = ensure_nonzero_bytes_written(ready!(
159 this.inner.as_mut().poll_write(cx, &size_field[pos..])
160 )?)?;
161 let new_pos = pos + bytes_written;
162 if new_pos == LEN_SIZE {
163 *this.state = BytesPacketPosition::Payload(0);
165 } else {
166 *this.state = BytesPacketPosition::Size(new_pos);
167 }
168 }
169 BytesPacketPosition::Payload(_pos) => {
170 if *this.payload_len == 0 {
175 *this.state = BytesPacketPosition::Padding(0);
176 } else {
177 break;
178 }
179 }
180 BytesPacketPosition::Padding(pos) => {
181 let total_padding_len = padding_len(*this.payload_len) as usize;
183
184 if pos != total_padding_len {
185 let bytes_written = ensure_nonzero_bytes_written(ready!(
186 this.inner
187 .as_mut()
188 .poll_write(cx, &EMPTY_BYTES[pos..total_padding_len])
189 )?)?;
190 *this.state = BytesPacketPosition::Padding(pos + bytes_written);
191 } else {
192 break;
194 }
195 }
196 }
197 }
198 this.inner.as_mut().poll_flush(cx)
200 }
201
202 fn poll_shutdown(
203 mut self: std::pin::Pin<&mut Self>,
204 cx: &mut std::task::Context<'_>,
205 ) -> Poll<Result<(), std::io::Error>> {
206 ready!(self.as_mut().poll_flush(cx))?;
208
209 let this = self.project();
210
211 if let BytesPacketPosition::Padding(pos) = *this.state {
214 let padding_len = padding_len(*this.payload_len) as usize;
215 if padding_len == pos {
216 return this.inner.poll_shutdown(cx);
218 }
219 }
220
221 ready!(this.inner.poll_shutdown(cx))?;
223
224 Poll::Ready(Err(std::io::Error::new(
226 std::io::ErrorKind::BrokenPipe,
227 "unclean shutdown",
228 )))
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use std::sync::LazyLock;
235 use std::time::Duration;
236
237 use crate::wire::bytes::write_bytes;
238 use hex_literal::hex;
239 use tokio::io::AsyncWriteExt;
240 use tokio_test::{assert_err, assert_ok, io::Builder};
241
242 use super::*;
243
244 pub static LARGE_PAYLOAD: LazyLock<Vec<u8>> =
245 LazyLock::new(|| (0..255).collect::<Vec<u8>>().repeat(4 * 1024));
246
247 async fn produce_exp_bytes(payload: &[u8]) -> Vec<u8> {
250 let mut exp = vec![];
251 write_bytes(&mut exp, payload).await.unwrap();
252 exp
253 }
254
255 #[tokio::test]
257 async fn write_empty() {
258 let payload = &[];
259 let mut mock = Builder::new()
260 .write(&produce_exp_bytes(payload).await)
261 .build();
262
263 let mut w = BytesWriter::new(&mut mock, 0);
264 assert_ok!(w.write_all(&[]).await, "write all data");
265 assert_ok!(w.flush().await, "flush");
266 }
267
268 #[tokio::test]
270 async fn write_empty_only_flush() {
271 let payload = &[];
272 let mut mock = Builder::new()
273 .write(&produce_exp_bytes(payload).await)
274 .build();
275
276 let mut w = BytesWriter::new(&mut mock, 0);
277 assert_ok!(w.flush().await, "flush");
278 }
279
280 #[tokio::test]
282 async fn write_empty_only_shutdown() {
283 let payload = &[];
284 let mut mock = Builder::new()
285 .write(&produce_exp_bytes(payload).await)
286 .build();
287
288 let mut w = BytesWriter::new(&mut mock, 0);
289 assert_ok!(w.shutdown().await, "shutdown");
290 }
291
292 #[tokio::test]
294 async fn write_1b() {
295 let payload = &[0xff];
296
297 let mut mock = Builder::new()
298 .write(&produce_exp_bytes(payload).await)
299 .build();
300
301 let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
302 assert_ok!(w.write_all(payload).await);
303 assert_ok!(w.flush().await, "flush");
304 }
305
306 #[tokio::test]
308 async fn write_8b() {
309 let payload = &hex!("0001020304050607");
310
311 let mut mock = Builder::new()
312 .write(&produce_exp_bytes(payload).await)
313 .build();
314
315 let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
316 assert_ok!(w.write_all(payload).await);
317 assert_ok!(w.flush().await, "flush");
318 }
319
320 #[tokio::test]
322 async fn write_9b() {
323 let payload = &hex!("000102030405060708");
324
325 let mut mock = Builder::new()
326 .write(&produce_exp_bytes(payload).await)
327 .build();
328
329 let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
330 assert_ok!(w.write_all(payload).await);
331 assert_ok!(w.flush().await, "flush");
332 }
333
334 #[tokio::test]
337 async fn write_9b_flush() {
338 let payload = &hex!("000102030405060708");
339 let exp_bytes = produce_exp_bytes(payload).await;
340
341 let mut mock = Builder::new().write(&exp_bytes).build();
342
343 let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
344 assert_ok!(w.flush().await);
345
346 assert_ok!(w.write_all(&payload[..4]).await);
347 assert_ok!(w.flush().await);
348
349 assert_ok!(w.write_all(&[]).await);
351 assert_ok!(w.flush().await);
352
353 assert_ok!(w.write_all(&payload[4..]).await);
354 assert_ok!(w.flush().await);
355 assert_ok!(w.shutdown().await);
356 }
357
358 #[tokio::test]
363 async fn write_9b_write_padding_2steps() {
364 let payload = &hex!("000102030405060708");
365 let exp_bytes = produce_exp_bytes(payload).await;
366
367 let mut mock = Builder::new()
368 .write(&exp_bytes[0..8]) .write(&exp_bytes[8..17]) .write(&exp_bytes[17..19]) .wait(Duration::from_nanos(1))
373 .write(&hex!("0000000000ffff")) .build();
375
376 let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
377 assert_ok!(w.write_all(&payload[..]).await);
378 assert_ok!(w.flush().await);
379 assert_ok!(mock.write_all(&hex!("ffff")).await);
381 }
382
383 #[tokio::test]
385 async fn write_1m() {
386 let payload = LARGE_PAYLOAD.as_slice();
387 let exp_bytes = produce_exp_bytes(payload).await;
388
389 let mut mock = Builder::new().write(&exp_bytes).build();
390 let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
391
392 assert_ok!(w.write_all(payload).await);
393 assert_ok!(w.flush().await, "flush");
394 }
395
396 #[tokio::test]
399 async fn write_shutdown_without_flush_end() {
400 let payload = &[0xf0, 0xff];
401 let exp_bytes = produce_exp_bytes(payload).await;
402
403 let mut mock = Builder::new().write(&exp_bytes).build();
404 let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
405
406 assert_ok!(w.flush().await);
408
409 assert_ok!(w.write_all(payload).await);
411
412 assert_ok!(w.shutdown().await);
414 }
415
416 #[tokio::test]
418 async fn write_more_than_signalled_fail() {
419 let mut buf = Vec::new();
420 let mut w = BytesWriter::new(&mut buf, 2);
421
422 assert_err!(w.write_all(&hex!("000102")).await);
423 }
424 #[tokio::test]
426 async fn write_more_than_signalled_split_fail() {
427 let mut buf = Vec::new();
428 let mut w = BytesWriter::new(&mut buf, 2);
429
430 assert_ok!(w.write_all(&hex!("0001")).await);
432
433 assert_err!(w.write_all(&hex!("02")).await);
435 }
436
437 #[tokio::test]
440 async fn write_more_than_signalled_flush_fail() {
441 let mut buf = Vec::new();
442 let mut w = BytesWriter::new(&mut buf, 2);
443
444 assert_ok!(w.write_all(&hex!("0001")).await);
446 assert_ok!(w.flush().await);
447
448 assert_err!(w.write_all(&hex!("02")).await);
450 }
451
452 #[tokio::test]
457 async fn premature_shutdown() {
458 let payload = &[0xf0, 0xff];
459 let mut buf = Vec::new();
460 let mut w = BytesWriter::new(&mut buf, payload.len() as u64);
461
462 assert_ok!(w.flush().await);
464
465 assert_ok!(w.write_all(&payload[0..1]).await);
467
468 assert_err!(w.shutdown().await);
470 }
471
472 #[tokio::test]
475 async fn inner_writer_fail_during_size_firstwrite() {
476 let payload = &[0xf0];
477
478 let mut mock = Builder::new()
479 .write(&1u32.to_le_bytes())
480 .write_error(std::io::Error::other("🍿"))
481 .build();
482 let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
483
484 assert_err!(w.write_all(payload).await);
485 }
486
487 #[tokio::test]
490 async fn inner_writer_fail_during_size_initial_flush() {
491 let payload = &[0xf0];
492
493 let mut mock = Builder::new()
494 .write(&1u32.to_le_bytes())
495 .write_error(std::io::Error::other("🍿"))
496 .build();
497 let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
498
499 assert_err!(w.flush().await);
500 }
501
502 #[tokio::test]
505 async fn inner_writer_fail_during_write() {
506 let payload = &hex!("f0ff");
507
508 let mut mock = Builder::new()
509 .write(&2u64.to_le_bytes())
510 .write(&hex!("f0"))
511 .write_error(std::io::Error::other("🍿"))
512 .build();
513 let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
514
515 assert_ok!(w.write(&hex!("f0")).await);
516 assert_err!(w.write(&hex!("ff")).await);
517 }
518
519 #[tokio::test]
522 async fn inner_writer_fail_during_padding_flush() {
523 let payload = &hex!("f0");
524
525 let mut mock = Builder::new()
526 .write(&1u64.to_le_bytes())
527 .write(&hex!("f0"))
528 .write(&hex!("00"))
529 .write_error(std::io::Error::other("🍿"))
530 .build();
531 let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
532
533 assert_ok!(w.write(&hex!("f0")).await);
534 assert_err!(w.flush().await);
535 }
536}