async_compression/tokio/bufread/generic/
decoder.rs

1use core::{
2    pin::Pin,
3    task::{Context, Poll},
4};
5use std::io::{IoSlice, Result};
6
7use crate::{codec::Decode, util::PartialBuffer};
8use futures_core::ready;
9use pin_project_lite::pin_project;
10use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
11
12#[derive(Debug)]
13enum State {
14    Decoding,
15    Flushing,
16    Done,
17    Next,
18}
19
20pin_project! {
21    #[derive(Debug)]
22    pub struct Decoder<R, D> {
23        #[pin]
24        reader: R,
25        decoder: D,
26        state: State,
27        multiple_members: bool,
28    }
29}
30
31impl<R: AsyncBufRead, D: Decode> Decoder<R, D> {
32    pub fn new(reader: R, decoder: D) -> Self {
33        Self {
34            reader,
35            decoder,
36            state: State::Decoding,
37            multiple_members: false,
38        }
39    }
40}
41
42impl<R, D> Decoder<R, D> {
43    pub fn get_ref(&self) -> &R {
44        &self.reader
45    }
46
47    pub fn get_mut(&mut self) -> &mut R {
48        &mut self.reader
49    }
50
51    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
52        self.project().reader
53    }
54
55    pub fn into_inner(self) -> R {
56        self.reader
57    }
58
59    pub fn multiple_members(&mut self, enabled: bool) {
60        self.multiple_members = enabled;
61    }
62}
63
64impl<R: AsyncBufRead, D: Decode> Decoder<R, D> {
65    fn do_poll_read(
66        self: Pin<&mut Self>,
67        cx: &mut Context<'_>,
68        output: &mut PartialBuffer<&mut [u8]>,
69    ) -> Poll<Result<()>> {
70        let mut this = self.project();
71
72        let mut first = true;
73
74        loop {
75            *this.state = match this.state {
76                State::Decoding => {
77                    let input = if first {
78                        &[][..]
79                    } else {
80                        ready!(this.reader.as_mut().poll_fill_buf(cx))?
81                    };
82
83                    if input.is_empty() && !first {
84                        // Avoid attempting to reinitialise the decoder if the reader
85                        // has returned EOF.
86                        *this.multiple_members = false;
87
88                        State::Flushing
89                    } else {
90                        let mut input = PartialBuffer::new(input);
91                        let res = this.decoder.decode(&mut input, output).or_else(|err| {
92                            // ignore the first error, occurs when input is empty
93                            // but we need to run decode to flush
94                            if first {
95                                Ok(false)
96                            } else {
97                                Err(err)
98                            }
99                        });
100
101                        if !first {
102                            let len = input.written().len();
103                            this.reader.as_mut().consume(len);
104                        }
105
106                        first = false;
107
108                        if res? {
109                            State::Flushing
110                        } else {
111                            State::Decoding
112                        }
113                    }
114                }
115
116                State::Flushing => {
117                    if this.decoder.finish(output)? {
118                        if *this.multiple_members {
119                            this.decoder.reinit()?;
120                            State::Next
121                        } else {
122                            State::Done
123                        }
124                    } else {
125                        State::Flushing
126                    }
127                }
128
129                State::Done => State::Done,
130
131                State::Next => {
132                    let input = ready!(this.reader.as_mut().poll_fill_buf(cx))?;
133                    if input.is_empty() {
134                        State::Done
135                    } else {
136                        State::Decoding
137                    }
138                }
139            };
140
141            if let State::Done = *this.state {
142                return Poll::Ready(Ok(()));
143            }
144            if output.unwritten().is_empty() {
145                return Poll::Ready(Ok(()));
146            }
147        }
148    }
149}
150
151impl<R: AsyncBufRead, D: Decode> AsyncRead for Decoder<R, D> {
152    fn poll_read(
153        self: Pin<&mut Self>,
154        cx: &mut Context<'_>,
155        buf: &mut ReadBuf<'_>,
156    ) -> Poll<Result<()>> {
157        if buf.remaining() == 0 {
158            return Poll::Ready(Ok(()));
159        }
160
161        let mut output = PartialBuffer::new(buf.initialize_unfilled());
162        match self.do_poll_read(cx, &mut output)? {
163            Poll::Pending if output.written().is_empty() => Poll::Pending,
164            _ => {
165                let len = output.written().len();
166                buf.advance(len);
167                Poll::Ready(Ok(()))
168            }
169        }
170    }
171}
172
173impl<R: AsyncWrite, D: Decode> AsyncWrite for Decoder<R, D> {
174    fn poll_write(
175        mut self: Pin<&mut Self>,
176        cx: &mut Context<'_>,
177        buf: &[u8],
178    ) -> Poll<Result<usize>> {
179        self.get_pin_mut().poll_write(cx, buf)
180    }
181
182    fn poll_write_vectored(
183        mut self: Pin<&mut Self>,
184        cx: &mut Context<'_>,
185        mut bufs: &[IoSlice<'_>],
186    ) -> Poll<Result<usize>> {
187        self.get_pin_mut().poll_write_vectored(cx, bufs)
188    }
189
190    fn is_write_vectored(&self) -> bool {
191        self.get_ref().is_write_vectored()
192    }
193
194    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
195        self.get_pin_mut().poll_flush(cx)
196    }
197
198    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
199        self.get_pin_mut().poll_shutdown(cx)
200    }
201}