async_compression/tokio/write/generic/
encoder.rs1use std::{
2 io,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use crate::{
8 codec::Encode,
9 tokio::write::{AsyncBufWrite, BufWriter},
10 util::PartialBuffer,
11};
12use futures_core::ready;
13use pin_project_lite::pin_project;
14use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
15
16#[derive(Debug)]
17enum State {
18 Encoding,
19 Finishing,
20 Done,
21}
22
23pin_project! {
24 #[derive(Debug)]
25 pub struct Encoder<W, E> {
26 #[pin]
27 writer: BufWriter<W>,
28 encoder: E,
29 state: State,
30 }
31}
32
33impl<W: AsyncWrite, E: Encode> Encoder<W, E> {
34 pub fn new(writer: W, encoder: E) -> Self {
35 Self {
36 writer: BufWriter::new(writer),
37 encoder,
38 state: State::Encoding,
39 }
40 }
41}
42
43impl<W, E> Encoder<W, E> {
44 pub fn get_ref(&self) -> &W {
45 self.writer.get_ref()
46 }
47
48 pub fn get_mut(&mut self) -> &mut W {
49 self.writer.get_mut()
50 }
51
52 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
53 self.project().writer.get_pin_mut()
54 }
55
56 pub(crate) fn get_encoder_ref(&self) -> &E {
57 &self.encoder
58 }
59
60 pub fn into_inner(self) -> W {
61 self.writer.into_inner()
62 }
63}
64
65impl<W: AsyncWrite, E: Encode> Encoder<W, E> {
66 fn do_poll_write(
67 self: Pin<&mut Self>,
68 cx: &mut Context<'_>,
69 input: &mut PartialBuffer<&[u8]>,
70 ) -> Poll<io::Result<()>> {
71 let mut this = self.project();
72
73 loop {
74 let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
75 let mut output = PartialBuffer::new(output);
76
77 *this.state = match this.state {
78 State::Encoding => {
79 this.encoder.encode(input, &mut output)?;
80 State::Encoding
81 }
82
83 State::Finishing | State::Done => {
84 return Poll::Ready(Err(io::Error::new(
85 io::ErrorKind::Other,
86 "Write after shutdown",
87 )))
88 }
89 };
90
91 let produced = output.written().len();
92 this.writer.as_mut().produce(produced);
93
94 if input.unwritten().is_empty() {
95 return Poll::Ready(Ok(()));
96 }
97 }
98 }
99
100 fn do_poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
101 let mut this = self.project();
102
103 loop {
104 let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
105 let mut output = PartialBuffer::new(output);
106
107 let done = match this.state {
108 State::Encoding => this.encoder.flush(&mut output)?,
109
110 State::Finishing | State::Done => {
111 return Poll::Ready(Err(io::Error::new(
112 io::ErrorKind::Other,
113 "Flush after shutdown",
114 )))
115 }
116 };
117
118 let produced = output.written().len();
119 this.writer.as_mut().produce(produced);
120
121 if done {
122 return Poll::Ready(Ok(()));
123 }
124 }
125 }
126
127 fn do_poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
128 let mut this = self.project();
129
130 loop {
131 let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
132 let mut output = PartialBuffer::new(output);
133
134 *this.state = match this.state {
135 State::Encoding | State::Finishing => {
136 if this.encoder.finish(&mut output)? {
137 State::Done
138 } else {
139 State::Finishing
140 }
141 }
142
143 State::Done => State::Done,
144 };
145
146 let produced = output.written().len();
147 this.writer.as_mut().produce(produced);
148
149 if let State::Done = this.state {
150 return Poll::Ready(Ok(()));
151 }
152 }
153 }
154}
155
156impl<W: AsyncWrite, E: Encode> AsyncWrite for Encoder<W, E> {
157 fn poll_write(
158 self: Pin<&mut Self>,
159 cx: &mut Context<'_>,
160 buf: &[u8],
161 ) -> Poll<io::Result<usize>> {
162 if buf.is_empty() {
163 return Poll::Ready(Ok(0));
164 }
165
166 let mut input = PartialBuffer::new(buf);
167
168 match self.do_poll_write(cx, &mut input)? {
169 Poll::Pending if input.written().is_empty() => Poll::Pending,
170 _ => Poll::Ready(Ok(input.written().len())),
171 }
172 }
173
174 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
175 ready!(self.as_mut().do_poll_flush(cx))?;
176 ready!(self.project().writer.as_mut().poll_flush(cx))?;
177 Poll::Ready(Ok(()))
178 }
179
180 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
181 ready!(self.as_mut().do_poll_shutdown(cx))?;
182 ready!(self.project().writer.as_mut().poll_shutdown(cx))?;
183 Poll::Ready(Ok(()))
184 }
185}
186
187impl<W: AsyncRead, E> AsyncRead for Encoder<W, E> {
188 fn poll_read(
189 self: Pin<&mut Self>,
190 cx: &mut Context<'_>,
191 buf: &mut ReadBuf<'_>,
192 ) -> Poll<io::Result<()>> {
193 self.get_pin_mut().poll_read(cx, buf)
194 }
195}
196
197impl<W: AsyncBufRead, E> AsyncBufRead for Encoder<W, E> {
198 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
199 self.get_pin_mut().poll_fill_buf(cx)
200 }
201
202 fn consume(self: Pin<&mut Self>, amt: usize) {
203 self.get_pin_mut().consume(amt)
204 }
205}