async_compression/tokio/bufread/generic/
decoder.rs1use core::{
2 pin::Pin,
3 task::{Context, Poll},
4};
5use std::io::{IoSlice, Result};
6
7use crate::{codec::Decode, util::PartialBuffer};
8use futures_core::ready;
9use pin_project_lite::pin_project;
10use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
11
12#[derive(Debug)]
13enum State {
14 Decoding,
15 Flushing,
16 Done,
17 Next,
18}
19
20pin_project! {
21 #[derive(Debug)]
22 pub struct Decoder<R, D> {
23 #[pin]
24 reader: R,
25 decoder: D,
26 state: State,
27 multiple_members: bool,
28 }
29}
30
31impl<R: AsyncBufRead, D: Decode> Decoder<R, D> {
32 pub fn new(reader: R, decoder: D) -> Self {
33 Self {
34 reader,
35 decoder,
36 state: State::Decoding,
37 multiple_members: false,
38 }
39 }
40}
41
42impl<R, D> Decoder<R, D> {
43 pub fn get_ref(&self) -> &R {
44 &self.reader
45 }
46
47 pub fn get_mut(&mut self) -> &mut R {
48 &mut self.reader
49 }
50
51 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
52 self.project().reader
53 }
54
55 pub fn into_inner(self) -> R {
56 self.reader
57 }
58
59 pub fn multiple_members(&mut self, enabled: bool) {
60 self.multiple_members = enabled;
61 }
62}
63
64impl<R: AsyncBufRead, D: Decode> Decoder<R, D> {
65 fn do_poll_read(
66 self: Pin<&mut Self>,
67 cx: &mut Context<'_>,
68 output: &mut PartialBuffer<&mut [u8]>,
69 ) -> Poll<Result<()>> {
70 let mut this = self.project();
71
72 let mut first = true;
73
74 loop {
75 *this.state = match this.state {
76 State::Decoding => {
77 let input = if first {
78 &[][..]
79 } else {
80 ready!(this.reader.as_mut().poll_fill_buf(cx))?
81 };
82
83 if input.is_empty() && !first {
84 *this.multiple_members = false;
87
88 State::Flushing
89 } else {
90 let mut input = PartialBuffer::new(input);
91 let res = this.decoder.decode(&mut input, output).or_else(|err| {
92 if first {
95 Ok(false)
96 } else {
97 Err(err)
98 }
99 });
100
101 if !first {
102 let len = input.written().len();
103 this.reader.as_mut().consume(len);
104 }
105
106 first = false;
107
108 if res? {
109 State::Flushing
110 } else {
111 State::Decoding
112 }
113 }
114 }
115
116 State::Flushing => {
117 if this.decoder.finish(output)? {
118 if *this.multiple_members {
119 this.decoder.reinit()?;
120 State::Next
121 } else {
122 State::Done
123 }
124 } else {
125 State::Flushing
126 }
127 }
128
129 State::Done => State::Done,
130
131 State::Next => {
132 let input = ready!(this.reader.as_mut().poll_fill_buf(cx))?;
133 if input.is_empty() {
134 State::Done
135 } else {
136 State::Decoding
137 }
138 }
139 };
140
141 if let State::Done = *this.state {
142 return Poll::Ready(Ok(()));
143 }
144 if output.unwritten().is_empty() {
145 return Poll::Ready(Ok(()));
146 }
147 }
148 }
149}
150
151impl<R: AsyncBufRead, D: Decode> AsyncRead for Decoder<R, D> {
152 fn poll_read(
153 self: Pin<&mut Self>,
154 cx: &mut Context<'_>,
155 buf: &mut ReadBuf<'_>,
156 ) -> Poll<Result<()>> {
157 if buf.remaining() == 0 {
158 return Poll::Ready(Ok(()));
159 }
160
161 let mut output = PartialBuffer::new(buf.initialize_unfilled());
162 match self.do_poll_read(cx, &mut output)? {
163 Poll::Pending if output.written().is_empty() => Poll::Pending,
164 _ => {
165 let len = output.written().len();
166 buf.advance(len);
167 Poll::Ready(Ok(()))
168 }
169 }
170 }
171}
172
173impl<R: AsyncWrite, D: Decode> AsyncWrite for Decoder<R, D> {
174 fn poll_write(
175 mut self: Pin<&mut Self>,
176 cx: &mut Context<'_>,
177 buf: &[u8],
178 ) -> Poll<Result<usize>> {
179 self.get_pin_mut().poll_write(cx, buf)
180 }
181
182 fn poll_write_vectored(
183 mut self: Pin<&mut Self>,
184 cx: &mut Context<'_>,
185 mut bufs: &[IoSlice<'_>],
186 ) -> Poll<Result<usize>> {
187 self.get_pin_mut().poll_write_vectored(cx, bufs)
188 }
189
190 fn is_write_vectored(&self) -> bool {
191 self.get_ref().is_write_vectored()
192 }
193
194 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
195 self.get_pin_mut().poll_flush(cx)
196 }
197
198 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
199 self.get_pin_mut().poll_shutdown(cx)
200 }
201}