async_compression/tokio/write/generic/
decoder.rs1use std::{
2 io,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use crate::{
8 codec::Decode,
9 tokio::write::{AsyncBufWrite, BufWriter},
10 util::PartialBuffer,
11};
12use futures_core::ready;
13use pin_project_lite::pin_project;
14use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
15
16#[derive(Debug)]
17enum State {
18 Decoding,
19 Finishing,
20 Done,
21}
22
23pin_project! {
24 #[derive(Debug)]
25 pub struct Decoder<W, D> {
26 #[pin]
27 writer: BufWriter<W>,
28 decoder: D,
29 state: State,
30 }
31}
32
33impl<W: AsyncWrite, D: Decode> Decoder<W, D> {
34 pub fn new(writer: W, decoder: D) -> Self {
35 Self {
36 writer: BufWriter::new(writer),
37 decoder,
38 state: State::Decoding,
39 }
40 }
41}
42
43impl<W, D> Decoder<W, D> {
44 pub fn get_ref(&self) -> &W {
45 self.writer.get_ref()
46 }
47
48 pub fn get_mut(&mut self) -> &mut W {
49 self.writer.get_mut()
50 }
51
52 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
53 self.project().writer.get_pin_mut()
54 }
55
56 pub fn into_inner(self) -> W {
57 self.writer.into_inner()
58 }
59}
60
61impl<W: AsyncWrite, D: Decode> Decoder<W, D> {
62 fn do_poll_write(
63 self: Pin<&mut Self>,
64 cx: &mut Context<'_>,
65 input: &mut PartialBuffer<&[u8]>,
66 ) -> Poll<io::Result<()>> {
67 let mut this = self.project();
68
69 loop {
70 let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
71 let mut output = PartialBuffer::new(output);
72
73 *this.state = match this.state {
74 State::Decoding => {
75 if this.decoder.decode(input, &mut output)? {
76 State::Finishing
77 } else {
78 State::Decoding
79 }
80 }
81
82 State::Finishing => {
83 if this.decoder.finish(&mut output)? {
84 State::Done
85 } else {
86 State::Finishing
87 }
88 }
89
90 State::Done => {
91 return Poll::Ready(Err(io::Error::new(
92 io::ErrorKind::Other,
93 "Write after end of stream",
94 )))
95 }
96 };
97
98 let produced = output.written().len();
99 this.writer.as_mut().produce(produced);
100
101 if let State::Done = this.state {
102 return Poll::Ready(Ok(()));
103 }
104
105 if input.unwritten().is_empty() {
106 return Poll::Ready(Ok(()));
107 }
108 }
109 }
110
111 fn do_poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
112 let mut this = self.project();
113
114 loop {
115 let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
116 let mut output = PartialBuffer::new(output);
117
118 let (state, done) = match this.state {
119 State::Decoding => {
120 let done = this.decoder.flush(&mut output)?;
121 (State::Decoding, done)
122 }
123
124 State::Finishing => {
125 if this.decoder.finish(&mut output)? {
126 (State::Done, false)
127 } else {
128 (State::Finishing, false)
129 }
130 }
131
132 State::Done => (State::Done, true),
133 };
134
135 *this.state = state;
136
137 let produced = output.written().len();
138 this.writer.as_mut().produce(produced);
139
140 if done {
141 return Poll::Ready(Ok(()));
142 }
143 }
144 }
145}
146
147impl<W: AsyncWrite, D: Decode> AsyncWrite for Decoder<W, D> {
148 fn poll_write(
149 self: Pin<&mut Self>,
150 cx: &mut Context<'_>,
151 buf: &[u8],
152 ) -> Poll<io::Result<usize>> {
153 if buf.is_empty() {
154 return Poll::Ready(Ok(0));
155 }
156
157 let mut input = PartialBuffer::new(buf);
158
159 match self.do_poll_write(cx, &mut input)? {
160 Poll::Pending if input.written().is_empty() => Poll::Pending,
161 _ => Poll::Ready(Ok(input.written().len())),
162 }
163 }
164
165 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
166 ready!(self.as_mut().do_poll_flush(cx))?;
167 ready!(self.project().writer.as_mut().poll_flush(cx))?;
168 Poll::Ready(Ok(()))
169 }
170
171 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
172 if let State::Decoding = self.as_mut().project().state {
173 *self.as_mut().project().state = State::Finishing;
174 }
175
176 ready!(self.as_mut().do_poll_flush(cx))?;
177
178 if let State::Done = self.as_mut().project().state {
179 ready!(self.as_mut().project().writer.as_mut().poll_shutdown(cx))?;
180 Poll::Ready(Ok(()))
181 } else {
182 Poll::Ready(Err(io::Error::new(
183 io::ErrorKind::Other,
184 "Attempt to shutdown before finishing input",
185 )))
186 }
187 }
188}
189
190impl<W: AsyncRead, D> AsyncRead for Decoder<W, D> {
191 fn poll_read(
192 self: Pin<&mut Self>,
193 cx: &mut Context<'_>,
194 buf: &mut ReadBuf<'_>,
195 ) -> Poll<io::Result<()>> {
196 self.get_pin_mut().poll_read(cx, buf)
197 }
198}
199
200impl<W: AsyncBufRead, D> AsyncBufRead for Decoder<W, D> {
201 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
202 self.get_pin_mut().poll_fill_buf(cx)
203 }
204
205 fn consume(self: Pin<&mut Self>, amt: usize) {
206 self.get_pin_mut().consume(amt)
207 }
208}