tower_http/
compression_utils.rs

1//! Types used by compression and decompression middleware.
2
3use crate::{content_encoding::SupportedEncodings, BoxError};
4use bytes::{Buf, Bytes, BytesMut};
5use futures_core::Stream;
6use http::HeaderValue;
7use http_body::{Body, Frame};
8use pin_project_lite::pin_project;
9use std::{
10    io,
11    pin::Pin,
12    task::{ready, Context, Poll},
13};
14use tokio::io::AsyncRead;
15use tokio_util::io::StreamReader;
16
17#[derive(Debug, Clone, Copy)]
18pub(crate) struct AcceptEncoding {
19    pub(crate) gzip: bool,
20    pub(crate) deflate: bool,
21    pub(crate) br: bool,
22    pub(crate) zstd: bool,
23}
24
25impl AcceptEncoding {
26    #[allow(dead_code)]
27    pub(crate) fn to_header_value(self) -> Option<HeaderValue> {
28        let accept = match (self.gzip(), self.deflate(), self.br(), self.zstd()) {
29            (true, true, true, false) => "gzip,deflate,br",
30            (true, true, false, false) => "gzip,deflate",
31            (true, false, true, false) => "gzip,br",
32            (true, false, false, false) => "gzip",
33            (false, true, true, false) => "deflate,br",
34            (false, true, false, false) => "deflate",
35            (false, false, true, false) => "br",
36            (true, true, true, true) => "zstd,gzip,deflate,br",
37            (true, true, false, true) => "zstd,gzip,deflate",
38            (true, false, true, true) => "zstd,gzip,br",
39            (true, false, false, true) => "zstd,gzip",
40            (false, true, true, true) => "zstd,deflate,br",
41            (false, true, false, true) => "zstd,deflate",
42            (false, false, true, true) => "zstd,br",
43            (false, false, false, true) => "zstd",
44            (false, false, false, false) => return None,
45        };
46        Some(HeaderValue::from_static(accept))
47    }
48
49    #[allow(dead_code)]
50    pub(crate) fn set_gzip(&mut self, enable: bool) {
51        self.gzip = enable;
52    }
53
54    #[allow(dead_code)]
55    pub(crate) fn set_deflate(&mut self, enable: bool) {
56        self.deflate = enable;
57    }
58
59    #[allow(dead_code)]
60    pub(crate) fn set_br(&mut self, enable: bool) {
61        self.br = enable;
62    }
63
64    #[allow(dead_code)]
65    pub(crate) fn set_zstd(&mut self, enable: bool) {
66        self.zstd = enable;
67    }
68}
69
70impl SupportedEncodings for AcceptEncoding {
71    #[allow(dead_code)]
72    fn gzip(&self) -> bool {
73        #[cfg(any(feature = "decompression-gzip", feature = "compression-gzip"))]
74        return self.gzip;
75
76        #[cfg(not(any(feature = "decompression-gzip", feature = "compression-gzip")))]
77        return false;
78    }
79
80    #[allow(dead_code)]
81    fn deflate(&self) -> bool {
82        #[cfg(any(feature = "decompression-deflate", feature = "compression-deflate"))]
83        return self.deflate;
84
85        #[cfg(not(any(feature = "decompression-deflate", feature = "compression-deflate")))]
86        return false;
87    }
88
89    #[allow(dead_code)]
90    fn br(&self) -> bool {
91        #[cfg(any(feature = "decompression-br", feature = "compression-br"))]
92        return self.br;
93
94        #[cfg(not(any(feature = "decompression-br", feature = "compression-br")))]
95        return false;
96    }
97
98    #[allow(dead_code)]
99    fn zstd(&self) -> bool {
100        #[cfg(any(feature = "decompression-zstd", feature = "compression-zstd"))]
101        return self.zstd;
102
103        #[cfg(not(any(feature = "decompression-zstd", feature = "compression-zstd")))]
104        return false;
105    }
106}
107
108impl Default for AcceptEncoding {
109    fn default() -> Self {
110        AcceptEncoding {
111            gzip: true,
112            deflate: true,
113            br: true,
114            zstd: true,
115        }
116    }
117}
118
119/// A `Body` that has been converted into an `AsyncRead`.
120pub(crate) type AsyncReadBody<B> =
121    StreamReader<StreamErrorIntoIoError<BodyIntoStream<B>, <B as Body>::Error>, <B as Body>::Data>;
122
123/// Trait for applying some decorator to an `AsyncRead`
124pub(crate) trait DecorateAsyncRead {
125    type Input: AsyncRead;
126    type Output: AsyncRead;
127
128    /// Apply the decorator
129    fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output;
130
131    /// Get a pinned mutable reference to the original input.
132    ///
133    /// This is necessary to implement `Body::poll_trailers`.
134    fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input>;
135}
136
137pin_project! {
138    /// `Body` that has been decorated by an `AsyncRead`
139    pub(crate) struct WrapBody<M: DecorateAsyncRead> {
140        #[pin]
141        // rust-analyer thinks this field is private if its `pub(crate)` but works fine when its
142        // `pub`
143        pub read: M::Output,
144        // A buffer to temporarily store the data read from the underlying body.
145        // Reused as much as possible to optimize allocations.
146        buf: BytesMut,
147        read_all_data: bool,
148    }
149}
150
151impl<M: DecorateAsyncRead> WrapBody<M> {
152    const INTERNAL_BUF_CAPACITY: usize = 4096;
153}
154
155impl<M: DecorateAsyncRead> WrapBody<M> {
156    #[allow(dead_code)]
157    pub(crate) fn new<B>(body: B, quality: CompressionLevel) -> Self
158    where
159        B: Body,
160        M: DecorateAsyncRead<Input = AsyncReadBody<B>>,
161    {
162        // convert `Body` into a `Stream`
163        let stream = BodyIntoStream::new(body);
164
165        // an adapter that converts the error type into `io::Error` while storing the actual error
166        // `StreamReader` requires the error type is `io::Error`
167        let stream = StreamErrorIntoIoError::<_, B::Error>::new(stream);
168
169        // convert `Stream` into an `AsyncRead`
170        let read = StreamReader::new(stream);
171
172        // apply decorator to `AsyncRead` yielding another `AsyncRead`
173        let read = M::apply(read, quality);
174
175        Self {
176            read,
177            buf: BytesMut::with_capacity(Self::INTERNAL_BUF_CAPACITY),
178            read_all_data: false,
179        }
180    }
181}
182
183impl<B, M> Body for WrapBody<M>
184where
185    B: Body,
186    B::Error: Into<BoxError>,
187    M: DecorateAsyncRead<Input = AsyncReadBody<B>>,
188{
189    type Data = Bytes;
190    type Error = BoxError;
191
192    fn poll_frame(
193        self: Pin<&mut Self>,
194        cx: &mut Context<'_>,
195    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
196        let mut this = self.project();
197
198        if !*this.read_all_data {
199            if this.buf.capacity() == 0 {
200                this.buf.reserve(Self::INTERNAL_BUF_CAPACITY);
201            }
202
203            let result = tokio_util::io::poll_read_buf(this.read.as_mut(), cx, &mut this.buf);
204
205            match ready!(result) {
206                Ok(0) => {
207                    *this.read_all_data = true;
208                }
209                Ok(_) => {
210                    let chunk = this.buf.split().freeze();
211                    return Poll::Ready(Some(Ok(Frame::data(chunk))));
212                }
213                Err(err) => {
214                    let body_error: Option<B::Error> = M::get_pin_mut(this.read)
215                        .get_pin_mut()
216                        .project()
217                        .error
218                        .take();
219
220                    if let Some(body_error) = body_error {
221                        return Poll::Ready(Some(Err(body_error.into())));
222                    } else if err.raw_os_error() == Some(SENTINEL_ERROR_CODE) {
223                        // SENTINEL_ERROR_CODE only gets used when storing
224                        // an underlying body error
225                        unreachable!()
226                    } else {
227                        return Poll::Ready(Some(Err(err.into())));
228                    }
229                }
230            }
231        }
232
233        // poll any remaining frames, such as trailers
234        let body = M::get_pin_mut(this.read).get_pin_mut().get_pin_mut();
235        body.poll_frame(cx).map(|option| {
236            option.map(|result| {
237                result
238                    .map(|frame| frame.map_data(|mut data| data.copy_to_bytes(data.remaining())))
239                    .map_err(|err| err.into())
240            })
241        })
242    }
243}
244
245pin_project! {
246    pub(crate) struct BodyIntoStream<B>
247    where
248        B: Body,
249    {
250        #[pin]
251        body: B,
252        yielded_all_data: bool,
253        non_data_frame: Option<Frame<B::Data>>,
254    }
255}
256
257#[allow(dead_code)]
258impl<B> BodyIntoStream<B>
259where
260    B: Body,
261{
262    pub(crate) fn new(body: B) -> Self {
263        Self {
264            body,
265            yielded_all_data: false,
266            non_data_frame: None,
267        }
268    }
269
270    /// Get a reference to the inner body
271    pub(crate) fn get_ref(&self) -> &B {
272        &self.body
273    }
274
275    /// Get a mutable reference to the inner body
276    pub(crate) fn get_mut(&mut self) -> &mut B {
277        &mut self.body
278    }
279
280    /// Get a pinned mutable reference to the inner body
281    pub(crate) fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut B> {
282        self.project().body
283    }
284
285    /// Consume `self`, returning the inner body
286    pub(crate) fn into_inner(self) -> B {
287        self.body
288    }
289}
290
291impl<B> Stream for BodyIntoStream<B>
292where
293    B: Body,
294{
295    type Item = Result<B::Data, B::Error>;
296
297    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
298        loop {
299            let this = self.as_mut().project();
300
301            if *this.yielded_all_data {
302                return Poll::Ready(None);
303            }
304
305            match std::task::ready!(this.body.poll_frame(cx)) {
306                Some(Ok(frame)) => match frame.into_data() {
307                    Ok(data) => return Poll::Ready(Some(Ok(data))),
308                    Err(frame) => {
309                        *this.yielded_all_data = true;
310                        *this.non_data_frame = Some(frame);
311                    }
312                },
313                Some(Err(err)) => return Poll::Ready(Some(Err(err))),
314                None => {
315                    *this.yielded_all_data = true;
316                }
317            }
318        }
319    }
320}
321
322impl<B> Body for BodyIntoStream<B>
323where
324    B: Body,
325{
326    type Data = B::Data;
327    type Error = B::Error;
328
329    fn poll_frame(
330        mut self: Pin<&mut Self>,
331        cx: &mut Context<'_>,
332    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
333        // First drive the stream impl. This consumes all data frames and buffer at most one
334        // trailers frame.
335        if let Some(frame) = std::task::ready!(self.as_mut().poll_next(cx)) {
336            return Poll::Ready(Some(frame.map(Frame::data)));
337        }
338
339        let this = self.project();
340
341        // Yield the trailers frame `poll_next` hit.
342        if let Some(frame) = this.non_data_frame.take() {
343            return Poll::Ready(Some(Ok(frame)));
344        }
345
346        // Yield any remaining frames in the body. There shouldn't be any after the trailers but
347        // you never know.
348        this.body.poll_frame(cx)
349    }
350
351    #[inline]
352    fn size_hint(&self) -> http_body::SizeHint {
353        self.body.size_hint()
354    }
355}
356
357pin_project! {
358    pub(crate) struct StreamErrorIntoIoError<S, E> {
359        #[pin]
360        inner: S,
361        error: Option<E>,
362    }
363}
364
365impl<S, E> StreamErrorIntoIoError<S, E> {
366    pub(crate) fn new(inner: S) -> Self {
367        Self { inner, error: None }
368    }
369
370    /// Get a reference to the inner body
371    pub(crate) fn get_ref(&self) -> &S {
372        &self.inner
373    }
374
375    /// Get a mutable reference to the inner inner
376    pub(crate) fn get_mut(&mut self) -> &mut S {
377        &mut self.inner
378    }
379
380    /// Get a pinned mutable reference to the inner inner
381    pub(crate) fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
382        self.project().inner
383    }
384
385    /// Consume `self`, returning the inner inner
386    pub(crate) fn into_inner(self) -> S {
387        self.inner
388    }
389}
390
391impl<S, T, E> Stream for StreamErrorIntoIoError<S, E>
392where
393    S: Stream<Item = Result<T, E>>,
394{
395    type Item = Result<T, io::Error>;
396
397    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
398        let this = self.project();
399        match ready!(this.inner.poll_next(cx)) {
400            None => Poll::Ready(None),
401            Some(Ok(value)) => Poll::Ready(Some(Ok(value))),
402            Some(Err(err)) => {
403                *this.error = Some(err);
404                Poll::Ready(Some(Err(io::Error::from_raw_os_error(SENTINEL_ERROR_CODE))))
405            }
406        }
407    }
408}
409
410pub(crate) const SENTINEL_ERROR_CODE: i32 = -837459418;
411
412/// Level of compression data should be compressed with.
413#[non_exhaustive]
414#[derive(Clone, Copy, Debug, Eq, PartialEq, Default)]
415pub enum CompressionLevel {
416    /// Fastest quality of compression, usually produces bigger size.
417    Fastest,
418    /// Best quality of compression, usually produces the smallest size.
419    Best,
420    /// Default quality of compression defined by the selected compression
421    /// algorithm.
422    #[default]
423    Default,
424    /// Precise quality based on the underlying compression algorithms'
425    /// qualities.
426    ///
427    /// The interpretation of this depends on the algorithm chosen and the
428    /// specific implementation backing it.
429    ///
430    /// Qualities are implicitly clamped to the algorithm's maximum.
431    Precise(i32),
432}
433
434#[cfg(any(
435    feature = "compression-br",
436    feature = "compression-gzip",
437    feature = "compression-deflate",
438    feature = "compression-zstd"
439))]
440use async_compression::Level as AsyncCompressionLevel;
441
442#[cfg(any(
443    feature = "compression-br",
444    feature = "compression-gzip",
445    feature = "compression-deflate",
446    feature = "compression-zstd"
447))]
448impl CompressionLevel {
449    pub(crate) fn into_async_compression(self) -> AsyncCompressionLevel {
450        match self {
451            CompressionLevel::Fastest => AsyncCompressionLevel::Fastest,
452            CompressionLevel::Best => AsyncCompressionLevel::Best,
453            CompressionLevel::Default => AsyncCompressionLevel::Default,
454            CompressionLevel::Precise(quality) => AsyncCompressionLevel::Precise(quality),
455        }
456    }
457}