async_compression/tokio/bufread/generic/
encoder.rs

1use core::{
2    pin::Pin,
3    task::{Context, Poll},
4};
5use std::io::{IoSlice, Result};
6
7use crate::{codec::Encode, util::PartialBuffer};
8use futures_core::ready;
9use pin_project_lite::pin_project;
10use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
11
12#[derive(Debug)]
13enum State {
14    Encoding,
15    Flushing,
16    Done,
17}
18
19pin_project! {
20    #[derive(Debug)]
21    pub struct Encoder<R, E> {
22        #[pin]
23        reader: R,
24        encoder: E,
25        state: State,
26    }
27}
28
29impl<R: AsyncBufRead, E: Encode> Encoder<R, E> {
30    pub fn new(reader: R, encoder: E) -> Self {
31        Self {
32            reader,
33            encoder,
34            state: State::Encoding,
35        }
36    }
37}
38
39impl<R, E> Encoder<R, E> {
40    pub fn get_ref(&self) -> &R {
41        &self.reader
42    }
43
44    pub fn get_mut(&mut self) -> &mut R {
45        &mut self.reader
46    }
47
48    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
49        self.project().reader
50    }
51
52    pub(crate) fn get_encoder_ref(&self) -> &E {
53        &self.encoder
54    }
55
56    pub fn into_inner(self) -> R {
57        self.reader
58    }
59}
60impl<R: AsyncBufRead, E: Encode> Encoder<R, E> {
61    fn do_poll_read(
62        self: Pin<&mut Self>,
63        cx: &mut Context<'_>,
64        output: &mut PartialBuffer<&mut [u8]>,
65    ) -> Poll<Result<()>> {
66        let mut this = self.project();
67
68        loop {
69            *this.state = match this.state {
70                State::Encoding => {
71                    let input = ready!(this.reader.as_mut().poll_fill_buf(cx))?;
72                    if input.is_empty() {
73                        State::Flushing
74                    } else {
75                        let mut input = PartialBuffer::new(input);
76                        this.encoder.encode(&mut input, output)?;
77                        let len = input.written().len();
78                        this.reader.as_mut().consume(len);
79                        State::Encoding
80                    }
81                }
82
83                State::Flushing => {
84                    if this.encoder.finish(output)? {
85                        State::Done
86                    } else {
87                        State::Flushing
88                    }
89                }
90
91                State::Done => State::Done,
92            };
93
94            if let State::Done = *this.state {
95                return Poll::Ready(Ok(()));
96            }
97            if output.unwritten().is_empty() {
98                return Poll::Ready(Ok(()));
99            }
100        }
101    }
102}
103
104impl<R: AsyncBufRead, E: Encode> AsyncRead for Encoder<R, E> {
105    fn poll_read(
106        self: Pin<&mut Self>,
107        cx: &mut Context<'_>,
108        buf: &mut ReadBuf<'_>,
109    ) -> Poll<Result<()>> {
110        if buf.remaining() == 0 {
111            return Poll::Ready(Ok(()));
112        }
113
114        let mut output = PartialBuffer::new(buf.initialize_unfilled());
115        match self.do_poll_read(cx, &mut output)? {
116            Poll::Pending if output.written().is_empty() => Poll::Pending,
117            _ => {
118                let len = output.written().len();
119                buf.advance(len);
120                Poll::Ready(Ok(()))
121            }
122        }
123    }
124}
125
126impl<R: AsyncWrite, E> AsyncWrite for Encoder<R, E> {
127    fn poll_write(
128        mut self: Pin<&mut Self>,
129        cx: &mut Context<'_>,
130        buf: &[u8],
131    ) -> Poll<Result<usize>> {
132        self.get_pin_mut().poll_write(cx, buf)
133    }
134
135    fn poll_write_vectored(
136        mut self: Pin<&mut Self>,
137        cx: &mut Context<'_>,
138        mut bufs: &[IoSlice<'_>],
139    ) -> Poll<Result<usize>> {
140        self.get_pin_mut().poll_write_vectored(cx, bufs)
141    }
142
143    fn is_write_vectored(&self) -> bool {
144        self.get_ref().is_write_vectored()
145    }
146
147    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
148        self.get_pin_mut().poll_flush(cx)
149    }
150
151    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
152        self.get_pin_mut().poll_shutdown(cx)
153    }
154}