async_compression/tokio/write/generic/
encoder.rs

1use std::{
2    io,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use crate::{
8    codec::Encode,
9    tokio::write::{AsyncBufWrite, BufWriter},
10    util::PartialBuffer,
11};
12use futures_core::ready;
13use pin_project_lite::pin_project;
14use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
15
16#[derive(Debug)]
17enum State {
18    Encoding,
19    Finishing,
20    Done,
21}
22
23pin_project! {
24    #[derive(Debug)]
25    pub struct Encoder<W, E> {
26        #[pin]
27        writer: BufWriter<W>,
28        encoder: E,
29        state: State,
30    }
31}
32
33impl<W: AsyncWrite, E: Encode> Encoder<W, E> {
34    pub fn new(writer: W, encoder: E) -> Self {
35        Self {
36            writer: BufWriter::new(writer),
37            encoder,
38            state: State::Encoding,
39        }
40    }
41}
42
43impl<W, E> Encoder<W, E> {
44    pub fn get_ref(&self) -> &W {
45        self.writer.get_ref()
46    }
47
48    pub fn get_mut(&mut self) -> &mut W {
49        self.writer.get_mut()
50    }
51
52    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
53        self.project().writer.get_pin_mut()
54    }
55
56    pub(crate) fn get_encoder_ref(&self) -> &E {
57        &self.encoder
58    }
59
60    pub fn into_inner(self) -> W {
61        self.writer.into_inner()
62    }
63}
64
65impl<W: AsyncWrite, E: Encode> Encoder<W, E> {
66    fn do_poll_write(
67        self: Pin<&mut Self>,
68        cx: &mut Context<'_>,
69        input: &mut PartialBuffer<&[u8]>,
70    ) -> Poll<io::Result<()>> {
71        let mut this = self.project();
72
73        loop {
74            let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
75            let mut output = PartialBuffer::new(output);
76
77            *this.state = match this.state {
78                State::Encoding => {
79                    this.encoder.encode(input, &mut output)?;
80                    State::Encoding
81                }
82
83                State::Finishing | State::Done => {
84                    return Poll::Ready(Err(io::Error::new(
85                        io::ErrorKind::Other,
86                        "Write after shutdown",
87                    )))
88                }
89            };
90
91            let produced = output.written().len();
92            this.writer.as_mut().produce(produced);
93
94            if input.unwritten().is_empty() {
95                return Poll::Ready(Ok(()));
96            }
97        }
98    }
99
100    fn do_poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
101        let mut this = self.project();
102
103        loop {
104            let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
105            let mut output = PartialBuffer::new(output);
106
107            let done = match this.state {
108                State::Encoding => this.encoder.flush(&mut output)?,
109
110                State::Finishing | State::Done => {
111                    return Poll::Ready(Err(io::Error::new(
112                        io::ErrorKind::Other,
113                        "Flush after shutdown",
114                    )))
115                }
116            };
117
118            let produced = output.written().len();
119            this.writer.as_mut().produce(produced);
120
121            if done {
122                return Poll::Ready(Ok(()));
123            }
124        }
125    }
126
127    fn do_poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
128        let mut this = self.project();
129
130        loop {
131            let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
132            let mut output = PartialBuffer::new(output);
133
134            *this.state = match this.state {
135                State::Encoding | State::Finishing => {
136                    if this.encoder.finish(&mut output)? {
137                        State::Done
138                    } else {
139                        State::Finishing
140                    }
141                }
142
143                State::Done => State::Done,
144            };
145
146            let produced = output.written().len();
147            this.writer.as_mut().produce(produced);
148
149            if let State::Done = this.state {
150                return Poll::Ready(Ok(()));
151            }
152        }
153    }
154}
155
156impl<W: AsyncWrite, E: Encode> AsyncWrite for Encoder<W, E> {
157    fn poll_write(
158        self: Pin<&mut Self>,
159        cx: &mut Context<'_>,
160        buf: &[u8],
161    ) -> Poll<io::Result<usize>> {
162        if buf.is_empty() {
163            return Poll::Ready(Ok(0));
164        }
165
166        let mut input = PartialBuffer::new(buf);
167
168        match self.do_poll_write(cx, &mut input)? {
169            Poll::Pending if input.written().is_empty() => Poll::Pending,
170            _ => Poll::Ready(Ok(input.written().len())),
171        }
172    }
173
174    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
175        ready!(self.as_mut().do_poll_flush(cx))?;
176        ready!(self.project().writer.as_mut().poll_flush(cx))?;
177        Poll::Ready(Ok(()))
178    }
179
180    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
181        ready!(self.as_mut().do_poll_shutdown(cx))?;
182        ready!(self.project().writer.as_mut().poll_shutdown(cx))?;
183        Poll::Ready(Ok(()))
184    }
185}
186
187impl<W: AsyncRead, E> AsyncRead for Encoder<W, E> {
188    fn poll_read(
189        self: Pin<&mut Self>,
190        cx: &mut Context<'_>,
191        buf: &mut ReadBuf<'_>,
192    ) -> Poll<io::Result<()>> {
193        self.get_pin_mut().poll_read(cx, buf)
194    }
195}
196
197impl<W: AsyncBufRead, E> AsyncBufRead for Encoder<W, E> {
198    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
199        self.get_pin_mut().poll_fill_buf(cx)
200    }
201
202    fn consume(self: Pin<&mut Self>, amt: usize) {
203        self.get_pin_mut().consume(amt)
204    }
205}