async_compression/tokio/write/generic/
decoder.rs

1use std::{
2    io,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use crate::{
8    codec::Decode,
9    tokio::write::{AsyncBufWrite, BufWriter},
10    util::PartialBuffer,
11};
12use futures_core::ready;
13use pin_project_lite::pin_project;
14use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
15
16#[derive(Debug)]
17enum State {
18    Decoding,
19    Finishing,
20    Done,
21}
22
23pin_project! {
24    #[derive(Debug)]
25    pub struct Decoder<W, D> {
26        #[pin]
27        writer: BufWriter<W>,
28        decoder: D,
29        state: State,
30    }
31}
32
33impl<W: AsyncWrite, D: Decode> Decoder<W, D> {
34    pub fn new(writer: W, decoder: D) -> Self {
35        Self {
36            writer: BufWriter::new(writer),
37            decoder,
38            state: State::Decoding,
39        }
40    }
41}
42
43impl<W, D> Decoder<W, D> {
44    pub fn get_ref(&self) -> &W {
45        self.writer.get_ref()
46    }
47
48    pub fn get_mut(&mut self) -> &mut W {
49        self.writer.get_mut()
50    }
51
52    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
53        self.project().writer.get_pin_mut()
54    }
55
56    pub fn into_inner(self) -> W {
57        self.writer.into_inner()
58    }
59}
60
61impl<W: AsyncWrite, D: Decode> Decoder<W, D> {
62    fn do_poll_write(
63        self: Pin<&mut Self>,
64        cx: &mut Context<'_>,
65        input: &mut PartialBuffer<&[u8]>,
66    ) -> Poll<io::Result<()>> {
67        let mut this = self.project();
68
69        loop {
70            let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
71            let mut output = PartialBuffer::new(output);
72
73            *this.state = match this.state {
74                State::Decoding => {
75                    if this.decoder.decode(input, &mut output)? {
76                        State::Finishing
77                    } else {
78                        State::Decoding
79                    }
80                }
81
82                State::Finishing => {
83                    if this.decoder.finish(&mut output)? {
84                        State::Done
85                    } else {
86                        State::Finishing
87                    }
88                }
89
90                State::Done => {
91                    return Poll::Ready(Err(io::Error::new(
92                        io::ErrorKind::Other,
93                        "Write after end of stream",
94                    )))
95                }
96            };
97
98            let produced = output.written().len();
99            this.writer.as_mut().produce(produced);
100
101            if let State::Done = this.state {
102                return Poll::Ready(Ok(()));
103            }
104
105            if input.unwritten().is_empty() {
106                return Poll::Ready(Ok(()));
107            }
108        }
109    }
110
111    fn do_poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
112        let mut this = self.project();
113
114        loop {
115            let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
116            let mut output = PartialBuffer::new(output);
117
118            let (state, done) = match this.state {
119                State::Decoding => {
120                    let done = this.decoder.flush(&mut output)?;
121                    (State::Decoding, done)
122                }
123
124                State::Finishing => {
125                    if this.decoder.finish(&mut output)? {
126                        (State::Done, false)
127                    } else {
128                        (State::Finishing, false)
129                    }
130                }
131
132                State::Done => (State::Done, true),
133            };
134
135            *this.state = state;
136
137            let produced = output.written().len();
138            this.writer.as_mut().produce(produced);
139
140            if done {
141                return Poll::Ready(Ok(()));
142            }
143        }
144    }
145}
146
147impl<W: AsyncWrite, D: Decode> AsyncWrite for Decoder<W, D> {
148    fn poll_write(
149        self: Pin<&mut Self>,
150        cx: &mut Context<'_>,
151        buf: &[u8],
152    ) -> Poll<io::Result<usize>> {
153        if buf.is_empty() {
154            return Poll::Ready(Ok(0));
155        }
156
157        let mut input = PartialBuffer::new(buf);
158
159        match self.do_poll_write(cx, &mut input)? {
160            Poll::Pending if input.written().is_empty() => Poll::Pending,
161            _ => Poll::Ready(Ok(input.written().len())),
162        }
163    }
164
165    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
166        ready!(self.as_mut().do_poll_flush(cx))?;
167        ready!(self.project().writer.as_mut().poll_flush(cx))?;
168        Poll::Ready(Ok(()))
169    }
170
171    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
172        if let State::Decoding = self.as_mut().project().state {
173            *self.as_mut().project().state = State::Finishing;
174        }
175
176        ready!(self.as_mut().do_poll_flush(cx))?;
177
178        if let State::Done = self.as_mut().project().state {
179            ready!(self.as_mut().project().writer.as_mut().poll_shutdown(cx))?;
180            Poll::Ready(Ok(()))
181        } else {
182            Poll::Ready(Err(io::Error::new(
183                io::ErrorKind::Other,
184                "Attempt to shutdown before finishing input",
185            )))
186        }
187    }
188}
189
190impl<W: AsyncRead, D> AsyncRead for Decoder<W, D> {
191    fn poll_read(
192        self: Pin<&mut Self>,
193        cx: &mut Context<'_>,
194        buf: &mut ReadBuf<'_>,
195    ) -> Poll<io::Result<()>> {
196        self.get_pin_mut().poll_read(cx, buf)
197    }
198}
199
200impl<W: AsyncBufRead, D> AsyncBufRead for Decoder<W, D> {
201    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
202        self.get_pin_mut().poll_fill_buf(cx)
203    }
204
205    fn consume(self: Pin<&mut Self>, amt: usize) {
206        self.get_pin_mut().consume(amt)
207    }
208}