async_compression/tokio/bufread/generic/
encoder.rs1use core::{
2 pin::Pin,
3 task::{Context, Poll},
4};
5use std::io::{IoSlice, Result};
6
7use crate::{codec::Encode, util::PartialBuffer};
8use futures_core::ready;
9use pin_project_lite::pin_project;
10use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
11
12#[derive(Debug)]
13enum State {
14 Encoding,
15 Flushing,
16 Done,
17}
18
19pin_project! {
20 #[derive(Debug)]
21 pub struct Encoder<R, E> {
22 #[pin]
23 reader: R,
24 encoder: E,
25 state: State,
26 }
27}
28
29impl<R: AsyncBufRead, E: Encode> Encoder<R, E> {
30 pub fn new(reader: R, encoder: E) -> Self {
31 Self {
32 reader,
33 encoder,
34 state: State::Encoding,
35 }
36 }
37}
38
39impl<R, E> Encoder<R, E> {
40 pub fn get_ref(&self) -> &R {
41 &self.reader
42 }
43
44 pub fn get_mut(&mut self) -> &mut R {
45 &mut self.reader
46 }
47
48 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
49 self.project().reader
50 }
51
52 pub(crate) fn get_encoder_ref(&self) -> &E {
53 &self.encoder
54 }
55
56 pub fn into_inner(self) -> R {
57 self.reader
58 }
59}
60impl<R: AsyncBufRead, E: Encode> Encoder<R, E> {
61 fn do_poll_read(
62 self: Pin<&mut Self>,
63 cx: &mut Context<'_>,
64 output: &mut PartialBuffer<&mut [u8]>,
65 ) -> Poll<Result<()>> {
66 let mut this = self.project();
67
68 loop {
69 *this.state = match this.state {
70 State::Encoding => {
71 let input = ready!(this.reader.as_mut().poll_fill_buf(cx))?;
72 if input.is_empty() {
73 State::Flushing
74 } else {
75 let mut input = PartialBuffer::new(input);
76 this.encoder.encode(&mut input, output)?;
77 let len = input.written().len();
78 this.reader.as_mut().consume(len);
79 State::Encoding
80 }
81 }
82
83 State::Flushing => {
84 if this.encoder.finish(output)? {
85 State::Done
86 } else {
87 State::Flushing
88 }
89 }
90
91 State::Done => State::Done,
92 };
93
94 if let State::Done = *this.state {
95 return Poll::Ready(Ok(()));
96 }
97 if output.unwritten().is_empty() {
98 return Poll::Ready(Ok(()));
99 }
100 }
101 }
102}
103
104impl<R: AsyncBufRead, E: Encode> AsyncRead for Encoder<R, E> {
105 fn poll_read(
106 self: Pin<&mut Self>,
107 cx: &mut Context<'_>,
108 buf: &mut ReadBuf<'_>,
109 ) -> Poll<Result<()>> {
110 if buf.remaining() == 0 {
111 return Poll::Ready(Ok(()));
112 }
113
114 let mut output = PartialBuffer::new(buf.initialize_unfilled());
115 match self.do_poll_read(cx, &mut output)? {
116 Poll::Pending if output.written().is_empty() => Poll::Pending,
117 _ => {
118 let len = output.written().len();
119 buf.advance(len);
120 Poll::Ready(Ok(()))
121 }
122 }
123 }
124}
125
126impl<R: AsyncWrite, E> AsyncWrite for Encoder<R, E> {
127 fn poll_write(
128 mut self: Pin<&mut Self>,
129 cx: &mut Context<'_>,
130 buf: &[u8],
131 ) -> Poll<Result<usize>> {
132 self.get_pin_mut().poll_write(cx, buf)
133 }
134
135 fn poll_write_vectored(
136 mut self: Pin<&mut Self>,
137 cx: &mut Context<'_>,
138 mut bufs: &[IoSlice<'_>],
139 ) -> Poll<Result<usize>> {
140 self.get_pin_mut().poll_write_vectored(cx, bufs)
141 }
142
143 fn is_write_vectored(&self) -> bool {
144 self.get_ref().is_write_vectored()
145 }
146
147 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
148 self.get_pin_mut().poll_flush(cx)
149 }
150
151 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
152 self.get_pin_mut().poll_shutdown(cx)
153 }
154}